Skip to main content

dfir_lang/graph/
flat_graph_builder.rs

1//! Build a flat graph from [`HfStatement`]s.
2
3use std::borrow::Cow;
4use std::collections::btree_map::Entry;
5use std::collections::{BTreeMap, BTreeSet};
6
7use itertools::Itertools;
8use proc_macro2::Span;
9use quote::ToTokens;
10use syn::spanned::Spanned;
11use syn::{Error, Ident, ItemUse};
12
13use crate::diagnostic::{Diagnostic, Diagnostics, Level};
14use crate::graph::ops::next_iteration::NEXT_ITERATION;
15use crate::graph::ops::{FloType, Persistence, PortListSpec, RangeTrait};
16use crate::graph::{
17    DfirGraph, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId, HandoffKind, PortIndexValue,
18    graph_algorithms,
19};
20use crate::parse::{DfirCode, DfirStatement, Operator, Pipeline};
21use crate::pretty_span::PrettySpan;
22
23#[derive(Clone, Debug)]
24struct Ends {
25    inn: Option<(PortIndexValue, GraphDet)>,
26    out: Option<(PortIndexValue, GraphDet)>,
27}
28
29#[derive(Clone, Debug)]
30enum GraphDet {
31    Determined(GraphNodeId),
32    Undetermined(Ident),
33}
34
35/// Variable name info for each ident, see [`FlatGraphBuilder::varname_ends`].
36#[derive(Debug)]
37struct VarnameInfo {
38    /// What the variable name resolves to.
39    pub ends: Ends,
40    /// Set to true if the varname reference creates an illegal self-referential cycle.
41    pub illegal_cycle: bool,
42    /// Set to true once the in port is used. Used to track unused ports.
43    pub inn_used: bool,
44    /// Set to true once the out port is used. Used to track unused ports.
45    pub out_used: bool,
46}
47impl VarnameInfo {
48    pub fn new(ends: Ends) -> Self {
49        Self {
50            ends,
51            illegal_cycle: false,
52            inn_used: false,
53            out_used: false,
54        }
55    }
56}
57
58/// Wraper around [`DfirGraph`] to build a flat graph from AST code.
59#[derive(Debug, Default)]
60pub struct FlatGraphBuilder {
61    /// Spanned error/warning/etc diagnostics to emit.
62    diagnostics: Diagnostics,
63
64    /// [`DfirGraph`] being built.
65    flat_graph: DfirGraph,
66    /// Variable names, used as [`HfStatement::Named`] are added.
67    varname_ends: BTreeMap<Ident, VarnameInfo>,
68    /// Each (out -> inn) link inputted.
69    links: Vec<Ends>,
70
71    /// Use statements.
72    uses: Vec<ItemUse>,
73
74    /// If the flat graph is being loaded as a module, then two initial ModuleBoundary nodes are inserted into the graph. One
75    /// for the input into the module and one for the output out of the module.
76    module_boundary_nodes: Option<(GraphNodeId, GraphNodeId)>,
77}
78
79/// Output of [`FlatGraphBuilder::build`].
80pub struct FlatGraphBuilderOutput {
81    /// The flat DFIR graph.
82    pub flat_graph: DfirGraph,
83    /// Any `use` statements.
84    pub uses: Vec<ItemUse>,
85    /// Any emitted diagnostics (warnings, errors).
86    pub diagnostics: Diagnostics,
87}
88
89impl FlatGraphBuilder {
90    /// Create a new empty graph builder.
91    pub fn new() -> Self {
92        Default::default()
93    }
94
95    /// Convert the DFIR code AST into a graph builder.
96    pub fn from_dfir(input: DfirCode) -> Self {
97        let mut builder = Self::default();
98        builder.add_dfir(input, None, None);
99        builder
100    }
101
102    /// Build into an unpartitioned [`DfirGraph`], returning a struct containing the flat graph, any diagnostics, and
103    /// other outputs.
104    ///
105    /// If any diagnostics are errors, `Err` is returned and the underlying graph is lost.
106    pub fn build(mut self) -> Result<FlatGraphBuilderOutput, Diagnostics> {
107        self.finalize_connect_operator_links();
108        self.process_operator_errors();
109
110        if self.diagnostics.has_error() {
111            Err(self.diagnostics)
112        } else {
113            Ok(FlatGraphBuilderOutput {
114                flat_graph: self.flat_graph,
115                uses: self.uses,
116                diagnostics: self.diagnostics,
117            })
118        }
119    }
120
121    /// Adds all [`DfirStatement`]s within the [`DfirCode`] to this [`DfirGraph`].
122    ///
123    /// Optional configuration:
124    /// * In the given loop context `current_loop`.
125    /// * With the given operator tag `operator_tag`.
126    pub fn add_dfir(
127        &mut self,
128        dfir: DfirCode,
129        current_loop: Option<GraphLoopId>,
130        operator_tag: Option<&str>,
131    ) {
132        for stmt in dfir.statements {
133            self.add_statement_internal(stmt, current_loop, operator_tag);
134        }
135    }
136
137    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] in the root context.
138    pub fn add_statement(&mut self, stmt: DfirStatement) {
139        self.add_statement_internal(stmt, None, None);
140    }
141
142    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] with given configuration.
143    ///
144    /// Optional configuration:
145    /// * In the given loop context `current_loop`.
146    /// * With the given operator tag `operator_tag`.
147    fn add_statement_internal(
148        &mut self,
149        stmt: DfirStatement,
150        current_loop: Option<GraphLoopId>,
151        operator_tag: Option<&str>,
152    ) {
153        match stmt {
154            DfirStatement::Use(yuse) => {
155                self.uses.push(yuse);
156            }
157            DfirStatement::Named(named) => {
158                let stmt_span = named.span();
159                let ends = self.add_pipeline(
160                    named.pipeline,
161                    Some(&named.name),
162                    current_loop,
163                    operator_tag,
164                );
165                self.assign_varname_checked(named.name, stmt_span, ends);
166            }
167            DfirStatement::Pipeline(pipeline_stmt) => {
168                let ends =
169                    self.add_pipeline(pipeline_stmt.pipeline, None, current_loop, operator_tag);
170                Self::helper_check_unused_port(&mut self.diagnostics, &ends, true);
171                Self::helper_check_unused_port(&mut self.diagnostics, &ends, false);
172            }
173            DfirStatement::Loop(loop_statement) => {
174                let inner_loop = self.flat_graph.insert_loop(current_loop);
175                for stmt in loop_statement.statements {
176                    self.add_statement_internal(stmt, Some(inner_loop), operator_tag);
177                }
178            }
179        }
180    }
181
182    /// Programatically add an pipeline, optionally adding `pred_name` as a single predecessor and
183    /// assigning it all to `asgn_name`.
184    ///
185    /// In DFIR syntax, equivalent to [`Self::add_statement`] of (if all names are supplied):
186    /// ```text
187    /// #asgn_name = #pred_name -> #pipeline;
188    /// ```
189    ///
190    /// But with, optionally:
191    /// * A `current_loop` to put the operator in.
192    /// * An `operator_tag` to tag the operator with, for debugging/tracing.
193    pub fn append_assign_pipeline(
194        &mut self,
195        asgn_name: Option<&Ident>,
196        pred_name: Option<&Ident>,
197        pipeline: Pipeline,
198        current_loop: Option<GraphLoopId>,
199        operator_tag: Option<&str>,
200    ) {
201        let span = pipeline.span();
202        let mut ends = self.add_pipeline(pipeline, asgn_name, current_loop, operator_tag);
203
204        // Connect `pred_name` if supplied.
205        if let Some(pred_name) = pred_name {
206            if let Some(pred_varname_info) = self.varname_ends.get(pred_name) {
207                // Update ends for `asgn_name`.
208                ends = self.connect_ends(pred_varname_info.ends.clone(), ends);
209            } else {
210                self.diagnostics.push(Diagnostic::spanned(
211                    pred_name.span(),
212                    Level::Error,
213                    format!(
214                        "Cannot find referenced name `{}`; name was never assigned.",
215                        pred_name
216                    ),
217                ));
218            }
219        }
220
221        // Assign `asgn_name` if supplied.
222        if let Some(asgn_name) = asgn_name {
223            self.assign_varname_checked(asgn_name.clone(), span, ends);
224        }
225    }
226}
227
228/// Internal methods.
229impl FlatGraphBuilder {
230    /// Assign a variable name to a pipeline, checking for conflicts.
231    fn assign_varname_checked(&mut self, name: Ident, stmt_span: Span, ends: Ends) {
232        match self.varname_ends.entry(name) {
233            Entry::Vacant(vacant_entry) => {
234                vacant_entry.insert(VarnameInfo::new(ends));
235            }
236            Entry::Occupied(occupied_entry) => {
237                let prev_conflict = occupied_entry.key();
238                self.diagnostics.push(Diagnostic::spanned(
239                    prev_conflict.span(),
240                    Level::Error,
241                    format!(
242                        "Existing assignment to `{}` conflicts with later assignment: {} (1/2)",
243                        prev_conflict,
244                        PrettySpan(stmt_span),
245                    ),
246                ));
247                self.diagnostics.push(Diagnostic::spanned(
248                    stmt_span,
249                    Level::Error,
250                    format!(
251                        "Name assignment to `{}` conflicts with existing assignment: {} (2/2)",
252                        prev_conflict,
253                        PrettySpan(prev_conflict.span())
254                    ),
255                ));
256            }
257        }
258    }
259
260    /// Helper: Add a pipeline, i.e. `a -> b -> c`. Return the input and output [`Ends`] for it.
261    fn add_pipeline(
262        &mut self,
263        pipeline: Pipeline,
264        current_varname: Option<&Ident>,
265        current_loop: Option<GraphLoopId>,
266        operator_tag: Option<&str>,
267    ) -> Ends {
268        match pipeline {
269            Pipeline::Paren(ported_pipeline_paren) => {
270                let (inn_port, pipeline_paren, out_port) =
271                    PortIndexValue::from_ported(ported_pipeline_paren);
272                let og_ends = self.add_pipeline(
273                    *pipeline_paren.pipeline,
274                    current_varname,
275                    current_loop,
276                    operator_tag,
277                );
278                Self::helper_combine_ends(&mut self.diagnostics, og_ends, inn_port, out_port)
279            }
280            Pipeline::Name(pipeline_name) => {
281                let (inn_port, ident, out_port) = PortIndexValue::from_ported(pipeline_name);
282
283                // Mingwei: We could lookup non-forward references immediately, but easier to just
284                // have one consistent code path: `GraphDet::Undetermined`.
285                Ends {
286                    inn: Some((inn_port, GraphDet::Undetermined(ident.clone()))),
287                    out: Some((out_port, GraphDet::Undetermined(ident))),
288                }
289            }
290            Pipeline::ModuleBoundary(pipeline_name) => {
291                let Some((input_node, output_node)) = self.module_boundary_nodes else {
292                    self.diagnostics.push(
293                        Error::new(
294                            pipeline_name.span(),
295                            "`mod` is only usable inside of a module.",
296                        )
297                        .into(),
298                    );
299
300                    return Ends {
301                        inn: None,
302                        out: None,
303                    };
304                };
305
306                let (inn_port, _, out_port) = PortIndexValue::from_ported(pipeline_name);
307
308                Ends {
309                    inn: Some((inn_port, GraphDet::Determined(output_node))),
310                    out: Some((out_port, GraphDet::Determined(input_node))),
311                }
312            }
313            Pipeline::Link(pipeline_link) => {
314                // Add the nested LHS and RHS of this link.
315                let lhs_ends = self.add_pipeline(
316                    *pipeline_link.lhs,
317                    current_varname,
318                    current_loop,
319                    operator_tag,
320                );
321                let rhs_ends = self.add_pipeline(
322                    *pipeline_link.rhs,
323                    current_varname,
324                    current_loop,
325                    operator_tag,
326                );
327
328                self.connect_ends(lhs_ends, rhs_ends)
329            }
330            Pipeline::Operator(operator) => {
331                let op_span = Some(operator.span());
332                let (node_id, ends) =
333                    self.add_operator(current_varname, current_loop, operator, op_span);
334                if let Some(operator_tag) = operator_tag {
335                    self.flat_graph
336                        .set_operator_tag(node_id, operator_tag.to_owned());
337                }
338                ends
339            }
340        }
341    }
342
343    /// Connects two [`Ends`] together. Returns the outer [`Ends`] for the connection.
344    ///
345    /// Links the inner ends together by adding it to `self.links`.
346    fn connect_ends(&mut self, lhs_ends: Ends, rhs_ends: Ends) -> Ends {
347        // Outer (first and last) ends.
348        let outer_ends = Ends {
349            inn: lhs_ends.inn,
350            out: rhs_ends.out,
351        };
352        // Inner (link) ends.
353        let link_ends = Ends {
354            out: lhs_ends.out,
355            inn: rhs_ends.inn,
356        };
357        self.links.push(link_ends);
358        outer_ends
359    }
360
361    /// Adds an operator to the graph, returning its [`GraphNodeId`] the input and output [`Ends`] for it.
362    fn add_operator(
363        &mut self,
364        current_varname: Option<&Ident>,
365        current_loop: Option<GraphLoopId>,
366        operator: Operator,
367        op_span: Option<Span>,
368    ) -> (GraphNodeId, Ends) {
369        let node_id = self.flat_graph.insert_node(
370            GraphNode::Operator(operator),
371            current_varname.cloned(),
372            current_loop,
373        );
374        let ends = Ends {
375            inn: Some((
376                PortIndexValue::Elided(op_span),
377                GraphDet::Determined(node_id),
378            )),
379            out: Some((
380                PortIndexValue::Elided(op_span),
381                GraphDet::Determined(node_id),
382            )),
383        };
384        (node_id, ends)
385    }
386
387    /// Connects operator links as a final building step. Processes all the links stored in
388    /// `self.links` and actually puts them into the graph.
389    fn finalize_connect_operator_links(&mut self) {
390        // `->` edges
391        for Ends { out, inn } in std::mem::take(&mut self.links) {
392            let out_opt = self.helper_resolve_name(out, false);
393            let inn_opt = self.helper_resolve_name(inn, true);
394            // `None` already have errors in `self.diagnostics`.
395            if let (Some((out_port, out_node)), Some((inn_port, inn_node))) = (out_opt, inn_opt) {
396                let _ = self.finalize_connect_operators(out_port, out_node, inn_port, inn_node);
397            }
398        }
399
400        // Resolve the singleton references for each node.
401        for node_id in self.flat_graph.node_ids().collect::<Vec<_>>() {
402            if let GraphNode::Operator(operator) = self.flat_graph.node(node_id) {
403                let singletons_referenced = operator
404                    .singletons_referenced
405                    .clone()
406                    .into_iter()
407                    .map(|singleton_ref| {
408                        let port_det = self
409                            .varname_ends
410                            .get(&singleton_ref)
411                            .filter(|varname_info| !varname_info.illegal_cycle)
412                            .map(|varname_info| &varname_info.ends)
413                            .and_then(|ends| ends.out.as_ref())
414                            .cloned();
415                        if let Some((_port, node_id)) = self.helper_resolve_name(port_det, false) {
416                            Some(node_id)
417                        } else {
418                            self.diagnostics.push(Diagnostic::spanned(
419                                singleton_ref.span(),
420                                Level::Error,
421                                format!(
422                                    "Cannot find referenced name `{}`; name was never assigned.",
423                                    singleton_ref
424                                ),
425                            ));
426                            None
427                        }
428                    })
429                    .collect();
430
431                self.flat_graph
432                    .set_node_singleton_references(node_id, singletons_referenced);
433            }
434        }
435    }
436
437    /// Recursively resolve a variable name. For handling forward (and backward) name references
438    /// after all names have been assigned.
439    /// Returns `None` if the name is not resolvable, either because it was never assigned or
440    /// because it contains a self-referential cycle.
441    ///
442    /// `is_in` set to `true` means the _input_ side will be returned. `false` means the _output_ side will be returned.
443    fn helper_resolve_name(
444        &mut self,
445        mut port_det: Option<(PortIndexValue, GraphDet)>,
446        is_in: bool,
447    ) -> Option<(PortIndexValue, GraphNodeId)> {
448        const BACKUP_RECURSION_LIMIT: usize = 1024;
449
450        let mut names = Vec::new();
451        for _ in 0..BACKUP_RECURSION_LIMIT {
452            match port_det? {
453                (port, GraphDet::Determined(node_id)) => {
454                    return Some((port, node_id));
455                }
456                (port, GraphDet::Undetermined(ident)) => {
457                    let Some(varname_info) = self.varname_ends.get_mut(&ident) else {
458                        self.diagnostics.push(Diagnostic::spanned(
459                            ident.span(),
460                            Level::Error,
461                            format!("Cannot find name `{}`; name was never assigned.", ident),
462                        ));
463                        return None;
464                    };
465                    // Check for a self-referential cycle.
466                    let cycle_found = names.contains(&ident);
467                    if !cycle_found {
468                        names.push(ident);
469                    };
470                    if cycle_found || varname_info.illegal_cycle {
471                        let len = names.len();
472                        for (i, name) in names.into_iter().enumerate() {
473                            self.diagnostics.push(Diagnostic::spanned(
474                                name.span(),
475                                Level::Error,
476                                format!(
477                                    "Name `{}` forms or references an illegal self-referential cycle ({}/{}).",
478                                    name,
479                                    i + 1,
480                                    len
481                                ),
482                            ));
483                            // Set value as `Err(())` to trigger `name_ends_result.is_err()`
484                            // diagnostics above if the name is referenced in the future.
485                            self.varname_ends.get_mut(&name).unwrap().illegal_cycle = true;
486                        }
487                        return None;
488                    }
489
490                    // No self-cycle.
491                    let prev = if is_in {
492                        varname_info.inn_used = true;
493                        &varname_info.ends.inn
494                    } else {
495                        varname_info.out_used = true;
496                        &varname_info.ends.out
497                    };
498                    port_det = Self::helper_combine_end(
499                        &mut self.diagnostics,
500                        prev.clone(),
501                        port,
502                        if is_in { "input" } else { "output" },
503                    );
504                }
505            }
506        }
507        self.diagnostics.push(Diagnostic::spanned(
508            Span::call_site(),
509            Level::Error,
510            format!(
511                "Reached the recursion limit {} while resolving names. This is either a dfir bug or you have an absurdly long chain of names: `{}`.",
512                BACKUP_RECURSION_LIMIT,
513                names.iter().map(ToString::to_string).collect::<Vec<_>>().join("` -> `"),
514            )
515        ));
516        None
517    }
518
519    /// Connect two operators on the given port indexes.
520    fn finalize_connect_operators(
521        &mut self,
522        src_port: PortIndexValue,
523        src: GraphNodeId,
524        dst_port: PortIndexValue,
525        dst: GraphNodeId,
526    ) -> GraphEdgeId {
527        {
528            /// Helper to emit conflicts when a port is used twice.
529            fn emit_conflict(
530                inout: &str,
531                old: &PortIndexValue,
532                new: &PortIndexValue,
533                diagnostics: &mut Diagnostics,
534            ) {
535                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
536                diagnostics.push(Diagnostic::spanned(
537                    old.span(),
538                    Level::Error,
539                    format!(
540                        "{} connection conflicts with below ({}) (1/2)",
541                        inout,
542                        PrettySpan(new.span()),
543                    ),
544                ));
545                diagnostics.push(Diagnostic::spanned(
546                    new.span(),
547                    Level::Error,
548                    format!(
549                        "{} connection conflicts with above ({}) (2/2)",
550                        inout,
551                        PrettySpan(old.span()),
552                    ),
553                ));
554            }
555
556            // Handle src's successor port conflicts:
557            if src_port.is_specified() {
558                for conflicting_port in self
559                    .flat_graph
560                    .node_successor_edges(src)
561                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).0)
562                    .filter(|&port| port == &src_port)
563                {
564                    emit_conflict("Output", conflicting_port, &src_port, &mut self.diagnostics);
565                }
566            }
567
568            // Handle dst's predecessor port conflicts:
569            if dst_port.is_specified() {
570                for conflicting_port in self
571                    .flat_graph
572                    .node_predecessor_edges(dst)
573                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).1)
574                    .filter(|&port| port == &dst_port)
575                {
576                    emit_conflict("Input", conflicting_port, &dst_port, &mut self.diagnostics);
577                }
578            }
579        }
580        self.flat_graph.insert_edge(src, src_port, dst, dst_port)
581    }
582
583    /// Process operators and emit operator errors.
584    fn process_operator_errors(&mut self) {
585        self.make_operator_instances();
586        self.check_operator_errors();
587        self.warn_unused_port_indexing();
588        self.check_loop_errors();
589    }
590
591    /// Make `OperatorInstance`s for each operator node.
592    fn make_operator_instances(&mut self) {
593        self.flat_graph
594            .insert_node_op_insts_all(&mut self.diagnostics);
595    }
596
597    /// Validates that operators have valid number of inputs, outputs, & arguments.
598    /// Adds errors (and warnings) to `self.diagnostics`.
599    fn check_operator_errors(&mut self) {
600        /// Returns true if an error was found.
601        fn emit_arity_error(
602            op_span: Span,
603            op_name: &str,
604            is_in: bool,
605            is_hard: bool,
606            degree: usize,
607            range: &dyn RangeTrait<usize>,
608            diagnostics: &mut Diagnostics,
609        ) -> bool {
610            let message = format!(
611                "`{}` {} have {} {}, actually has {}.",
612                op_name,
613                if is_hard { "must" } else { "should" },
614                range.human_string(),
615                if is_in { "input(s)" } else { "output(s)" },
616                degree,
617            );
618            let out_of_range = !range.contains(&degree);
619            if out_of_range {
620                diagnostics.push(Diagnostic::spanned(
621                    op_span,
622                    if is_hard {
623                        Level::Error
624                    } else {
625                        Level::Warning
626                    },
627                    message,
628                ));
629            }
630            out_of_range
631        }
632
633        for (node_id, node) in self.flat_graph.nodes() {
634            match node {
635                GraphNode::Operator(operator) => {
636                    let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
637                        // Error already emitted by `insert_node_op_insts_all`.
638                        continue;
639                    };
640                    let op_constraints = op_inst.op_constraints;
641                    let op_name = operator.name_string();
642
643                    // Check number of args
644                    if op_constraints.num_args != operator.args.len() {
645                        self.diagnostics.push(Diagnostic::spanned(
646                            operator.span(),
647                            Level::Error,
648                            format!(
649                                "`{}` expects {} argument(s), received {}.",
650                                op_name,
651                                op_constraints.num_args,
652                                operator.args.len()
653                            ),
654                        ));
655                    }
656
657                    // Check input/output (port) arity
658                    let inn_degree = self.flat_graph.node_degree_in(node_id);
659                    let _ = emit_arity_error(
660                        operator.span(),
661                        &op_name,
662                        true,
663                        true,
664                        inn_degree,
665                        op_constraints.hard_range_inn,
666                        &mut self.diagnostics,
667                    ) || emit_arity_error(
668                        operator.span(),
669                        &op_name,
670                        true,
671                        false,
672                        inn_degree,
673                        op_constraints.soft_range_inn,
674                        &mut self.diagnostics,
675                    );
676
677                    let out_degree = self.flat_graph.node_degree_out(node_id);
678                    let _ = emit_arity_error(
679                        operator.span(),
680                        &op_name,
681                        false,
682                        true,
683                        out_degree,
684                        op_constraints.hard_range_out,
685                        &mut self.diagnostics,
686                    ) || emit_arity_error(
687                        operator.span(),
688                        &op_name,
689                        false,
690                        false,
691                        out_degree,
692                        op_constraints.soft_range_out,
693                        &mut self.diagnostics,
694                    );
695
696                    fn emit_port_error<'a>(
697                        op_span: Span,
698                        op_name: &str,
699                        expected_ports_fn: Option<fn() -> PortListSpec>,
700                        actual_ports_iter: impl Iterator<Item = &'a PortIndexValue>,
701                        input_output: &'static str,
702                        diagnostics: &mut Diagnostics,
703                    ) {
704                        let Some(expected_ports_fn) = expected_ports_fn else {
705                            return;
706                        };
707                        let PortListSpec::Fixed(expected_ports) = (expected_ports_fn)() else {
708                            // Separate check inside of `demux` special case.
709                            return;
710                        };
711                        let expected_ports: Vec<_> = expected_ports.into_iter().collect();
712
713                        // Reject unexpected ports.
714                        let ports: BTreeSet<_> = actual_ports_iter
715                            // Use `inspect` before collecting into `BTreeSet` to ensure we get
716                            // both error messages on duplicated port names.
717                            .inspect(|actual_port_iv| {
718                                // For each actually used port `port_index_value`, check if it is expected.
719                                let is_expected = expected_ports.iter().any(|port_index| {
720                                    actual_port_iv == &&port_index.clone().into()
721                                });
722                                // If it is not expected, emit a diagnostic error.
723                                if !is_expected {
724                                    diagnostics.push(Diagnostic::spanned(
725                                        actual_port_iv.span(),
726                                        Level::Error,
727                                        format!(
728                                            "`{}` received unexpected {} port: {}. Expected one of: `{}`",
729                                            op_name,
730                                            input_output,
731                                            actual_port_iv.as_error_message_string(),
732                                            Itertools::intersperse(
733                                                expected_ports
734                                                    .iter()
735                                                    .map(|port| port.to_token_stream().to_string())
736                                                    .map(Cow::Owned),
737                                                Cow::Borrowed("`, `"),
738                                            ).collect::<String>()
739                                        ),
740                                    ))
741                                }
742                            })
743                            .collect();
744
745                        // List missing expected ports.
746                        let missing: Vec<_> = expected_ports
747                            .into_iter()
748                            .filter_map(|expected_port| {
749                                let tokens = expected_port.to_token_stream();
750                                if !ports.contains(&&expected_port.into()) {
751                                    Some(tokens)
752                                } else {
753                                    None
754                                }
755                            })
756                            .collect();
757                        if !missing.is_empty() {
758                            diagnostics.push(Diagnostic::spanned(
759                                op_span,
760                                Level::Error,
761                                format!(
762                                    "`{}` missing expected {} port(s): `{}`.",
763                                    op_name,
764                                    input_output,
765                                    Itertools::intersperse(
766                                        missing.into_iter().map(|port| Cow::Owned(
767                                            port.to_token_stream().to_string()
768                                        )),
769                                        Cow::Borrowed("`, `")
770                                    )
771                                    .collect::<String>()
772                                ),
773                            ));
774                        }
775                    }
776
777                    emit_port_error(
778                        operator.span(),
779                        &op_name,
780                        op_constraints.ports_inn,
781                        self.flat_graph
782                            .node_predecessor_edges(node_id)
783                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).1),
784                        "input",
785                        &mut self.diagnostics,
786                    );
787                    emit_port_error(
788                        operator.span(),
789                        &op_name,
790                        op_constraints.ports_out,
791                        self.flat_graph
792                            .node_successor_edges(node_id)
793                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).0),
794                        "output",
795                        &mut self.diagnostics,
796                    );
797
798                    // Check that singleton references actually reference valid targets.
799                    {
800                        let singletons_resolved =
801                            self.flat_graph.node_singleton_references(node_id);
802                        for (singleton_node_id, singleton_ident) in singletons_resolved
803                            .iter()
804                            .zip_eq(&*operator.singletons_referenced)
805                        {
806                            let &Some(singleton_node_id) = singleton_node_id else {
807                                // Error already emitted by `connect_operator_links`, "Cannot find referenced name...".
808                                continue;
809                            };
810                            // HandoffKind::Option nodes are valid singleton reference targets.
811                            if matches!(
812                                self.flat_graph.node(singleton_node_id),
813                                GraphNode::Handoff {
814                                    kind: HandoffKind::Option,
815                                    ..
816                                }
817                            ) {
818                                continue;
819                            }
820                            let Some(ref_op_inst) = self.flat_graph.node_op_inst(singleton_node_id)
821                            else {
822                                // Error already emitted by `insert_node_op_insts_all`.
823                                continue;
824                            };
825                            let ref_op_constraints = ref_op_inst.op_constraints;
826                            if !ref_op_constraints.has_singleton_output {
827                                self.diagnostics.push(Diagnostic::spanned(
828                                    singleton_ident.span(),
829                                    Level::Error,
830                                    format!(
831                                        "Cannot reference operator `{}`. Only operators with singleton state can be referenced.",
832                                        ref_op_constraints.name,
833                                    ),
834                                ));
835                            }
836                        }
837                    }
838                }
839                GraphNode::Handoff { kind, src_span, .. } => {
840                    // Validate arity: handoff must have exactly 1 input and 1 output.
841                    let op_name = match kind {
842                        HandoffKind::Vec => "handoff",
843                        HandoffKind::Option => "singleton",
844                    };
845                    let inn_degree = self.flat_graph.node_degree_in(node_id);
846                    emit_arity_error(
847                        *src_span,
848                        op_name,
849                        true,
850                        true,
851                        inn_degree,
852                        &(1..=1),
853                        &mut self.diagnostics,
854                    );
855                    let out_degree = self.flat_graph.node_degree_out(node_id);
856                    let out_degree_range = match kind {
857                        HandoffKind::Vec => 1..=1,
858                        // `singleton()` may be no-output, only by ref. In the future this will also apply to vec.
859                        HandoffKind::Option => 0..=1,
860                    };
861                    emit_arity_error(
862                        *src_span,
863                        op_name,
864                        false,
865                        true,
866                        out_degree,
867                        &out_degree_range,
868                        &mut self.diagnostics,
869                    );
870                }
871                GraphNode::ModuleBoundary { .. } => {
872                    // Module boundaries don't require any checking.
873                }
874            }
875        }
876    }
877
878    /// Warns about unused port indexing referenced in [`Self::varname_ends`].
879    /// https://github.com/hydro-project/hydro/issues/1108
880    fn warn_unused_port_indexing(&mut self) {
881        for (_ident, varname_info) in self.varname_ends.iter() {
882            if !varname_info.inn_used {
883                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, true);
884            }
885            if !varname_info.out_used {
886                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, false);
887            }
888        }
889    }
890
891    /// Emit a warning to `diagnostics` for an unused port (i.e. if the port is specified for
892    /// reason).
893    fn helper_check_unused_port(diagnostics: &mut Diagnostics, ends: &Ends, is_in: bool) {
894        let port = if is_in { &ends.inn } else { &ends.out };
895        if let Some((port, _)) = port
896            && port.is_specified()
897        {
898            diagnostics.push(Diagnostic::spanned(
899                port.span(),
900                Level::Error,
901                format!(
902                    "{} port index is unused. (Is the port on the correct side?)",
903                    if is_in { "Input" } else { "Output" },
904                ),
905            ));
906        }
907    }
908
909    /// Helper function.
910    /// Combine the port indexing information for indexing wrapped around a name.
911    /// Because the name may already have indexing, this may introduce double indexing (i.e. `[0][0]my_var[0][0]`)
912    /// which would be an error.
913    fn helper_combine_ends(
914        diagnostics: &mut Diagnostics,
915        og_ends: Ends,
916        inn_port: PortIndexValue,
917        out_port: PortIndexValue,
918    ) -> Ends {
919        Ends {
920            inn: Self::helper_combine_end(diagnostics, og_ends.inn, inn_port, "input"),
921            out: Self::helper_combine_end(diagnostics, og_ends.out, out_port, "output"),
922        }
923    }
924
925    /// Helper function.
926    /// Combine the port indexing info for one input or output.
927    fn helper_combine_end(
928        diagnostics: &mut Diagnostics,
929        og: Option<(PortIndexValue, GraphDet)>,
930        other: PortIndexValue,
931        input_output: &'static str,
932    ) -> Option<(PortIndexValue, GraphDet)> {
933        // TODO(mingwei): minification pass over this code?
934
935        let other_span = other.span();
936
937        let (og_port, og_node) = og?;
938        match og_port.combine(other) {
939            Ok(combined_port) => Some((combined_port, og_node)),
940            Err(og_port) => {
941                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
942                diagnostics.push(Diagnostic::spanned(
943                    og_port.span(),
944                    Level::Error,
945                    format!(
946                        "Indexing on {} is overwritten below ({}) (1/2).",
947                        input_output,
948                        PrettySpan(other_span),
949                    ),
950                ));
951                diagnostics.push(Diagnostic::spanned(
952                    other_span,
953                    Level::Error,
954                    format!(
955                        "Cannot index on already-indexed {}, previously indexed above ({}) (2/2).",
956                        input_output,
957                        PrettySpan(og_port.span()),
958                    ),
959                ));
960                // When errored, just use original and ignore OTHER port to minimize
961                // noisy/extra diagnostics.
962                Some((og_port, og_node))
963            }
964        }
965    }
966
967    /// Check for loop context-related errors.
968    fn check_loop_errors(&mut self) {
969        for (node_id, node) in self.flat_graph.nodes() {
970            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
971                continue;
972            };
973            let loop_opt = self.flat_graph.node_loop(node_id);
974
975            // Ensure no `'tick` or `'static` persistences are used WITHIN a loop context.
976            // Ensure no `'loop` persistences are used OUTSIDE a loop context.
977            for persistence in &op_inst.generics.persistence_args {
978                let span = op_inst.generics.generic_args.span();
979                match (loop_opt, persistence) {
980                    (Some(_loop_id), p @ (Persistence::Tick | Persistence::Static)) => {
981                        self.diagnostics.push(Diagnostic::spanned(
982                            span,
983                            Level::Error,
984                            format!(
985                                "Operator uses `'{}` persistence, which is not allowed within a `loop {{ ... }}` context.",
986                                p.to_str_lowercase(),
987                            ),
988                        ));
989                    }
990                    (None, p @ (Persistence::None | Persistence::Loop)) => {
991                        self.diagnostics.push(Diagnostic::spanned(
992                            span,
993                            Level::Error,
994                            format!(
995                                "Operator uses `'{}` persistence, but is not within a `loop {{ ... }}` context.",
996                                p.to_str_lowercase(),
997                            ),
998                        ));
999                    }
1000                    _ => {}
1001                }
1002            }
1003
1004            // All inputs must be declared in the root block.
1005            if let (Some(_loop_id), Some(FloType::Source)) =
1006                (loop_opt, op_inst.op_constraints.flo_type)
1007            {
1008                self.diagnostics.push(Diagnostic::spanned(
1009                    node.span(),
1010                    Level::Error,
1011                    format!(
1012                        "Source operator `{}(...)` must be at the root level, not within any `loop {{ ... }}` contexts.",
1013                        op_inst.op_constraints.name
1014                    )
1015                ));
1016            }
1017        }
1018
1019        // Check windowing and un-windowing operators, for loop inputs and outputs respectively.
1020        for (_edge_id, (pred_id, node_id)) in self.flat_graph.edges() {
1021            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
1022                continue;
1023            };
1024            let flo_type = &op_inst.op_constraints.flo_type;
1025
1026            let pred_loop_id = self.flat_graph.node_loop(pred_id);
1027            let loop_id = self.flat_graph.node_loop(node_id);
1028
1029            let span = self.flat_graph.node(node_id).span();
1030
1031            let (is_input, is_output) = {
1032                let parent_pred_loop_id =
1033                    pred_loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
1034                let parent_loop_id = loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
1035                let is_same = pred_loop_id == loop_id;
1036                let is_input = !is_same && parent_loop_id == pred_loop_id;
1037                let is_output = !is_same && parent_pred_loop_id == loop_id;
1038                if !(is_input || is_output || is_same) {
1039                    self.diagnostics.push(Diagnostic::spanned(
1040                        span,
1041                        Level::Error,
1042                        "Operator input edge may not cross multiple loop contexts.",
1043                    ));
1044                    continue;
1045                }
1046                (is_input, is_output)
1047            };
1048
1049            match flo_type {
1050                None => {
1051                    if is_input {
1052                        self.diagnostics.push(Diagnostic::spanned(
1053                            span,
1054                            Level::Error,
1055                            format!(
1056                                "Operator `{}(...)` entering a loop context must be a windowing operator, but is not.",
1057                                op_inst.op_constraints.name
1058                            )
1059                        ));
1060                    }
1061                    if is_output {
1062                        self.diagnostics.push(Diagnostic::spanned(
1063                            span,
1064                            Level::Error,
1065                            format!(
1066                                "Operator `{}(...)` exiting a loop context must be an un-windowing operator, but is not.",
1067                                op_inst.op_constraints.name
1068                            )
1069                        ));
1070                    }
1071                }
1072                Some(FloType::Windowing) => {
1073                    if !is_input {
1074                        self.diagnostics.push(Diagnostic::spanned(
1075                            span,
1076                            Level::Error,
1077                            format!(
1078                                "Windowing operator `{}(...)` must be the first input operator into a `loop {{ ... }} context.",
1079                                op_inst.op_constraints.name
1080                            )
1081                        ));
1082                    }
1083                }
1084                Some(FloType::Unwindowing) => {
1085                    if !is_output {
1086                        self.diagnostics.push(Diagnostic::spanned(
1087                            span,
1088                            Level::Error,
1089                            format!(
1090                                "Un-windowing operator `{}(...)` must be the first output operator after exiting a `loop {{ ... }} context.",
1091                                op_inst.op_constraints.name
1092                            )
1093                        ));
1094                    }
1095                }
1096                Some(FloType::NextIteration) => {
1097                    // Must be in a loop context.
1098                    if loop_id.is_none() {
1099                        self.diagnostics.push(Diagnostic::spanned(
1100                            span,
1101                            Level::Error,
1102                            format!(
1103                                "Operator `{}(...)` must be within a `loop {{ ... }}` context.",
1104                                op_inst.op_constraints.name
1105                            ),
1106                        ));
1107                    }
1108                }
1109                Some(FloType::Source) => {
1110                    // Handled above.
1111                }
1112            }
1113        }
1114
1115        // Must be a DAG (excluding `next_iteration()` operators).
1116        // TODO(mingwei): Nested loop blocks should count as a single node.
1117        // But this doesn't cause any correctness issues because the nested loops are also DAGs.
1118        for (loop_id, loop_nodes) in self.flat_graph.loops() {
1119            // Filter out `next_iteration()` operators.
1120            let filter_next_iteration = |&node_id: &GraphNodeId| {
1121                self.flat_graph
1122                    .node_op_inst(node_id)
1123                    .map(|op_inst| Some(FloType::NextIteration) != op_inst.op_constraints.flo_type)
1124                    .unwrap_or(true)
1125            };
1126
1127            let topo_sort_result = graph_algorithms::topo_sort(
1128                loop_nodes.iter().copied().filter(filter_next_iteration),
1129                |dst| {
1130                    self.flat_graph
1131                        .node_predecessor_nodes(dst)
1132                        .filter(|&src| Some(loop_id) == self.flat_graph.node_loop(src))
1133                        .filter(filter_next_iteration)
1134                },
1135            );
1136            if let Err(cycle) = topo_sort_result {
1137                let len = cycle.len();
1138                for (i, node_id) in cycle.into_iter().enumerate() {
1139                    let span = self.flat_graph.node(node_id).span();
1140                    self.diagnostics.push(Diagnostic::spanned(
1141                        span,
1142                        Level::Error,
1143                        format!(
1144                            "Operator forms an illegal cycle within a `loop {{ ... }}` block. Use `{}()` to pass data across loop iterations. ({}/{})",
1145                            NEXT_ITERATION.name,
1146                            i + 1,
1147                            len,
1148                        ),
1149                    ));
1150                }
1151            }
1152        }
1153    }
1154}