11"""Maximum flow algorithms and network flow computations."""
22
3- from typing import Literal , Union , overload
3+ from typing import Dict , Literal , Union , overload
44
55from ngraph .lib .algorithms .base import EdgeSelect , FlowPlacement
66from ngraph .lib .algorithms .flow_init import init_flow_graph
77from ngraph .lib .algorithms .place_flow import place_flow_on_graph
8- from ngraph .lib .algorithms .spf import spf
8+ from ngraph .lib .algorithms .spf import Cost , spf
99from ngraph .lib .algorithms .types import FlowSummary
1010from ngraph .lib .graph import NodeID , StrictMultiDiGraph
1111
@@ -208,6 +208,7 @@ def calc_max_flow(
208208 capacity_attr ,
209209 flow_attr ,
210210 tolerance ,
211+ {}, # Empty cost distribution for self-loop case
211212 )
212213 else :
213214 return 0.0
@@ -220,8 +221,11 @@ def calc_max_flow(
220221 reset_flow_graph ,
221222 )
222223
224+ # Initialize cost distribution tracking
225+ cost_distribution : Dict [Cost , float ] = {}
226+
223227 # First path-finding iteration.
224- _ , pred = spf (
228+ costs , pred = spf (
225229 flow_graph , src_node , edge_select = EdgeSelect .ALL_MIN_COST_WITH_CAP_REMAINING
226230 )
227231 flow_meta = place_flow_on_graph (
@@ -236,6 +240,13 @@ def calc_max_flow(
236240 )
237241 max_flow = flow_meta .placed_flow
238242
243+ # Track cost distribution for first iteration
244+ if dst_node in costs and flow_meta .placed_flow > 0 :
245+ path_cost = costs [dst_node ]
246+ cost_distribution [path_cost ] = (
247+ cost_distribution .get (path_cost , 0.0 ) + flow_meta .placed_flow
248+ )
249+
239250 # If only one path (single augmentation) is desired, return early.
240251 if shortest_path :
241252 return _build_return_value (
@@ -247,11 +258,12 @@ def calc_max_flow(
247258 capacity_attr ,
248259 flow_attr ,
249260 tolerance ,
261+ cost_distribution ,
250262 )
251263
252264 # Otherwise, repeatedly find augmenting paths until no new flow can be placed.
253265 while True :
254- _ , pred = spf (
266+ costs , pred = spf (
255267 flow_graph , src_node , edge_select = EdgeSelect .ALL_MIN_COST_WITH_CAP_REMAINING
256268 )
257269 if dst_node not in pred :
@@ -274,6 +286,13 @@ def calc_max_flow(
274286
275287 max_flow += flow_meta .placed_flow
276288
289+ # Track cost distribution for this iteration
290+ if dst_node in costs and flow_meta .placed_flow > 0 :
291+ path_cost = costs [dst_node ]
292+ cost_distribution [path_cost ] = (
293+ cost_distribution .get (path_cost , 0.0 ) + flow_meta .placed_flow
294+ )
295+
277296 return _build_return_value (
278297 max_flow ,
279298 flow_graph ,
@@ -283,6 +302,7 @@ def calc_max_flow(
283302 capacity_attr ,
284303 flow_attr ,
285304 tolerance ,
305+ cost_distribution ,
286306 )
287307
288308
@@ -295,6 +315,7 @@ def _build_return_value(
295315 capacity_attr : str ,
296316 flow_attr : str ,
297317 tolerance : float ,
318+ cost_distribution : Dict [Cost , float ],
298319) -> Union [float , tuple ]:
299320 """Build the appropriate return value based on the requested flags."""
300321 if not (return_summary or return_graph ):
@@ -303,7 +324,13 @@ def _build_return_value(
303324 summary = None
304325 if return_summary :
305326 summary = _build_flow_summary (
306- max_flow , flow_graph , src_node , capacity_attr , flow_attr , tolerance
327+ max_flow ,
328+ flow_graph ,
329+ src_node ,
330+ capacity_attr ,
331+ flow_attr ,
332+ tolerance ,
333+ cost_distribution ,
307334 )
308335
309336 ret : list = [max_flow ]
@@ -322,6 +349,7 @@ def _build_flow_summary(
322349 capacity_attr : str ,
323350 flow_attr : str ,
324351 tolerance : float ,
352+ cost_distribution : Dict [Cost , float ],
325353) -> FlowSummary :
326354 """Build a FlowSummary from the flow graph state."""
327355 edge_flow = {}
@@ -364,6 +392,7 @@ def _build_flow_summary(
364392 residual_cap = residual_cap ,
365393 reachable = reachable ,
366394 min_cut = min_cut ,
395+ cost_distribution = cost_distribution ,
367396 )
368397
369398
0 commit comments