1use std::collections::{BTreeMap, BTreeSet};
4
5use itertools::Itertools;
6use proc_macro2::Span;
7use slotmap::{SecondaryMap, SparseSecondaryMap};
8
9use super::meta_graph::DfirGraph;
10use super::ops::{DelayType, FloType};
11use super::{
12 Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, HandoffKind, graph_algorithms,
13};
14use crate::diagnostic::{Diagnostic, Level};
15use crate::union_find::UnionFind;
16
17struct BarrierCrossers {
19 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
21 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
23}
24impl BarrierCrossers {
25 fn iter_node_pairs<'a>(
27 &'a self,
28 partitioned_graph: &'a DfirGraph,
29 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
30 let edge_pairs_iter = self
31 .edge_barrier_crossers
32 .iter()
33 .map(|(edge_id, &delay_type)| {
34 let src_dst = partitioned_graph.edge(edge_id);
35 (src_dst, delay_type)
36 });
37 let singleton_pairs_iter = self
38 .singleton_barrier_crossers
39 .iter()
40 .map(|&src_dst| (src_dst, DelayType::Stratum));
41 edge_pairs_iter.chain(singleton_pairs_iter)
42 }
43
44 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
46 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
47 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
48 }
49 }
50}
51
52fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
54 let edge_barrier_crossers = partitioned_graph
55 .edges()
56 .filter(|&(_, (_src, dst))| {
57 partitioned_graph.node_loop(dst).is_none()
59 })
60 .filter_map(|(edge_id, (_src, dst))| {
61 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
62 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
63 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
64 Some((edge_id, input_barrier))
65 })
66 .collect();
67
68 let mut singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)> = partitioned_graph
70 .node_ids()
71 .flat_map(|dst| {
72 partitioned_graph
73 .node_singleton_references(dst)
74 .iter()
75 .filter_map(|r| r.node_id)
76 .map(move |src_ref| (src_ref, dst))
77 })
78 .collect();
79
80 let refs_by_target = partitioned_graph.node_singleton_reference_groups();
84 for (_singleton, groups) in refs_by_target {
86 for (group_a, group_b) in groups.values().tuple_windows() {
88 for &(node_a, _, _) in group_a {
90 for &(node_b, _, _) in group_b {
91 assert_ne!(
93 node_a, node_b,
94 "encounted conflicted or cyclical singleton references\n{:?}\n{:?}",
95 group_a, group_b,
96 );
97 singleton_barrier_crossers.push((node_a, node_b));
98 }
99 }
100 }
101 }
102
103 BarrierCrossers {
104 edge_barrier_crossers,
105 singleton_barrier_crossers,
106 }
107}
108
109fn find_subgraph_unionfind(
110 partitioned_graph: &DfirGraph,
111 barrier_crossers: &BarrierCrossers,
112) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
113 let mut node_color = partitioned_graph
118 .node_ids()
119 .filter_map(|node_id| {
120 let op_color = partitioned_graph.node_color(node_id)?;
121 Some((node_id, op_color))
122 })
123 .collect::<SparseSecondaryMap<_, _>>();
124
125 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
126 UnionFind::with_capacity(partitioned_graph.nodes().len());
127
128 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
131 let mut progress = true;
140 while progress {
141 progress = false;
142 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
144 if subgraph_unionfind.same_set(src, dst) {
146 continue;
149 }
150
151 if barrier_crossers
153 .iter_node_pairs(partitioned_graph)
154 .any(|((x_src, x_dst), _)| {
155 (subgraph_unionfind.same_set(x_src, src)
156 && subgraph_unionfind.same_set(x_dst, dst))
157 || (subgraph_unionfind.same_set(x_src, dst)
158 && subgraph_unionfind.same_set(x_dst, src))
159 })
160 {
161 continue;
162 }
163
164 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
166 continue;
167 }
168 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
170 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
171 }) {
172 continue;
173 }
174
175 if can_connect_colorize(&mut node_color, src, dst) {
176 subgraph_unionfind.union(src, dst);
179 assert!(handoff_edges.remove(&edge_id));
180 progress = true;
181 }
182 }
183 }
184
185 (subgraph_unionfind, handoff_edges)
186}
187
188fn make_subgraph_collect(
192 partitioned_graph: &DfirGraph,
193 mut subgraph_unionfind: UnionFind<GraphNodeId>,
194) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
195 let topo_sort = graph_algorithms::topo_sort(
199 partitioned_graph
200 .nodes()
201 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
202 .map(|(node_id, _)| node_id),
203 |v| {
204 partitioned_graph
205 .node_predecessor_nodes(v)
206 .filter(|&pred_id| {
207 let pred = partitioned_graph.node(pred_id);
208 !matches!(pred, GraphNode::Handoff { .. })
209 })
210 },
211 )
212 .expect("Subgraphs are in-out trees.");
213
214 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
215 for node_id in topo_sort {
216 let repr_node = subgraph_unionfind.find(node_id);
217 if !grouped_nodes.contains_key(repr_node) {
218 grouped_nodes.insert(repr_node, Default::default());
219 }
220 grouped_nodes[repr_node].push(node_id);
221 }
222 grouped_nodes
223}
224
225fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
229 let (subgraph_unionfind, handoff_edges) =
238 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
239
240 for edge_id in handoff_edges {
242 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
243
244 let src_node = partitioned_graph.node(src_id);
246 let dst_node = partitioned_graph.node(dst_id);
247 if matches!(src_node, GraphNode::Handoff { .. })
248 || matches!(dst_node, GraphNode::Handoff { .. })
249 {
250 continue;
251 }
252
253 let hoff = GraphNode::Handoff {
254 kind: HandoffKind::Vec,
255 src_span: src_node.span(),
256 dst_span: dst_node.span(),
257 };
258 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
259
260 barrier_crossers.replace_edge(edge_id, out_edge_id);
262 }
263
264 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
268 for (_repr_node, member_nodes) in grouped_nodes {
269 partitioned_graph.insert_subgraph(member_nodes).unwrap();
270 }
271}
272
273fn can_connect_colorize(
279 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
280 src: GraphNodeId,
281 dst: GraphNodeId,
282) -> bool {
283 let can_connect = match (node_color.get(src), node_color.get(dst)) {
288 (None, None) => false,
291
292 (None, Some(Color::Pull | Color::Comp)) => {
294 node_color.insert(src, Color::Pull);
295 true
296 }
297 (None, Some(Color::Push | Color::Hoff)) => {
298 node_color.insert(src, Color::Push);
299 true
300 }
301
302 (Some(Color::Pull | Color::Hoff), None) => {
304 node_color.insert(dst, Color::Pull);
305 true
306 }
307 (Some(Color::Comp | Color::Push), None) => {
308 node_color.insert(dst, Color::Push);
309 true
310 }
311
312 (Some(Color::Pull), Some(Color::Pull)) => true,
314 (Some(Color::Pull), Some(Color::Comp)) => true,
315 (Some(Color::Pull), Some(Color::Push)) => true,
316
317 (Some(Color::Comp), Some(Color::Pull)) => false,
318 (Some(Color::Comp), Some(Color::Comp)) => false,
319 (Some(Color::Comp), Some(Color::Push)) => true,
320
321 (Some(Color::Push), Some(Color::Pull)) => false,
322 (Some(Color::Push), Some(Color::Comp)) => false,
323 (Some(Color::Push), Some(Color::Push)) => true,
324
325 (Some(Color::Hoff), Some(_)) => false,
327 (Some(_), Some(Color::Hoff)) => false,
328 };
329 can_connect
330}
331
332fn order_subgraphs(
338 partitioned_graph: &mut DfirGraph,
339 barrier_crossers: &BarrierCrossers,
340) -> Result<(), Diagnostic> {
341 let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
343
344 let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
346
347 for (hoff_id, hoff) in partitioned_graph.nodes() {
349 if !matches!(hoff, GraphNode::Handoff { .. }) {
350 continue;
351 }
352
353 if partitioned_graph.node_degree_out(hoff_id) == 0 {
355 continue;
356 }
357 assert_eq!(1, partitioned_graph.node_degree_out(hoff_id));
358
359 let (succ_edge, succ) = partitioned_graph.node_successors(hoff_id).next().unwrap();
360
361 let succ_edge_delaytype = barrier_crossers
362 .edge_barrier_crossers
363 .get(succ_edge)
364 .copied();
365 if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
367 tick_edges.push((succ_edge, delay_type));
368 continue;
369 }
370
371 assert_eq!(1, partitioned_graph.node_degree_in(hoff_id));
372 let (_edge_id, pred) = partitioned_graph.node_predecessors(hoff_id).next().unwrap();
373
374 let pred_sg = partitioned_graph
375 .node_subgraph(pred)
376 .expect("Handoff pred not in subgraph, may be a doubled/adjacent handoff");
377 let succ_sg = partitioned_graph
378 .node_subgraph(succ)
379 .expect("Handoff succ not in subgraph, may be a doubled/adjacent handoff");
380
381 sg_preds.entry(succ_sg).or_default().push(pred_sg);
382 }
383 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
385 assert_ne!(pred, succ);
386 let pred_sg = if let Some(sg) = partitioned_graph.node_subgraph(pred) {
388 sg
389 } else {
390 let (_edge, pred_pred) = partitioned_graph
392 .node_predecessors(pred)
393 .next()
394 .expect("handoff must have a predecessor");
395 partitioned_graph.node_subgraph(pred_pred).unwrap()
396 };
397 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
398 if pred_sg == succ_sg {
399 continue;
400 }
401 sg_preds.entry(succ_sg).or_default().push(pred_sg);
402
403 if matches!(partitioned_graph.node(pred), GraphNode::Handoff { .. }) {
406 assert!(
407 partitioned_graph.node_degree_out(pred) <= 1,
408 "handoff should have at most one successor"
409 );
410 if let Some((_edge, consumer)) = partitioned_graph.node_successors(pred).next() {
411 let consumer_sg = partitioned_graph.node_subgraph(consumer).unwrap();
412 if consumer_sg != succ_sg {
413 sg_preds.entry(consumer_sg).or_default().push(succ_sg);
414 }
415 }
416 }
417 }
418
419 if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
421 sg_preds.get(&v).into_iter().flatten().copied()
422 }) {
423 let span = cycle
424 .first()
425 .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
426 .map(|n| partitioned_graph.node(n).span())
427 .unwrap_or_else(Span::call_site);
428 return Err(Diagnostic::spanned(
429 span,
430 Level::Error,
431 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
432 ));
433 }
434
435 for (edge_id, delay_type) in tick_edges {
440 let (hoff, _dst) = partitioned_graph.edge(edge_id);
441 assert!(matches!(
442 partitioned_graph.node(hoff),
443 GraphNode::Handoff {
444 kind: HandoffKind::Vec,
445 ..
446 }
447 ));
448 partitioned_graph.set_handoff_delay_type(hoff, delay_type);
449 }
450 Ok(())
451}
452
453pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
457 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
459 let mut partitioned_graph = flat_graph;
460
461 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
463
464 order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
466
467 Ok(partitioned_graph)
468}