Skip to main content

dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{
11    Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, HandoffKind, graph_algorithms,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::union_find::UnionFind;
15
16/// Helper struct for tracking barrier crossers, see [`find_barrier_crossers`].
17struct BarrierCrossers {
18    /// Edge barrier crossers, including what type.
19    pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
20    /// Singleton reference barrier crossers, considered to be [`DelayType::Stratum`].
21    pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
22}
23impl BarrierCrossers {
24    /// Iterate pairs of nodes that are across a barrier. Excludes `DelayType::NextIteration` pairs.
25    fn iter_node_pairs<'a>(
26        &'a self,
27        partitioned_graph: &'a DfirGraph,
28    ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
29        let edge_pairs_iter = self
30            .edge_barrier_crossers
31            .iter()
32            .map(|(edge_id, &delay_type)| {
33                let src_dst = partitioned_graph.edge(edge_id);
34                (src_dst, delay_type)
35            });
36        let singleton_pairs_iter = self
37            .singleton_barrier_crossers
38            .iter()
39            .map(|&src_dst| (src_dst, DelayType::Stratum));
40        edge_pairs_iter.chain(singleton_pairs_iter)
41    }
42
43    /// Insert/replace edge.
44    fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
45        if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
46            self.edge_barrier_crossers.insert(new_edge_id, delay_type);
47        }
48    }
49}
50
51/// Find all the barrier crossers.
52fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
53    let edge_barrier_crossers = partitioned_graph
54        .edges()
55        .filter(|&(_, (_src, dst))| {
56            // Ignore barriers within `loop {` blocks.
57            partitioned_graph.node_loop(dst).is_none()
58        })
59        .filter_map(|(edge_id, (_src, dst))| {
60            let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
61            let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
62            let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
63            Some((edge_id, input_barrier))
64        })
65        .collect();
66    let singleton_barrier_crossers = partitioned_graph
67        .node_ids()
68        .flat_map(|dst| {
69            partitioned_graph
70                .node_singleton_references(dst)
71                .iter()
72                .flatten()
73                .map(move |&src_ref| (src_ref, dst))
74        })
75        .collect();
76    BarrierCrossers {
77        edge_barrier_crossers,
78        singleton_barrier_crossers,
79    }
80}
81
82fn find_subgraph_unionfind(
83    partitioned_graph: &DfirGraph,
84    barrier_crossers: &BarrierCrossers,
85) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
86    // Modality (color) of nodes, push or pull.
87    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
88    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
89    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
90    let mut node_color = partitioned_graph
91        .node_ids()
92        .filter_map(|node_id| {
93            let op_color = partitioned_graph.node_color(node_id)?;
94            Some((node_id, op_color))
95        })
96        .collect::<SparseSecondaryMap<_, _>>();
97
98    let mut subgraph_unionfind: UnionFind<GraphNodeId> =
99        UnionFind::with_capacity(partitioned_graph.nodes().len());
100
101    // Will contain all edges which are handoffs. Starts out with all edges and
102    // we remove from this set as we combine nodes into subgraphs.
103    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
104    // Would sort edges here for priority (for now, no sort/priority).
105
106    // Each edge gets looked at in order. However we may not know if a linear
107    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
108    // algorithm would know to handle linear chains from the outside inward.
109    // But instead we just run through the edges in a loop until no more
110    // progress is made. Could have some sort of O(N^2) pathological worst
111    // case.
112    let mut progress = true;
113    while progress {
114        progress = false;
115        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
116        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
117            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
118            if subgraph_unionfind.same_set(src, dst) {
119                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
120                // Handoffs will be inserted later for this self-loop.
121                continue;
122            }
123
124            // Do not connect stratum crossers (next edges).
125            if barrier_crossers
126                .iter_node_pairs(partitioned_graph)
127                .any(|((x_src, x_dst), _)| {
128                    (subgraph_unionfind.same_set(x_src, src)
129                        && subgraph_unionfind.same_set(x_dst, dst))
130                        || (subgraph_unionfind.same_set(x_src, dst)
131                            && subgraph_unionfind.same_set(x_dst, src))
132                })
133            {
134                continue;
135            }
136
137            // Do not connect across loop contexts.
138            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
139                continue;
140            }
141            // Do not connect `next_iteration()`.
142            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
143                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
144            }) {
145                continue;
146            }
147
148            if can_connect_colorize(&mut node_color, src, dst) {
149                // At this point we have selected this edge and its src & dst to be
150                // within a single subgraph.
151                subgraph_unionfind.union(src, dst);
152                assert!(handoff_edges.remove(&edge_id));
153                progress = true;
154            }
155        }
156    }
157
158    (subgraph_unionfind, handoff_edges)
159}
160
161/// Builds the datastructures for checking which subgraph each node belongs to
162/// after handoffs have already been inserted to partition subgraphs.
163/// This list of nodes in each subgraph are returned in topological sort order.
164fn make_subgraph_collect(
165    partitioned_graph: &DfirGraph,
166    mut subgraph_unionfind: UnionFind<GraphNodeId>,
167) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
168    // We want the nodes of each subgraph to be listed in topo-sort order.
169    // We could do this on each subgraph, or we could do it all at once on the
170    // whole node graph by ignoring handoffs, which is what we do here:
171    let topo_sort = graph_algorithms::topo_sort(
172        partitioned_graph
173            .nodes()
174            .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
175            .map(|(node_id, _)| node_id),
176        |v| {
177            partitioned_graph
178                .node_predecessor_nodes(v)
179                .filter(|&pred_id| {
180                    let pred = partitioned_graph.node(pred_id);
181                    !matches!(pred, GraphNode::Handoff { .. })
182                })
183        },
184    )
185    .expect("Subgraphs are in-out trees.");
186
187    let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
188    for node_id in topo_sort {
189        let repr_node = subgraph_unionfind.find(node_id);
190        if !grouped_nodes.contains_key(repr_node) {
191            grouped_nodes.insert(repr_node, Default::default());
192        }
193        grouped_nodes[repr_node].push(node_id);
194    }
195    grouped_nodes
196}
197
198/// Find subgraph and insert handoffs.
199/// Modifies barrier_crossers so that the edge OUT of an inserted handoff has
200/// the DelayType data.
201fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
202    // Algorithm:
203    // 1. Each node begins as its own subgraph.
204    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
205    // 3. For each edge, try to join `(to, from)` into the same subgraph.
206
207    // TODO(mingwei):
208    // self.partitioned_graph.assert_valid();
209
210    let (subgraph_unionfind, handoff_edges) =
211        find_subgraph_unionfind(partitioned_graph, barrier_crossers);
212
213    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
214    for edge_id in handoff_edges {
215        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
216
217        // Already has a handoff, no need to insert one.
218        let src_node = partitioned_graph.node(src_id);
219        let dst_node = partitioned_graph.node(dst_id);
220        if matches!(src_node, GraphNode::Handoff { .. })
221            || matches!(dst_node, GraphNode::Handoff { .. })
222        {
223            continue;
224        }
225
226        let hoff = GraphNode::Handoff {
227            kind: HandoffKind::Vec,
228            src_span: src_node.span(),
229            dst_span: dst_node.span(),
230        };
231        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
232
233        // Update barrier_crossers for inserted node.
234        barrier_crossers.replace_edge(edge_id, out_edge_id);
235    }
236
237    // Determine node's subgraph and subgraph's nodes.
238    // This list of nodes in each subgraph are to be in topological sort order.
239    // Eventually returned directly in the [`DfirGraph`].
240    let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
241    for (_repr_node, member_nodes) in grouped_nodes {
242        partitioned_graph.insert_subgraph(member_nodes).unwrap();
243    }
244}
245
246/// Set `src` or `dst` color if `None` based on the other (if possible):
247/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
248/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
249///
250/// Returns if `src` and `dst` can be in the same subgraph.
251fn can_connect_colorize(
252    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
253    src: GraphNodeId,
254    dst: GraphNodeId,
255) -> bool {
256    // Pull -> Pull
257    // Push -> Push
258    // Pull -> [Computation] -> Push
259    // Push -> [Handoff] -> Pull
260    let can_connect = match (node_color.get(src), node_color.get(dst)) {
261        // Linear chain, can't connect because it may cause future conflicts.
262        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
263        (None, None) => false,
264
265        // Infer left side.
266        (None, Some(Color::Pull | Color::Comp)) => {
267            node_color.insert(src, Color::Pull);
268            true
269        }
270        (None, Some(Color::Push | Color::Hoff)) => {
271            node_color.insert(src, Color::Push);
272            true
273        }
274
275        // Infer right side.
276        (Some(Color::Pull | Color::Hoff), None) => {
277            node_color.insert(dst, Color::Pull);
278            true
279        }
280        (Some(Color::Comp | Color::Push), None) => {
281            node_color.insert(dst, Color::Push);
282            true
283        }
284
285        // Both sides already specified.
286        (Some(Color::Pull), Some(Color::Pull)) => true,
287        (Some(Color::Pull), Some(Color::Comp)) => true,
288        (Some(Color::Pull), Some(Color::Push)) => true,
289
290        (Some(Color::Comp), Some(Color::Pull)) => false,
291        (Some(Color::Comp), Some(Color::Comp)) => false,
292        (Some(Color::Comp), Some(Color::Push)) => true,
293
294        (Some(Color::Push), Some(Color::Pull)) => false,
295        (Some(Color::Push), Some(Color::Comp)) => false,
296        (Some(Color::Push), Some(Color::Push)) => true,
297
298        // Handoffs are not part of subgraphs.
299        (Some(Color::Hoff), Some(_)) => false,
300        (Some(_), Some(Color::Hoff)) => false,
301    };
302    can_connect
303}
304
305/// Topologically sorts subgraphs and marks tick-boundary (`defer_tick` / `defer_tick_lazy`)
306/// handoffs with their delay type for double-buffered codegen in `as_code`.
307///
308/// Returns an error if there is an intra-tick cycle (i.e. the subgraph DAG has a cycle when
309/// tick-boundary edges are excluded).
310fn order_subgraphs(
311    partitioned_graph: &mut DfirGraph,
312    barrier_crossers: &BarrierCrossers,
313) -> Result<(), Diagnostic> {
314    // Build a subgraph-level directed graph, excluding tick-boundary edges.
315    let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
316
317    // Track which handoff edges are tick-boundary, keyed by (src_sg, dst_sg).
318    let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
319
320    // Iterate handoffs between subgraphs.
321    for (hoff_id, hoff) in partitioned_graph.nodes() {
322        if !matches!(hoff, GraphNode::Handoff { .. }) {
323            continue;
324        }
325
326        // Handoffs may have 0 successors if only used by reference. Skip ordering those.
327        if partitioned_graph.node_degree_out(hoff_id) == 0 {
328            continue;
329        }
330        assert_eq!(1, partitioned_graph.node_degree_out(hoff_id));
331
332        let (succ_edge, succ) = partitioned_graph.node_successors(hoff_id).next().unwrap();
333
334        let succ_edge_delaytype = barrier_crossers
335            .edge_barrier_crossers
336            .get(succ_edge)
337            .copied();
338        // Tick edges are excluded from the topo sort — they are cross-tick by design.
339        if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
340            tick_edges.push((succ_edge, delay_type));
341            continue;
342        }
343
344        assert_eq!(1, partitioned_graph.node_degree_in(hoff_id));
345        let (_edge_id, pred) = partitioned_graph.node_predecessors(hoff_id).next().unwrap();
346
347        let pred_sg = partitioned_graph
348            .node_subgraph(pred)
349            .expect("Handoff pred not in subgraph, may be a doubled/adjacent handoff");
350        let succ_sg = partitioned_graph
351            .node_subgraph(succ)
352            .expect("Handoff succ not in subgraph, may be a doubled/adjacent handoff");
353
354        sg_preds.entry(succ_sg).or_default().push(pred_sg);
355    }
356    // Include singleton reference edges.
357    for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
358        assert_ne!(pred, succ);
359        // For handoff nodes (which have no subgraph), use the predecessor's subgraph.
360        let pred_sg = if let Some(sg) = partitioned_graph.node_subgraph(pred) {
361            sg
362        } else {
363            // pred is a handoff node — find its predecessor operator's subgraph.
364            let (_edge, pred_pred) = partitioned_graph
365                .node_predecessors(pred)
366                .next()
367                .expect("handoff must have a predecessor");
368            partitioned_graph.node_subgraph(pred_pred).unwrap()
369        };
370        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
371        if pred_sg == succ_sg {
372            continue;
373        }
374        sg_preds.entry(succ_sg).or_default().push(pred_sg);
375
376        // For handoff nodes: borrower must run before pipe consumer.
377        // All handoffs should have at most one successor.
378        if matches!(partitioned_graph.node(pred), GraphNode::Handoff { .. }) {
379            assert!(
380                partitioned_graph.node_degree_out(pred) <= 1,
381                "handoff should have at most one successor"
382            );
383            if let Some((_edge, consumer)) = partitioned_graph.node_successors(pred).next() {
384                let consumer_sg = partitioned_graph.node_subgraph(consumer).unwrap();
385                if consumer_sg != succ_sg {
386                    sg_preds.entry(consumer_sg).or_default().push(succ_sg);
387                }
388            }
389        }
390    }
391
392    // Topological sort — rejects intra-tick cycles.
393    if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
394        sg_preds.get(&v).into_iter().flatten().copied()
395    }) {
396        let span = cycle
397            .first()
398            .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
399            .map(|n| partitioned_graph.node(n).span())
400            .unwrap_or_else(Span::call_site);
401        return Err(Diagnostic::spanned(
402            span,
403            Level::Error,
404            "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
405        ));
406    }
407
408    // Mark tick-boundary handoffs with their delay type.
409    // These handoffs are excluded from the intra-tick topo ordering in
410    // `as_code`; instead, their double-buffered handoff semantics defer data
411    // across the tick boundary to the next tick.
412    for (edge_id, delay_type) in tick_edges {
413        let (hoff, _dst) = partitioned_graph.edge(edge_id);
414        assert!(matches!(
415            partitioned_graph.node(hoff),
416            GraphNode::Handoff {
417                kind: HandoffKind::Vec,
418                ..
419            }
420        ));
421        partitioned_graph.set_handoff_delay_type(hoff, delay_type);
422    }
423    Ok(())
424}
425
426/// Main method for this module. Partitions a flat [`DfirGraph`] into one with subgraphs.
427///
428/// Returns an error if an intra-tick cycle exists in the graph.
429pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
430    // Pre-find barrier crossers (input edges with a `DelayType`).
431    let mut barrier_crossers = find_barrier_crossers(&flat_graph);
432    let mut partitioned_graph = flat_graph;
433
434    // Partition into subgraphs.
435    make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
436
437    // Topologically order subgraphs and mark tick-boundary handoffs for double-buffering.
438    order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
439
440    Ok(partitioned_graph)
441}