import itertools
import networkx as nx


def afegir_linia(g, noml, parades):
    nx.add_path(g, parades, linia=noml)


def crea_xarxa_metro(nomf='metrobcn.txt'):
    g = nx.Graph()
    with open(nomf, 'r') as f :
        parades = []
        for lin in f:
            lin = lin.strip()
            if lin[0] == '=':
                afegir_linia(g, lin[1:], parades)
                parades = []
            else:
                parades.append(lin)
    return g


def cjt_parades_1(g, noml):
    # Recorregut utilitzant l'iterador d'arestes
    s = set()
    for a, b in g.edges():
        if g[a][b]['linia'] == noml:
            s.add(a)
            s.add(b)
    return s

def cjt_parades_2(g, noml):
    # Solució amb iteradors
    arestes = g.edges(data='linia')
    arestes_linia = filter(lambda a: a[2] == noml, arestes)
    arestes_linia_parell_nodes = map(lambda a: a[:2], arestes_linia)
    nodes_linia = itertools.chain.from_iterable(arestes_linia_parell_nodes)
    cjt_nodes = set(nodes_linia)
    return cjt_nodes

def cjt_parades_3(g, noml):
    # Recorregut clàssic: per cada node, per cada vei 
    s = set()
    for node in g:
        for vei in g[node]:
            if g[node][vei]['linia'] == noml:
                s.add(node)
                s.add(vei)
    return s

# Tria la solució de cjt_parades que vols provar
#cjt_parades = cjt_parades_1
#cjt_parades = cjt_parades_2
cjt_parades = cjt_parades_3


def cjt_linies_1(g, parada):
    s = set()
    for veina in g[parada] :
        s.add(g[parada][veina]['linia'])
    return s
    
def cjt_linies_2(g, parada):
    # Equivalent a la versió anterior però amb "generator expressions"
    return set( g[parada][veina]['linia'] for veina in g[parada])

def cjt_linies_3(g, parada):
    # Solució amb iteradors
    arestes_parada = g.edges(parada, data='linia')
    linies_parada = map(lambda a: a[2], arestes_parada)
    cjt_l = set(linies_parada)
    return cjt_l
    
# Tria la solució de cjt_linies que vols provar
#cjt_linies = cjt_linies_1
#cjt_linies = cjt_linies_2
cjt_linies = cjt_linies_3


def recorregut_1(g, pini, pfi):
    try:
        return nx.shortest_path(g, pini, pfi)
    except (nx.exception.NetworkXError, nx.exception.NetworkXNoPath):
        return []

def recorregut_2(g, pini, pfi):
    if nx.has_path(g, pini, pfi):
        r = nx.shortest_path(g, pini, pfi)
    else:
        r = []
    return r
    
# Tria la solució de recorregut que vols provar
#recorregut = recorregut_1
recorregut = recorregut_2


def llista_parades_1(g, linia):
    parades = cjt_parades(g, linia)
    sg = nx.Graph.subgraph(g, parades)
    extrems = []
    for p in parades :
        if  sg.degree(p) == 1 :
            extrems.append(p)
    pini, pfi = extrems
    if pini > pfi:
        pini, pfi = pfi, pini
    return nx.shortest_path(sg, pini, pfi)

def llista_parades_2(g, linia):
    # Equivalent a la versió anterior però amb "generator expressions"
    parades = cjt_parades(g, linia)
    sg = nx.Graph.subgraph(g, parades)
    pini, pfi = (p for p in parades if sg.degree(p) == 1)
    if pini > pfi:
        pini, pfi = pfi, pini
    return nx.shortest_path(sg, pini, pfi)

def llista_parades_3(g, linia):
    # Equivalent a la versió anterior però amb filter
    parades = cjt_parades(g, linia)
    sg = nx.Graph.subgraph(g, parades)
    pini, pfi = filter(lambda p: sg.degree(p) == 1, parades)
    if pini > pfi:
        pini, pfi = pfi, pini
    return nx.shortest_path(sg, pini, pfi)

# Tria la solució de llista_parades que vols provar
#llista_parades = llista_parades_1
#llista_parades = llista_parades_2
llista_parades = llista_parades_3
