#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Code adapté de flot.py d'Anthony Labarre
    
"""Implémentation d'un graphe orienté pondéré à l'aide d'un dictionnaire: les
clés sont les sommets, et les valeurs sont les successeurs du sommet donné. 
"""

# Imports ---------------------------------------------------------------------
from dictionnaireadjacenceorientepondere import DictionnaireAdjacenceOrientePondere
from queue import Queue


def reseau_residuel(reseau, flot):
    """Renvoie le réseau résiduel associé à un flot et à un réseau.
    
    >>> from dictionnaireadjacenceorientepondere import DictionnaireAdjacenceOrientePondere
    >>> G = DictionnaireAdjacenceOrientePondere()
    >>> # exemple de CLRS
    >>> G.ajouter_arcs([("s",  "v1", 16), ("s",  "v2", 13), ("v1", "v3", 12), ("v2", "v1",  4), ("v2", "v4", 14), ("v3", "v2",  9), ("v3", "t",  20), ("v4", "v3",  7),("v4", "t",   4)])
    >>> flot = dict.fromkeys({(u, v) for u, v, _ in G.arcs()}, 0)
    >>> H = reseau_residuel(G, flot)
    >>> sorted(H.arcs())
    [('s', 'v1', 16), ('s', 'v2', 13), ('v1', 'v3', 12), ('v2', 'v1', 4), ('v2', 'v4', 14), ('v3', 't', 20), ('v3', 'v2', 9), ('v4', 't', 4), ('v4', 'v3', 7)]
    """
    nouveau_reseau = type(reseau)()
    capacites_residuelles = dict()
    arcs = reseau.arcs()
    
    for u, v, capacite in arcs:
        capacites_residuelles[(u, v)] = capacite - flot[(u, v)]
    
    for v, u, _ in arcs:
        capacites_residuelles[(u, v)] = flot[(v, u)]
    
    for (u, v), capacite in capacites_residuelles.items():
        if capacite > 0:
            nouveau_reseau.ajouter_arc(u, v, capacite)
    return nouveau_reseau


def reconstruire_chemin_oriente_pondere(debut, fin, parents, graphe):
    """Renvoie un chemin entre les sommets debut et fin du graphe sur base des
    informations sur les parents, ou None si le chemin ne peut être
    reconstruit."""
    chemin = type(graphe)()
    sommet = fin
    
    while sommet != debut:
        if sommet not in parents:
            return None
        chemin.ajouter_arc(
            parents[sommet], sommet,
            graphe.poids_arc(parents[sommet], sommet)
        )
        sommet = parents[sommet]   
    return chemin


def chemin_augmentant(reseau, source, puits):
    """Renvoie un chemin augmentant dans le réseau donné, ou None s'il n'en
    existe pas.
    
    >>> from dictionnaireadjacenceorientepondere import DictionnaireAdjacenceOrientePondere
    >>> G = DictionnaireAdjacenceOrientePondere()
    >>> # exemple de CLRS
    >>> G.ajouter_arcs([("s",  "v1", 16), ("s",  "v2", 13), ("v1", "v3", 12), ("v2", "v1",  4), ("v2", "v4", 14), ("v3", "v2",  9), ("v3", "t",  20), ("v4", "v3",  7),("v4", "t",   4)])
    >>> sorted(chemin_augmentant(G, 's', 't').arcs()) in ([('s', 'v2', 13), ('v2', 'v4', 14), ('v4', 't', 4)], [('s', 'v1', 16), ('v1', 'v3', 12), ('v3', 't', 20)])
    True
    """
    # on construit l'arbre de parcours en largeur à partir de la source, et on
    # renvoie le chemin obtenu de la source au puits s'il existe -- None sinon
    deja_visites = dict.fromkeys(reseau.sommets(), False)
    a_traiter = Queue()
    a_traiter.put(source)

    parents = dict()
    while not a_traiter.empty():
        sommet = a_traiter.get()
        if not deja_visites[sommet]:
            deja_visites[sommet] = True
            for suivant in sorted(reseau.successeurs(sommet)):
                if not deja_visites[suivant]:
                    a_traiter.put(suivant)
                    if suivant not in parents:
                        parents[suivant] = sommet

    return reconstruire_chemin_oriente_pondere(source, puits, parents, reseau)


def ford_fulkerson(reseau, source, puits):
    """Renvoie un flot maximum pour le réseau donné sous la forme d'un
    dictionnaire dont les clés sont les arcs du réseau et les valeurs sont les
    flots circulant sur l'arc correspondant. 
    
    >>> from dictionnaireadjacenceorientepondere import DictionnaireAdjacenceOrientePondere
    >>> G = DictionnaireAdjacenceOrientePondere()
    >>> # exemple de CLRS
    >>> G.ajouter_arcs([("s",  "v1", 16), ("s",  "v2", 13), ("v1", "v3", 12), ("v2", "v1",  4), ("v2", "v4", 14), ("v3", "v2",  9), ("v3", "t",  20), ("v4", "v3",  7),("v4", "t",   4)])
    >>> source = "s"
    >>> puits = "t"
    >>> sorted(ford_fulkerson(G, source, puits).items())
    [(('s', 'v1'), 12), (('s', 'v2'), 11), (('v1', 'v3'), 12), (('v2', 'v1'), 0), (('v2', 'v4'), 11), (('v3', 't'), 19), (('v3', 'v2'), 0), (('v4', 't'), 4), (('v4', 'v3'), 7)]
    """
    arcs = reseau.arcs()
    sommets = reseau.sommets()
    flot = dict()
    for u, v, _ in arcs:
        flot[(u,v)] = flot[(v,u)] = 0

    while True:
        # trouver un chemin augmentant dans le réseau résiduel, ou arrêter
        # s'il n'y en a pas
        chemin = chemin_augmentant(reseau_residuel(reseau, flot), source, puits)
        if chemin is None:
            break
        # calculer la capacité résiduelle minimale du réseau le long du
        # chemin trouvé
        arcs_chemin = chemin.arcs()
        cap_res_min = min(poids for u, v, poids in arcs_chemin)
        
        # augmenter le flot sur chaque arc du chemin appartenant au réseau du
        # minimum autorisé
        for u, v, _ in arcs_chemin:
            if reseau.contient_arc(u, v):
                flot[(u, v)] += cap_res_min
                flot[(v, u)] = -flot[(u, v)] 
            else:
                flot[(v, u)] -= cap_res_min
                flot[(u, v)] = -flot[(v, u)]
    return flot
