import networkx as nx
import numpy as np
import random
import matplotlib.pyplot as plt
from robustness import *
from rewire import *
from edge_addition import *


def recover_to_initial_diameter(initial_diameter, initial_lcc, attacked_graph, recovery_option=0):
    d, lcc, av_cc = get_robustness(attacked_graph)
    new_diameter = d
    num_edges = 0

    recovered_graph = attacked_graph.copy()

    # Fist we have to make sure there is only 1 LCC
    while lcc < initial_lcc:
        recovered_graph = recover_edge(recovery_option, recovered_graph)
        num_edges += 1
        lcc = get_LCC_size(recovered_graph)

    # Then recover the diameter
    while new_diameter > initial_diameter:
        recovered_graph = recover_edge(recovery_option, recovered_graph)
        num_edges += 1
        new_diameter = get_diameter(recovered_graph)

    # print_robustness(recovered_graph)
    # print("\nNumber of edges needed to recover:", num_edges)

    return recovered_graph, num_edges


def recover_to_initial_diameter_lcc_ratio(initial_diameter, initial_lcc, attacked_graph, recovery_option=RANDOM_ADD):
    d, lcc, av_cc = get_robustness(attacked_graph)
    num_edges = 0

    initial_d_lcc = initial_diameter/initial_lcc
    new_d_lcc = d/lcc

    recovered_graph = attacked_graph.copy()

    size = recovered_graph.size()

    stop = 0
    # Give up if too many edges are added or the robustness doesn't increase anymore
    while new_d_lcc > initial_d_lcc and num_edges < size*2 and stop < size*0.1:
        recovered_graph = recover_edge(recovery_option, recovered_graph)
        num_edges += 1
        d, lcc, av_cc = get_robustness(recovered_graph)

        old_d_lcc = new_d_lcc
        new_d_lcc = d/lcc
        stop = 0 if old_d_lcc < new_d_lcc else stop+1

    # print_robustness(recovered_graph)
    # print("\nNumber of edges needed to recover:", num_edges)

    return recovered_graph, num_edges


def recover_edge(recovery_option, recovered_graph):
    # Select respective recovery option

    # Edge addition
    if recovery_option == RANDOM_ADD:
        recovered_graph = add_random_edge(recovered_graph)
    elif recovery_option == PREFERENTIAL_MIN_MIN:
        recovered_graph = add_preferential_edge_min_min(recovered_graph)
    elif recovery_option == PREFERENTIAL_MIN_MAX:
        recovered_graph = add_preferential_edge_min_max(recovered_graph)
    elif recovery_option == PREFERENTIAL_MAX_MAX:
        recovered_graph = add_preferential_edge_max_max(recovered_graph)
    elif recovery_option == PREFERENTIAL_MAX_RAND:
        recovered_graph = add_preferential_edge_max_rand(recovered_graph)

    # Edge rewiring
    elif recovery_option == RANDOM_REWIRE:
        recovered_graph = random_rewire(recovered_graph)
    elif recovery_option == PREFERENTIAL_REW:
        recovered_graph = pref_rewire(recovered_graph)
    elif recovery_option == PREFERENTIAL_RANDOM:
        recovered_graph = pref_random_rewire(recovered_graph)

    # Do nothing
    elif recovery_option == NO_REWIRE or recovery_option == NO_ADD:
        pass
    else:
        raise Exception("Incorrect recovery option specified")

    return recovered_graph