import networkx as nx

def npolitges(g):
    np = 0
    for r1, r2 in g.edges():
        if g[r1][r2]['unio'] == 'POL':
            np = np + 1
    return np

def quantes_unions_1(ge, o, d, un):
    q = 0
    if nx.has_path(ge, o, d):
        cami = nx.shortest_path(ge, o, d)
        ni = cami[0]
        for n in cami[1:]:
            if ge.edges[ni, n]['unio'] == un:
                q += 1
            ni = n
    return q

def quantes_unions_2(ge, o, d, un):
    if nx.has_path(ge, o, d):
        cami = nx.shortest_path(ge, o, d)
        arestes = zip(cami, cami[1:])
        unions = map(lambda a: ge.edges[a]['unio'], arestes)
        unions_un = filter(lambda u: u == un, unions)
        q = sum(1 for u in unions_un)
    else:
        q = 0
    return q

#
# Tria la solució que vulguis provar
#
quantes_unions = quantes_unions_1

def moviment(g, r):
    cjt = nx.node_connected_component(g, r)
    return cjt - {r}

def grup_d_engranatges_1(ge, rd, un):
    arestes_totes = ge.edges.data('unio')
    arestes_unio = filter(lambda a: a[2] == un, arestes_totes)
    arestes_incidents = map(lambda a: a[:2], arestes_unio)
    g = nx.Graph()
    g.add_node(rd)
    g.add_edges_from(arestes_incidents)
    c = nx.node_connected_component(g, rd)
    c.remove(rd)
    return c

def grup_d_engranatges_2(ge, rd, un):
    g = nx.Graph()
    g.add_node(rd)
    for o, d, u in ge.edges.data('unio'):
        if u == un:
            g.add_edge(o, d)
    c = nx.node_connected_component(g, rd)
    c.remove(rd)
    return c

def grup_d_engranatges_3(gm, rd, un):
    gde = set()
    for node in gm:
        if node != rd and units_per(gm, rd, node, un):
            gde.add(node)
    return gde

def units_per(gm, orig, dest, un):
    if not nx.has_path(gm, orig, dest):
        return False
    cami = nx.shortest_path(gm, orig, dest)
    for i in range(len(cami)-1):
        if gm[cami[i]][cami[i+1]]['unio'] != un:
            return False
    return True

def grup_d_engranatges_4(gm, rd, un):
    s = set()
    nodes = set([rd])
    visitats = set()
    while len(nodes) > 0:
        node = nodes.pop()
        for vei in gm[node]:
            if gm[node][vei]['unio'] == un:
                s.add(vei)
                if vei not in visitats:
                    nodes.add(vei)
        visitats.add(node)
    s.discard(rd)
    return s

#
# Tria la solució que vulguis provar
#
grup_d_engranatges = grup_d_engranatges_1
