Skip to main content

dfir_lang/graph/
flat_graph_builder.rs

1//! Build a flat graph from [`HfStatement`]s.
2
3use std::borrow::Cow;
4use std::collections::btree_map::Entry;
5use std::collections::{BTreeMap, BTreeSet};
6
7use itertools::Itertools;
8use proc_macro2::Span;
9use quote::ToTokens;
10use syn::spanned::Spanned;
11use syn::{Error, Ident, ItemUse};
12
13use crate::diagnostic::{Diagnostic, Diagnostics, Level};
14use crate::graph::meta_graph::ResolvedSingletonRef;
15use crate::graph::ops::next_iteration::NEXT_ITERATION;
16use crate::graph::ops::{FloType, Persistence, PortListSpec, RangeTrait};
17use crate::graph::{
18    DfirGraph, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId, HandoffKind, PortIndexValue,
19    graph_algorithms,
20};
21use crate::parse::{DfirCode, DfirStatement, Operator, Pipeline};
22use crate::pretty_span::PrettySpan;
23
24#[derive(Clone, Debug)]
25struct Ends {
26    inn: Option<(PortIndexValue, GraphDet)>,
27    out: Option<(PortIndexValue, GraphDet)>,
28}
29
30#[derive(Clone, Debug)]
31enum GraphDet {
32    Determined(GraphNodeId),
33    Undetermined(Ident),
34}
35
36/// Variable name info for each ident, see [`FlatGraphBuilder::varname_ends`].
37#[derive(Debug)]
38struct VarnameInfo {
39    /// What the variable name resolves to.
40    pub ends: Ends,
41    /// Set to true if the varname reference creates an illegal self-referential cycle.
42    pub illegal_cycle: bool,
43    /// Set to true once the in port is used. Used to track unused ports.
44    pub inn_used: bool,
45    /// Set to true once the out port is used. Used to track unused ports.
46    pub out_used: bool,
47}
48impl VarnameInfo {
49    pub fn new(ends: Ends) -> Self {
50        Self {
51            ends,
52            illegal_cycle: false,
53            inn_used: false,
54            out_used: false,
55        }
56    }
57}
58
59/// Wraper around [`DfirGraph`] to build a flat graph from AST code.
60#[derive(Debug, Default)]
61pub struct FlatGraphBuilder {
62    /// Spanned error/warning/etc diagnostics to emit.
63    diagnostics: Diagnostics,
64
65    /// [`DfirGraph`] being built.
66    flat_graph: DfirGraph,
67    /// Variable names, used as [`HfStatement::Named`] are added.
68    varname_ends: BTreeMap<Ident, VarnameInfo>,
69    /// Each (out -> inn) link inputted.
70    links: Vec<Ends>,
71
72    /// Use statements.
73    uses: Vec<ItemUse>,
74
75    /// If the flat graph is being loaded as a module, then two initial ModuleBoundary nodes are inserted into the graph. One
76    /// for the input into the module and one for the output out of the module.
77    module_boundary_nodes: Option<(GraphNodeId, GraphNodeId)>,
78}
79
80/// Output of [`FlatGraphBuilder::build`].
81pub struct FlatGraphBuilderOutput {
82    /// The flat DFIR graph.
83    pub flat_graph: DfirGraph,
84    /// Any `use` statements.
85    pub uses: Vec<ItemUse>,
86    /// Any emitted diagnostics (warnings, errors).
87    pub diagnostics: Diagnostics,
88}
89
90impl FlatGraphBuilder {
91    /// Create a new empty graph builder.
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Convert the DFIR code AST into a graph builder.
97    pub fn from_dfir(input: DfirCode) -> Self {
98        let mut builder = Self::default();
99        builder.add_dfir(input, None, None);
100        builder
101    }
102
103    /// Build into an unpartitioned [`DfirGraph`], returning a struct containing the flat graph, any diagnostics, and
104    /// other outputs.
105    ///
106    /// If any diagnostics are errors, `Err` is returned and the underlying graph is lost.
107    pub fn build(mut self) -> Result<FlatGraphBuilderOutput, Diagnostics> {
108        self.finalize_connect_operator_links();
109        self.process_operator_errors();
110
111        if self.diagnostics.has_error() {
112            Err(self.diagnostics)
113        } else {
114            Ok(FlatGraphBuilderOutput {
115                flat_graph: self.flat_graph,
116                uses: self.uses,
117                diagnostics: self.diagnostics,
118            })
119        }
120    }
121
122    /// Adds all [`DfirStatement`]s within the [`DfirCode`] to this [`DfirGraph`].
123    ///
124    /// Optional configuration:
125    /// * In the given loop context `current_loop`.
126    /// * With the given operator tag `operator_tag`.
127    pub fn add_dfir(
128        &mut self,
129        dfir: DfirCode,
130        current_loop: Option<GraphLoopId>,
131        operator_tag: Option<&str>,
132    ) {
133        for stmt in dfir.statements {
134            self.add_statement_internal(stmt, current_loop, operator_tag);
135        }
136    }
137
138    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] in the root context.
139    pub fn add_statement(&mut self, stmt: DfirStatement) {
140        self.add_statement_internal(stmt, None, None);
141    }
142
143    /// Add a single [`DfirStatement`] line to this [`DfirGraph`] with given configuration.
144    ///
145    /// Optional configuration:
146    /// * In the given loop context `current_loop`.
147    /// * With the given operator tag `operator_tag`.
148    fn add_statement_internal(
149        &mut self,
150        stmt: DfirStatement,
151        current_loop: Option<GraphLoopId>,
152        operator_tag: Option<&str>,
153    ) {
154        match stmt {
155            DfirStatement::Use(yuse) => {
156                self.uses.push(yuse);
157            }
158            DfirStatement::Named(named) => {
159                let stmt_span = named.span();
160                let ends = self.add_pipeline(
161                    named.pipeline,
162                    Some(&named.name),
163                    current_loop,
164                    operator_tag,
165                );
166                self.assign_varname_checked(named.name, stmt_span, ends);
167            }
168            DfirStatement::Pipeline(pipeline_stmt) => {
169                let ends =
170                    self.add_pipeline(pipeline_stmt.pipeline, None, current_loop, operator_tag);
171                Self::helper_check_unused_port(&mut self.diagnostics, &ends, true);
172                Self::helper_check_unused_port(&mut self.diagnostics, &ends, false);
173            }
174            DfirStatement::Loop(loop_statement) => {
175                let inner_loop = self.flat_graph.insert_loop(current_loop);
176                for stmt in loop_statement.statements {
177                    self.add_statement_internal(stmt, Some(inner_loop), operator_tag);
178                }
179            }
180        }
181    }
182
183    /// Programatically add an pipeline, optionally adding `pred_name` as a single predecessor and
184    /// assigning it all to `asgn_name`.
185    ///
186    /// In DFIR syntax, equivalent to [`Self::add_statement`] of (if all names are supplied):
187    /// ```text
188    /// #asgn_name = #pred_name -> #pipeline;
189    /// ```
190    ///
191    /// But with, optionally:
192    /// * A `current_loop` to put the operator in.
193    /// * An `operator_tag` to tag the operator with, for debugging/tracing.
194    pub fn append_assign_pipeline(
195        &mut self,
196        asgn_name: Option<&Ident>,
197        pred_name: Option<&Ident>,
198        pipeline: Pipeline,
199        current_loop: Option<GraphLoopId>,
200        operator_tag: Option<&str>,
201    ) {
202        let span = pipeline.span();
203        let mut ends = self.add_pipeline(pipeline, asgn_name, current_loop, operator_tag);
204
205        // Connect `pred_name` if supplied.
206        if let Some(pred_name) = pred_name {
207            if let Some(pred_varname_info) = self.varname_ends.get(pred_name) {
208                // Update ends for `asgn_name`.
209                ends = self.connect_ends(pred_varname_info.ends.clone(), ends);
210            } else {
211                self.diagnostics.push(Diagnostic::spanned(
212                    pred_name.span(),
213                    Level::Error,
214                    format!(
215                        "Cannot find referenced name `{}`; name was never assigned.",
216                        pred_name
217                    ),
218                ));
219            }
220        }
221
222        // Assign `asgn_name` if supplied.
223        if let Some(asgn_name) = asgn_name {
224            self.assign_varname_checked(asgn_name.clone(), span, ends);
225        }
226    }
227}
228
229/// Internal methods.
230impl FlatGraphBuilder {
231    /// Assign a variable name to a pipeline, checking for conflicts.
232    fn assign_varname_checked(&mut self, name: Ident, stmt_span: Span, ends: Ends) {
233        match self.varname_ends.entry(name) {
234            Entry::Vacant(vacant_entry) => {
235                vacant_entry.insert(VarnameInfo::new(ends));
236            }
237            Entry::Occupied(occupied_entry) => {
238                let prev_conflict = occupied_entry.key();
239                self.diagnostics.push(Diagnostic::spanned(
240                    prev_conflict.span(),
241                    Level::Error,
242                    format!(
243                        "Existing assignment to `{}` conflicts with later assignment: {} (1/2)",
244                        prev_conflict,
245                        PrettySpan(stmt_span),
246                    ),
247                ));
248                self.diagnostics.push(Diagnostic::spanned(
249                    stmt_span,
250                    Level::Error,
251                    format!(
252                        "Name assignment to `{}` conflicts with existing assignment: {} (2/2)",
253                        prev_conflict,
254                        PrettySpan(prev_conflict.span())
255                    ),
256                ));
257            }
258        }
259    }
260
261    /// Helper: Add a pipeline, i.e. `a -> b -> c`. Return the input and output [`Ends`] for it.
262    fn add_pipeline(
263        &mut self,
264        pipeline: Pipeline,
265        current_varname: Option<&Ident>,
266        current_loop: Option<GraphLoopId>,
267        operator_tag: Option<&str>,
268    ) -> Ends {
269        match pipeline {
270            Pipeline::Paren(ported_pipeline_paren) => {
271                let (inn_port, pipeline_paren, out_port) =
272                    PortIndexValue::from_ported(ported_pipeline_paren);
273                let og_ends = self.add_pipeline(
274                    *pipeline_paren.pipeline,
275                    current_varname,
276                    current_loop,
277                    operator_tag,
278                );
279                Self::helper_combine_ends(&mut self.diagnostics, og_ends, inn_port, out_port)
280            }
281            Pipeline::Name(pipeline_name) => {
282                let (inn_port, ident, out_port) = PortIndexValue::from_ported(pipeline_name);
283
284                // Mingwei: We could lookup non-forward references immediately, but easier to just
285                // have one consistent code path: `GraphDet::Undetermined`.
286                Ends {
287                    inn: Some((inn_port, GraphDet::Undetermined(ident.clone()))),
288                    out: Some((out_port, GraphDet::Undetermined(ident))),
289                }
290            }
291            Pipeline::ModuleBoundary(pipeline_name) => {
292                let Some((input_node, output_node)) = self.module_boundary_nodes else {
293                    self.diagnostics.push(
294                        Error::new(
295                            pipeline_name.span(),
296                            "`mod` is only usable inside of a module.",
297                        )
298                        .into(),
299                    );
300
301                    return Ends {
302                        inn: None,
303                        out: None,
304                    };
305                };
306
307                let (inn_port, _, out_port) = PortIndexValue::from_ported(pipeline_name);
308
309                Ends {
310                    inn: Some((inn_port, GraphDet::Determined(output_node))),
311                    out: Some((out_port, GraphDet::Determined(input_node))),
312                }
313            }
314            Pipeline::Link(pipeline_link) => {
315                // Add the nested LHS and RHS of this link.
316                let lhs_ends = self.add_pipeline(
317                    *pipeline_link.lhs,
318                    current_varname,
319                    current_loop,
320                    operator_tag,
321                );
322                let rhs_ends = self.add_pipeline(
323                    *pipeline_link.rhs,
324                    current_varname,
325                    current_loop,
326                    operator_tag,
327                );
328
329                self.connect_ends(lhs_ends, rhs_ends)
330            }
331            Pipeline::Operator(operator) => {
332                let op_span = Some(operator.span());
333                let (node_id, ends) =
334                    self.add_operator(current_varname, current_loop, operator, op_span);
335                if let Some(operator_tag) = operator_tag {
336                    self.flat_graph
337                        .set_operator_tag(node_id, operator_tag.to_owned());
338                }
339                ends
340            }
341        }
342    }
343
344    /// Connects two [`Ends`] together. Returns the outer [`Ends`] for the connection.
345    ///
346    /// Links the inner ends together by adding it to `self.links`.
347    fn connect_ends(&mut self, lhs_ends: Ends, rhs_ends: Ends) -> Ends {
348        // Outer (first and last) ends.
349        let outer_ends = Ends {
350            inn: lhs_ends.inn,
351            out: rhs_ends.out,
352        };
353        // Inner (link) ends.
354        let link_ends = Ends {
355            out: lhs_ends.out,
356            inn: rhs_ends.inn,
357        };
358        self.links.push(link_ends);
359        outer_ends
360    }
361
362    /// Adds an operator to the graph, returning its [`GraphNodeId`] the input and output [`Ends`] for it.
363    fn add_operator(
364        &mut self,
365        current_varname: Option<&Ident>,
366        current_loop: Option<GraphLoopId>,
367        operator: Operator,
368        op_span: Option<Span>,
369    ) -> (GraphNodeId, Ends) {
370        let node_id = self.flat_graph.insert_node(
371            GraphNode::Operator(operator),
372            current_varname.cloned(),
373            current_loop,
374        );
375        let ends = Ends {
376            inn: Some((
377                PortIndexValue::Elided(op_span),
378                GraphDet::Determined(node_id),
379            )),
380            out: Some((
381                PortIndexValue::Elided(op_span),
382                GraphDet::Determined(node_id),
383            )),
384        };
385        (node_id, ends)
386    }
387
388    /// Connects operator links as a final building step. Processes all the links stored in
389    /// `self.links` and actually puts them into the graph.
390    fn finalize_connect_operator_links(&mut self) {
391        // `->` edges
392        for Ends { out, inn } in std::mem::take(&mut self.links) {
393            let out_opt = Self::helper_resolve_name(
394                &mut self.varname_ends,
395                out,
396                false,
397                &mut self.diagnostics,
398            );
399            let inn_opt =
400                Self::helper_resolve_name(&mut self.varname_ends, inn, true, &mut self.diagnostics);
401            // `None` already have errors in `self.diagnostics`.
402            if let (Some((out_port, out_node)), Some((inn_port, inn_node))) = (out_opt, inn_opt) {
403                let _ = self.finalize_connect_operators(out_port, out_node, inn_port, inn_node);
404            }
405        }
406
407        // Resolve the singleton references for each node.
408        for node_id in self.flat_graph.node_ids().collect::<Vec<_>>() {
409            if let GraphNode::Operator(operator) = self.flat_graph.node(node_id) {
410                let singletons_referenced = operator
411                    .singletons_referenced
412                    .iter()
413                    .map(|singleton_ref| {
414                        let port_det = self
415                            .varname_ends
416                            .get(&singleton_ref.ident)
417                            .filter(|varname_info| !varname_info.illegal_cycle)
418                            .map(|varname_info| &varname_info.ends)
419                            .and_then(|ends| ends.out.as_ref())
420                            .cloned();
421                        let resolved_node_id = if let Some((_port, node_id)) =
422                            Self::helper_resolve_name(
423                                &mut self.varname_ends,
424                                port_det,
425                                false,
426                                &mut self.diagnostics,
427                            ) {
428                            Some(node_id)
429                        } else {
430                            self.diagnostics.push(Diagnostic::spanned(
431                                singleton_ref.span(),
432                                Level::Error,
433                                format!(
434                                    "Cannot find referenced name `{}`; name was never assigned.",
435                                    singleton_ref.ident
436                                ),
437                            ));
438                            None
439                        };
440                        ResolvedSingletonRef {
441                            node_id: resolved_node_id,
442                            is_mut: singleton_ref.token_mut.is_some(),
443                            access_group: singleton_ref.access_group.as_ref().and_then(
444                                |(_, lit_int)| match lit_int.base10_parse::<u32>() {
445                                    Ok(n) => Some(n),
446                                    Err(e) => {
447                                        self.diagnostics.push(Diagnostic::spanned(
448                                            lit_int.span(),
449                                            Level::Error,
450                                            format!("Access group is not a valid `u32`: {}", e),
451                                        ));
452                                        None
453                                    }
454                                },
455                            ),
456                        }
457                    })
458                    .collect();
459
460                self.flat_graph
461                    .set_node_singleton_references(node_id, singletons_referenced);
462            }
463        }
464    }
465
466    /// Recursively resolve a variable name. For handling forward (and backward) name references
467    /// after all names have been assigned.
468    /// Returns `None` if the name is not resolvable, either because it was never assigned or
469    /// because it contains a self-referential cycle.
470    ///
471    /// `is_in` set to `true` means the _input_ side will be returned. `false` means the _output_ side will be returned.
472    fn helper_resolve_name(
473        varname_ends: &mut BTreeMap<Ident, VarnameInfo>,
474        mut port_det: Option<(PortIndexValue, GraphDet)>,
475        is_in: bool,
476        diagnostics: &mut Diagnostics,
477    ) -> Option<(PortIndexValue, GraphNodeId)> {
478        const BACKUP_RECURSION_LIMIT: usize = 1024;
479
480        let mut names = Vec::new();
481        for _ in 0..BACKUP_RECURSION_LIMIT {
482            match port_det? {
483                (port, GraphDet::Determined(node_id)) => {
484                    return Some((port, node_id));
485                }
486                (port, GraphDet::Undetermined(ident)) => {
487                    let Some(varname_info) = varname_ends.get_mut(&ident) else {
488                        diagnostics.push(Diagnostic::spanned(
489                            ident.span(),
490                            Level::Error,
491                            format!("Cannot find name `{}`; name was never assigned.", ident),
492                        ));
493                        return None;
494                    };
495                    // Check for a self-referential cycle.
496                    let cycle_found = names.contains(&ident);
497                    if !cycle_found {
498                        names.push(ident);
499                    };
500                    if cycle_found || varname_info.illegal_cycle {
501                        let len = names.len();
502                        for (i, name) in names.into_iter().enumerate() {
503                            diagnostics.push(Diagnostic::spanned(
504                                name.span(),
505                                Level::Error,
506                                format!(
507                                    "Name `{}` forms or references an illegal self-referential cycle ({}/{}).",
508                                    name,
509                                    i + 1,
510                                    len
511                                ),
512                            ));
513                            // Set value as `Err(())` to trigger `name_ends_result.is_err()`
514                            // diagnostics above if the name is referenced in the future.
515                            varname_ends.get_mut(&name).unwrap().illegal_cycle = true;
516                        }
517                        return None;
518                    }
519
520                    // No self-cycle.
521                    let prev = if is_in {
522                        varname_info.inn_used = true;
523                        &varname_info.ends.inn
524                    } else {
525                        varname_info.out_used = true;
526                        &varname_info.ends.out
527                    };
528                    port_det = Self::helper_combine_end(
529                        diagnostics,
530                        prev.clone(),
531                        port,
532                        if is_in { "input" } else { "output" },
533                    );
534                }
535            }
536        }
537        diagnostics.push(Diagnostic::spanned(
538            Span::call_site(),
539            Level::Error,
540            format!(
541                "Reached the recursion limit {} while resolving names. This is either a dfir bug or you have an absurdly long chain of names: `{}`.",
542                BACKUP_RECURSION_LIMIT,
543                names.iter().map(ToString::to_string).collect::<Vec<_>>().join("` -> `"),
544            )
545        ));
546        None
547    }
548
549    /// Connect two operators on the given port indexes.
550    fn finalize_connect_operators(
551        &mut self,
552        src_port: PortIndexValue,
553        src: GraphNodeId,
554        dst_port: PortIndexValue,
555        dst: GraphNodeId,
556    ) -> GraphEdgeId {
557        {
558            /// Helper to emit conflicts when a port is used twice.
559            fn emit_conflict(
560                inout: &str,
561                old: &PortIndexValue,
562                new: &PortIndexValue,
563                diagnostics: &mut Diagnostics,
564            ) {
565                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
566                diagnostics.push(Diagnostic::spanned(
567                    old.span(),
568                    Level::Error,
569                    format!(
570                        "{} connection conflicts with below ({}) (1/2)",
571                        inout,
572                        PrettySpan(new.span()),
573                    ),
574                ));
575                diagnostics.push(Diagnostic::spanned(
576                    new.span(),
577                    Level::Error,
578                    format!(
579                        "{} connection conflicts with above ({}) (2/2)",
580                        inout,
581                        PrettySpan(old.span()),
582                    ),
583                ));
584            }
585
586            // Handle src's successor port conflicts:
587            if src_port.is_specified() {
588                for conflicting_port in self
589                    .flat_graph
590                    .node_successor_edges(src)
591                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).0)
592                    .filter(|&port| port == &src_port)
593                {
594                    emit_conflict("Output", conflicting_port, &src_port, &mut self.diagnostics);
595                }
596            }
597
598            // Handle dst's predecessor port conflicts:
599            if dst_port.is_specified() {
600                for conflicting_port in self
601                    .flat_graph
602                    .node_predecessor_edges(dst)
603                    .map(|edge_id| self.flat_graph.edge_ports(edge_id).1)
604                    .filter(|&port| port == &dst_port)
605                {
606                    emit_conflict("Input", conflicting_port, &dst_port, &mut self.diagnostics);
607                }
608            }
609        }
610        self.flat_graph.insert_edge(src, src_port, dst, dst_port)
611    }
612
613    /// Process operators and emit operator errors.
614    fn process_operator_errors(&mut self) {
615        self.make_operator_instances();
616        self.check_operator_errors();
617        self.warn_unused_port_indexing();
618        self.check_loop_errors();
619    }
620
621    /// Make `OperatorInstance`s for each operator node.
622    fn make_operator_instances(&mut self) {
623        self.flat_graph
624            .insert_node_op_insts_all(&mut self.diagnostics);
625    }
626
627    /// Validates that operators have valid number of inputs, outputs, & arguments.
628    /// Adds errors (and warnings) to `self.diagnostics`.
629    fn check_operator_errors(&mut self) {
630        /// Returns true if an error was found.
631        fn emit_arity_error(
632            op_span: Span,
633            op_name: &str,
634            is_in: bool,
635            is_hard: bool,
636            degree: usize,
637            range: &dyn RangeTrait<usize>,
638            diagnostics: &mut Diagnostics,
639        ) -> bool {
640            let message = format!(
641                "`{}` {} have {} {}, actually has {}.",
642                op_name,
643                if is_hard { "must" } else { "should" },
644                range.human_string(),
645                if is_in { "input(s)" } else { "output(s)" },
646                degree,
647            );
648            let out_of_range = !range.contains(&degree);
649            if out_of_range {
650                diagnostics.push(Diagnostic::spanned(
651                    op_span,
652                    if is_hard {
653                        Level::Error
654                    } else {
655                        Level::Warning
656                    },
657                    message,
658                ));
659            }
660            out_of_range
661        }
662
663        for (node_id, node) in self.flat_graph.nodes() {
664            match node {
665                GraphNode::Operator(operator) => {
666                    let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
667                        // Error already emitted by `insert_node_op_insts_all`.
668                        continue;
669                    };
670                    let op_constraints = op_inst.op_constraints;
671                    let op_name = operator.name_string();
672
673                    // Check number of args
674                    if op_constraints.num_args != operator.args.len() {
675                        self.diagnostics.push(Diagnostic::spanned(
676                            operator.span(),
677                            Level::Error,
678                            format!(
679                                "`{}` expects {} argument(s), received {}.",
680                                op_name,
681                                op_constraints.num_args,
682                                operator.args.len()
683                            ),
684                        ));
685                    }
686
687                    // Check input/output (port) arity
688                    let inn_degree = self.flat_graph.node_degree_in(node_id);
689                    let _ = emit_arity_error(
690                        operator.span(),
691                        &op_name,
692                        true,
693                        true,
694                        inn_degree,
695                        op_constraints.hard_range_inn,
696                        &mut self.diagnostics,
697                    ) || emit_arity_error(
698                        operator.span(),
699                        &op_name,
700                        true,
701                        false,
702                        inn_degree,
703                        op_constraints.soft_range_inn,
704                        &mut self.diagnostics,
705                    );
706
707                    let out_degree = self.flat_graph.node_degree_out(node_id);
708                    let _ = emit_arity_error(
709                        operator.span(),
710                        &op_name,
711                        false,
712                        true,
713                        out_degree,
714                        op_constraints.hard_range_out,
715                        &mut self.diagnostics,
716                    ) || emit_arity_error(
717                        operator.span(),
718                        &op_name,
719                        false,
720                        false,
721                        out_degree,
722                        op_constraints.soft_range_out,
723                        &mut self.diagnostics,
724                    );
725
726                    fn emit_port_error<'a>(
727                        op_span: Span,
728                        op_name: &str,
729                        expected_ports_fn: Option<fn() -> PortListSpec>,
730                        actual_ports_iter: impl Iterator<Item = &'a PortIndexValue>,
731                        input_output: &'static str,
732                        diagnostics: &mut Diagnostics,
733                    ) {
734                        let Some(expected_ports_fn) = expected_ports_fn else {
735                            return;
736                        };
737                        let PortListSpec::Fixed(expected_ports) = (expected_ports_fn)() else {
738                            // Separate check inside of `demux` special case.
739                            return;
740                        };
741                        let expected_ports: Vec<_> = expected_ports.into_iter().collect();
742
743                        // Reject unexpected ports.
744                        let ports: BTreeSet<_> = actual_ports_iter
745                            // Use `inspect` before collecting into `BTreeSet` to ensure we get
746                            // both error messages on duplicated port names.
747                            .inspect(|actual_port_iv| {
748                                // For each actually used port `port_index_value`, check if it is expected.
749                                let is_expected = expected_ports.iter().any(|port_index| {
750                                    actual_port_iv == &&port_index.clone().into()
751                                });
752                                // If it is not expected, emit a diagnostic error.
753                                if !is_expected {
754                                    diagnostics.push(Diagnostic::spanned(
755                                        actual_port_iv.span(),
756                                        Level::Error,
757                                        format!(
758                                            "`{}` received unexpected {} port: {}. Expected one of: `{}`",
759                                            op_name,
760                                            input_output,
761                                            actual_port_iv.as_error_message_string(),
762                                            Itertools::intersperse(
763                                                expected_ports
764                                                    .iter()
765                                                    .map(|port| port.to_token_stream().to_string())
766                                                    .map(Cow::Owned),
767                                                Cow::Borrowed("`, `"),
768                                            ).collect::<String>()
769                                        ),
770                                    ))
771                                }
772                            })
773                            .collect();
774
775                        // List missing expected ports.
776                        let missing: Vec<_> = expected_ports
777                            .into_iter()
778                            .filter_map(|expected_port| {
779                                let tokens = expected_port.to_token_stream();
780                                if !ports.contains(&&expected_port.into()) {
781                                    Some(tokens)
782                                } else {
783                                    None
784                                }
785                            })
786                            .collect();
787                        if !missing.is_empty() {
788                            diagnostics.push(Diagnostic::spanned(
789                                op_span,
790                                Level::Error,
791                                format!(
792                                    "`{}` missing expected {} port(s): `{}`.",
793                                    op_name,
794                                    input_output,
795                                    Itertools::intersperse(
796                                        missing.into_iter().map(|port| Cow::Owned(
797                                            port.to_token_stream().to_string()
798                                        )),
799                                        Cow::Borrowed("`, `")
800                                    )
801                                    .collect::<String>()
802                                ),
803                            ));
804                        }
805                    }
806
807                    emit_port_error(
808                        operator.span(),
809                        &op_name,
810                        op_constraints.ports_inn,
811                        self.flat_graph
812                            .node_predecessor_edges(node_id)
813                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).1),
814                        "input",
815                        &mut self.diagnostics,
816                    );
817                    emit_port_error(
818                        operator.span(),
819                        &op_name,
820                        op_constraints.ports_out,
821                        self.flat_graph
822                            .node_successor_edges(node_id)
823                            .map(|edge_id| self.flat_graph.edge_ports(edge_id).0),
824                        "output",
825                        &mut self.diagnostics,
826                    );
827
828                    // Check that singleton references actually reference valid targets.
829                    {
830                        let singletons_resolved =
831                            self.flat_graph.node_singleton_references(node_id);
832                        for (resolved_ref, singleton_ref_token) in singletons_resolved
833                            .iter()
834                            .zip_eq(&*operator.singletons_referenced)
835                        {
836                            let Some(singleton_node_id) = resolved_ref.node_id else {
837                                // Error already emitted by `connect_operator_links`, "Cannot find referenced name...".
838                                continue;
839                            };
840                            // Handoff nodes are valid reference targets.
841                            if matches!(
842                                self.flat_graph.node(singleton_node_id),
843                                GraphNode::Handoff { .. },
844                            ) {
845                                continue;
846                            }
847                            let Some(ref_op_inst) = self.flat_graph.node_op_inst(singleton_node_id)
848                            else {
849                                // Error already emitted by `insert_node_op_insts_all`.
850                                continue;
851                            };
852                            let ref_op_constraints = ref_op_inst.op_constraints;
853                            self.diagnostics.push(Diagnostic::spanned(
854                                singleton_ref_token.span(),
855                                Level::Error,
856                                format!(
857                                    "Cannot reference operator `{}`. Use `singleton()`, `optional()`, or `handoff()` to create a referenceable name.",
858                                    ref_op_constraints.name,
859                                ),
860                            ));
861                        }
862                    }
863                }
864                GraphNode::Handoff { kind, src_span, .. } => {
865                    // Validate arity: handoff must have exactly 1 input and 1 output.
866                    let op_name = match kind {
867                        HandoffKind::Vec => "handoff",
868                        HandoffKind::Singleton => "singleton",
869                        HandoffKind::Optional => "optional",
870                    };
871                    let inn_degree = self.flat_graph.node_degree_in(node_id);
872                    emit_arity_error(
873                        *src_span,
874                        op_name,
875                        true,
876                        true,
877                        inn_degree,
878                        &(1..=1),
879                        &mut self.diagnostics,
880                    );
881                    let out_degree = self.flat_graph.node_degree_out(node_id);
882                    emit_arity_error(
883                        *src_span,
884                        op_name,
885                        false,
886                        true,
887                        out_degree,
888                        &(0..=1), // Handoffs may be no-output, for use only by ref.
889                        &mut self.diagnostics,
890                    );
891                }
892                GraphNode::ModuleBoundary { .. } => {
893                    // Module boundaries don't require any checking.
894                }
895            }
896        }
897
898        // Validate singleton references.
899        // All singleton references must have unambiguous group orderings.
900        // Rules:
901        // 1. If any singleton reference has an explicit group number, they all must have one.
902        // 2. Every `#mut` must be in its own group.
903        {
904            let refs_by_target = self.flat_graph.node_singleton_reference_groups();
905            // For each singleton, check the groups.
906            for (_singleton, groups) in refs_by_target {
907                // Rule 1. If any singleton reference has an explicit group number, they all must have one.
908                if 1 < groups.len()
909                    && let Some(ungrouped) = groups.get(&None)
910                {
911                    for &(_src_node, r, span) in ungrouped {
912                        self.diagnostics.push(Diagnostic::spanned(
913                            span,
914                            Level::Error,
915                            format!(
916                                "Must use an explicit group `#{{N}}{}` to reference a singleton when other references use explicit groups.",
917                                if r.is_mut { " mut" } else { "" },
918                            ),
919                        ));
920                    }
921                }
922                // Rule 2. Every `#mut` must be in its own group.
923                for (group_idx, group) in groups {
924                    if 1 < group.len() && group.iter().any(|(_, r, _)| r.is_mut) {
925                        let group_str = if let Some(n) = group_idx {
926                            format!("`#{{{}}}`", n)
927                        } else {
928                            "<default>".to_owned()
929                        };
930                        for (_src_node, _mut_r, span) in
931                            group.into_iter().filter(|(_, r, _)| r.is_mut)
932                        {
933                            self.diagnostics.push(Diagnostic::spanned(
934                                span,
935                                Level::Error,
936                                format!("Mutable singleton references must be the only one in their access group, but group {} has multiple.", group_str),
937                            ));
938                        }
939                    }
940                }
941            }
942        }
943    }
944
945    /// Warns about unused port indexing referenced in [`Self::varname_ends`].
946    /// https://github.com/hydro-project/hydro/issues/1108
947    fn warn_unused_port_indexing(&mut self) {
948        for (_ident, varname_info) in self.varname_ends.iter() {
949            if !varname_info.inn_used {
950                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, true);
951            }
952            if !varname_info.out_used {
953                Self::helper_check_unused_port(&mut self.diagnostics, &varname_info.ends, false);
954            }
955        }
956    }
957
958    /// Emit a warning to `diagnostics` for an unused port (i.e. if the port is specified for
959    /// reason).
960    fn helper_check_unused_port(diagnostics: &mut Diagnostics, ends: &Ends, is_in: bool) {
961        let port = if is_in { &ends.inn } else { &ends.out };
962        if let Some((port, _)) = port
963            && port.is_specified()
964        {
965            diagnostics.push(Diagnostic::spanned(
966                port.span(),
967                Level::Error,
968                format!(
969                    "{} port index is unused. (Is the port on the correct side?)",
970                    if is_in { "Input" } else { "Output" },
971                ),
972            ));
973        }
974    }
975
976    /// Helper function.
977    /// Combine the port indexing information for indexing wrapped around a name.
978    /// Because the name may already have indexing, this may introduce double indexing (i.e. `[0][0]my_var[0][0]`)
979    /// which would be an error.
980    fn helper_combine_ends(
981        diagnostics: &mut Diagnostics,
982        og_ends: Ends,
983        inn_port: PortIndexValue,
984        out_port: PortIndexValue,
985    ) -> Ends {
986        Ends {
987            inn: Self::helper_combine_end(diagnostics, og_ends.inn, inn_port, "input"),
988            out: Self::helper_combine_end(diagnostics, og_ends.out, out_port, "output"),
989        }
990    }
991
992    /// Helper function.
993    /// Combine the port indexing info for one input or output.
994    fn helper_combine_end(
995        diagnostics: &mut Diagnostics,
996        og: Option<(PortIndexValue, GraphDet)>,
997        other: PortIndexValue,
998        input_output: &'static str,
999    ) -> Option<(PortIndexValue, GraphDet)> {
1000        // TODO(mingwei): minification pass over this code?
1001
1002        let other_span = other.span();
1003
1004        let (og_port, og_node) = og?;
1005        match og_port.combine(other) {
1006            Ok(combined_port) => Some((combined_port, og_node)),
1007            Err(og_port) => {
1008                // TODO(mingwei): Use `MultiSpan` once `proc_macro2` supports it.
1009                diagnostics.push(Diagnostic::spanned(
1010                    og_port.span(),
1011                    Level::Error,
1012                    format!(
1013                        "Indexing on {} is overwritten below ({}) (1/2).",
1014                        input_output,
1015                        PrettySpan(other_span),
1016                    ),
1017                ));
1018                diagnostics.push(Diagnostic::spanned(
1019                    other_span,
1020                    Level::Error,
1021                    format!(
1022                        "Cannot index on already-indexed {}, previously indexed above ({}) (2/2).",
1023                        input_output,
1024                        PrettySpan(og_port.span()),
1025                    ),
1026                ));
1027                // When errored, just use original and ignore OTHER port to minimize
1028                // noisy/extra diagnostics.
1029                Some((og_port, og_node))
1030            }
1031        }
1032    }
1033
1034    /// Check for loop context-related errors.
1035    fn check_loop_errors(&mut self) {
1036        for (node_id, node) in self.flat_graph.nodes() {
1037            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
1038                continue;
1039            };
1040            let loop_opt = self.flat_graph.node_loop(node_id);
1041
1042            // Ensure no `'tick` or `'static` persistences are used WITHIN a loop context.
1043            // Ensure no `'loop` persistences are used OUTSIDE a loop context.
1044            for persistence in &op_inst.generics.persistence_args {
1045                let span = op_inst.generics.generic_args.span();
1046                match (loop_opt, persistence) {
1047                    (Some(_loop_id), p @ (Persistence::Tick | Persistence::Static)) => {
1048                        self.diagnostics.push(Diagnostic::spanned(
1049                            span,
1050                            Level::Error,
1051                            format!(
1052                                "Operator uses `'{}` persistence, which is not allowed within a `loop {{ ... }}` context.",
1053                                p.to_str_lowercase(),
1054                            ),
1055                        ));
1056                    }
1057                    (None, p @ (Persistence::None | Persistence::Loop)) => {
1058                        self.diagnostics.push(Diagnostic::spanned(
1059                            span,
1060                            Level::Error,
1061                            format!(
1062                                "Operator uses `'{}` persistence, but is not within a `loop {{ ... }}` context.",
1063                                p.to_str_lowercase(),
1064                            ),
1065                        ));
1066                    }
1067                    _ => {}
1068                }
1069            }
1070
1071            // All inputs must be declared in the root block.
1072            if let (Some(_loop_id), Some(FloType::Source)) =
1073                (loop_opt, op_inst.op_constraints.flo_type)
1074            {
1075                self.diagnostics.push(Diagnostic::spanned(
1076                    node.span(),
1077                    Level::Error,
1078                    format!(
1079                        "Source operator `{}(...)` must be at the root level, not within any `loop {{ ... }}` contexts.",
1080                        op_inst.op_constraints.name
1081                    )
1082                ));
1083            }
1084        }
1085
1086        // Check windowing and un-windowing operators, for loop inputs and outputs respectively.
1087        for (_edge_id, (pred_id, node_id)) in self.flat_graph.edges() {
1088            let Some(op_inst) = self.flat_graph.node_op_inst(node_id) else {
1089                continue;
1090            };
1091            let flo_type = &op_inst.op_constraints.flo_type;
1092
1093            let pred_loop_id = self.flat_graph.node_loop(pred_id);
1094            let loop_id = self.flat_graph.node_loop(node_id);
1095
1096            let span = self.flat_graph.node(node_id).span();
1097
1098            let (is_input, is_output) = {
1099                let parent_pred_loop_id =
1100                    pred_loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
1101                let parent_loop_id = loop_id.and_then(|lid| self.flat_graph.loop_parent(lid));
1102                let is_same = pred_loop_id == loop_id;
1103                let is_input = !is_same && parent_loop_id == pred_loop_id;
1104                let is_output = !is_same && parent_pred_loop_id == loop_id;
1105                if !(is_input || is_output || is_same) {
1106                    self.diagnostics.push(Diagnostic::spanned(
1107                        span,
1108                        Level::Error,
1109                        "Operator input edge may not cross multiple loop contexts.",
1110                    ));
1111                    continue;
1112                }
1113                (is_input, is_output)
1114            };
1115
1116            match flo_type {
1117                None => {
1118                    if is_input {
1119                        self.diagnostics.push(Diagnostic::spanned(
1120                            span,
1121                            Level::Error,
1122                            format!(
1123                                "Operator `{}(...)` entering a loop context must be a windowing operator, but is not.",
1124                                op_inst.op_constraints.name
1125                            )
1126                        ));
1127                    }
1128                    if is_output {
1129                        self.diagnostics.push(Diagnostic::spanned(
1130                            span,
1131                            Level::Error,
1132                            format!(
1133                                "Operator `{}(...)` exiting a loop context must be an un-windowing operator, but is not.",
1134                                op_inst.op_constraints.name
1135                            )
1136                        ));
1137                    }
1138                }
1139                Some(FloType::Windowing) => {
1140                    if !is_input {
1141                        self.diagnostics.push(Diagnostic::spanned(
1142                            span,
1143                            Level::Error,
1144                            format!(
1145                                "Windowing operator `{}(...)` must be the first input operator into a `loop {{ ... }} context.",
1146                                op_inst.op_constraints.name
1147                            )
1148                        ));
1149                    }
1150                }
1151                Some(FloType::Unwindowing) => {
1152                    if !is_output {
1153                        self.diagnostics.push(Diagnostic::spanned(
1154                            span,
1155                            Level::Error,
1156                            format!(
1157                                "Un-windowing operator `{}(...)` must be the first output operator after exiting a `loop {{ ... }} context.",
1158                                op_inst.op_constraints.name
1159                            )
1160                        ));
1161                    }
1162                }
1163                Some(FloType::NextIteration) => {
1164                    // Must be in a loop context.
1165                    if loop_id.is_none() {
1166                        self.diagnostics.push(Diagnostic::spanned(
1167                            span,
1168                            Level::Error,
1169                            format!(
1170                                "Operator `{}(...)` must be within a `loop {{ ... }}` context.",
1171                                op_inst.op_constraints.name
1172                            ),
1173                        ));
1174                    }
1175                }
1176                Some(FloType::Source) => {
1177                    // Handled above.
1178                }
1179            }
1180        }
1181
1182        // Must be a DAG (excluding `next_iteration()` operators).
1183        // TODO(mingwei): Nested loop blocks should count as a single node.
1184        // But this doesn't cause any correctness issues because the nested loops are also DAGs.
1185        for (loop_id, loop_nodes) in self.flat_graph.loops() {
1186            // Filter out `next_iteration()` operators.
1187            let filter_next_iteration = |&node_id: &GraphNodeId| {
1188                self.flat_graph
1189                    .node_op_inst(node_id)
1190                    .map(|op_inst| Some(FloType::NextIteration) != op_inst.op_constraints.flo_type)
1191                    .unwrap_or(true)
1192            };
1193
1194            let topo_sort_result = graph_algorithms::topo_sort(
1195                loop_nodes.iter().copied().filter(filter_next_iteration),
1196                |dst| {
1197                    self.flat_graph
1198                        .node_predecessor_nodes(dst)
1199                        .filter(|&src| Some(loop_id) == self.flat_graph.node_loop(src))
1200                        .filter(filter_next_iteration)
1201                },
1202            );
1203            if let Err(cycle) = topo_sort_result {
1204                let len = cycle.len();
1205                for (i, node_id) in cycle.into_iter().enumerate() {
1206                    let span = self.flat_graph.node(node_id).span();
1207                    self.diagnostics.push(Diagnostic::spanned(
1208                        span,
1209                        Level::Error,
1210                        format!(
1211                            "Operator forms an illegal cycle within a `loop {{ ... }}` block. Use `{}()` to pass data across loop iterations. ({}/{})",
1212                            NEXT_ITERATION.name,
1213                            i + 1,
1214                            len,
1215                        ),
1216                    ));
1217                }
1218            }
1219        }
1220    }
1221}