1use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{Expr, ExprPath, GenericArgument, Token, Type};
12
13use self::ops::{OperatorConstraints, Persistence};
14use crate::diagnostic::{Diagnostic, Diagnostics, Level};
15use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
16use crate::pretty_span::PrettySpan;
17
18mod di_mul_graph;
19mod eliminate_extra_unions_tees;
20mod flat_graph_builder;
21mod flat_to_partitioned;
22mod graph_write;
23mod meta_graph;
24mod meta_graph_debugging;
25
26use std::fmt::Display;
27
28pub use di_mul_graph::DiMulGraph;
29pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
30pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
31pub use flat_to_partitioned::partition_graph;
32pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
33
34pub use crate::graph_ids::{GraphEdgeId, GraphLoopId, GraphNodeId, GraphSubgraphId};
35
36pub mod graph_algorithms;
37pub mod ops;
38
39impl GraphSubgraphId {
40 pub fn as_ident(self, span: Span) -> Ident {
42 use slotmap::Key;
43 Ident::new(&format!("sgid_{:?}", self.data()), span)
44 }
45}
46
47impl GraphLoopId {
48 pub fn as_ident(self, span: Span) -> Ident {
50 use slotmap::Key;
51 Ident::new(&format!("loop_{:?}", self.data()), span)
52 }
53}
54
55const CONTEXT: &str = "context";
57const GRAPH: &str = "df";
59
60const HANDOFF_NODE_STR: &str = "handoff";
61const SINGLETON_SLOT_NODE_STR: &str = "singleton";
62const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
63
64mod serde_syn {
65 use serde::{Deserialize, Deserializer, Serializer};
66
67 pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
68 where
69 S: Serializer,
70 T: quote::ToTokens,
71 {
72 serializer.serialize_str(&value.to_token_stream().to_string())
73 }
74
75 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
76 where
77 D: Deserializer<'de>,
78 T: syn::parse::Parse,
79 {
80 let s = String::deserialize(deserializer)?;
81 syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
82 }
83}
84
85#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
89pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
90
91#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
93pub enum HandoffKind {
94 Vec,
96 Option,
98}
99
100#[derive(Clone, Serialize, Deserialize)]
102pub enum GraphNode {
103 Operator(#[serde(with = "serde_syn")] Operator),
105 Handoff {
107 kind: HandoffKind,
109 #[serde(skip, default = "Span::call_site")]
111 src_span: Span,
112 #[serde(skip, default = "Span::call_site")]
114 dst_span: Span,
115 },
116
117 ModuleBoundary {
119 input: bool,
121
122 #[serde(skip, default = "Span::call_site")]
126 import_expr: Span,
127 },
128}
129impl GraphNode {
130 pub fn to_pretty_string(&self) -> Cow<'static, str> {
132 match self {
133 GraphNode::Operator(op) => op.to_pretty_string().into(),
134 GraphNode::Handoff {
135 kind: HandoffKind::Vec,
136 ..
137 } => HANDOFF_NODE_STR.into(),
138 GraphNode::Handoff {
139 kind: HandoffKind::Option,
140 ..
141 } => SINGLETON_SLOT_NODE_STR.into(),
142 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
143 }
144 }
145
146 pub fn to_name_string(&self) -> Cow<'static, str> {
148 match self {
149 GraphNode::Operator(op) => op.name_string().into(),
150 GraphNode::Handoff {
151 kind: HandoffKind::Vec,
152 ..
153 } => HANDOFF_NODE_STR.into(),
154 GraphNode::Handoff {
155 kind: HandoffKind::Option,
156 ..
157 } => SINGLETON_SLOT_NODE_STR.into(),
158 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
159 }
160 }
161
162 pub fn span(&self) -> Span {
164 match self {
165 Self::Operator(op) => op.span(),
166 &Self::Handoff {
167 src_span, dst_span, ..
168 } => src_span.join(dst_span).unwrap_or(src_span),
169 Self::ModuleBoundary { import_expr, .. } => *import_expr,
170 }
171 }
172}
173impl std::fmt::Debug for GraphNode {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match self {
176 Self::Operator(operator) => {
177 write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
178 }
179 Self::Handoff { kind, .. } => write!(f, "Node::Handoff({kind:?})"),
180 Self::ModuleBoundary { input, .. } => {
181 write!(f, "Node::ModuleBoundary{{input: {}}}", input)
182 }
183 }
184 }
185}
186
187#[derive(Clone, Debug)]
196pub struct OperatorInstance {
197 pub op_constraints: &'static OperatorConstraints,
199 pub input_ports: Vec<PortIndexValue>,
201 pub output_ports: Vec<PortIndexValue>,
203 pub singletons_referenced: Vec<Ident>,
205
206 pub generics: OpInstGenerics,
208 pub arguments_pre: Punctuated<Expr, Token![,]>,
214 pub arguments_raw: TokenStream,
216}
217
218#[derive(Clone, Debug)]
220pub struct OpInstGenerics {
221 pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
223 pub persistence_args: Vec<Persistence>,
225 pub type_args: Vec<Type>,
227}
228
229impl OpInstGenerics {
230 fn join_spans<I>(mut spans: I) -> Option<Span>
235 where
236 I: Iterator<Item = Span>,
237 {
238 let mut span = spans.next()?;
239 for s in spans {
240 span = span.join(s)?;
241 }
242 Some(span)
243 }
244
245 pub fn persistence_args_span(&self) -> Option<Span> {
247 self.generic_args.as_ref().and_then(|args| {
248 Self::join_spans(
249 args.iter()
250 .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
251 .map(|a| a.span()),
252 )
253 })
254 }
255
256 pub fn type_args_span(&self) -> Option<Span> {
258 self.generic_args.as_ref().and_then(|args| {
259 Self::join_spans(
260 args.iter()
261 .filter(|a| matches!(a, GenericArgument::Type(_)))
262 .map(|a| a.span()),
263 )
264 })
265 }
266}
267
268pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
273 let generic_args = operator.type_arguments().cloned();
275 let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
276 GenericArgument::Lifetime(lifetime) => {
277 match &*lifetime.ident.to_string() {
278 "none" => Some(Persistence::None),
279 "loop" => Some(Persistence::Loop),
280 "tick" => Some(Persistence::Tick),
281 "static" => Some(Persistence::Static),
282 "mutable" => Some(Persistence::Mutable),
283 _ => {
284 diagnostics.push(Diagnostic::spanned(
285 generic_arg.span(),
286 Level::Error,
287 format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
288 ));
289 None
291 }
292 }
293 },
294 _ => None,
295 }).collect::<Vec<_>>();
296 let type_args = generic_args
297 .iter()
298 .flatten()
299 .skip(persistence_args.len())
300 .map_while(|generic_arg| match generic_arg {
301 GenericArgument::Type(typ) => Some(typ),
302 _ => None,
303 })
304 .cloned()
305 .collect::<Vec<_>>();
306
307 OpInstGenerics {
308 generic_args,
309 persistence_args,
310 type_args,
311 }
312}
313
314#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
316pub enum Color {
317 Pull,
319 Push,
321 Comp,
323 Hoff,
325}
326
327#[derive(Clone, Debug, Serialize, Deserialize)]
329pub enum PortIndexValue {
330 Int(#[serde(with = "serde_syn")] IndexInt),
332 Path(#[serde(with = "serde_syn")] ExprPath),
334 Elided(#[serde(skip)] Option<Span>),
337}
338impl PortIndexValue {
339 pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
342 where
343 Inner: Spanned,
344 {
345 let ported_span = Some(ported.inner.span());
346 let port_inn = ported
347 .inn
348 .map(|idx| idx.index.into())
349 .unwrap_or_else(|| Self::Elided(ported_span));
350 let inner = ported.inner;
351 let port_out = ported
352 .out
353 .map(|idx| idx.index.into())
354 .unwrap_or_else(|| Self::Elided(ported_span));
355 (port_inn, inner, port_out)
356 }
357
358 pub fn is_specified(&self) -> bool {
360 !matches!(self, Self::Elided(_))
361 }
362
363 #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
367 #[allow(
368 clippy::result_large_err,
369 reason = "variants are same size, error isn't to be propagated."
370 )]
371 pub fn combine(self, other: Self) -> Result<Self, Self> {
372 match (self.is_specified(), other.is_specified()) {
373 (false, _other) => Ok(other),
374 (true, false) => Ok(self),
375 (true, true) => Err(self),
376 }
377 }
378
379 pub fn as_error_message_string(&self) -> String {
381 match self {
382 PortIndexValue::Int(n) => format!("`{}`", n.value),
383 PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
384 PortIndexValue::Elided(_) => "<elided>".to_owned(),
385 }
386 }
387
388 pub fn span(&self) -> Span {
390 match self {
391 PortIndexValue::Int(x) => x.span(),
392 PortIndexValue::Path(x) => x.span(),
393 PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
394 }
395 }
396}
397impl From<PortIndex> for PortIndexValue {
398 fn from(value: PortIndex) -> Self {
399 match value {
400 PortIndex::Int(x) => Self::Int(x),
401 PortIndex::Path(x) => Self::Path(x),
402 }
403 }
404}
405impl PartialEq for PortIndexValue {
406 fn eq(&self, other: &Self) -> bool {
407 match (self, other) {
408 (Self::Int(l0), Self::Int(r0)) => l0 == r0,
409 (Self::Path(l0), Self::Path(r0)) => l0 == r0,
410 (Self::Elided(_), Self::Elided(_)) => true,
411 _else => false,
412 }
413 }
414}
415impl Eq for PortIndexValue {}
416impl PartialOrd for PortIndexValue {
417 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
418 Some(self.cmp(other))
419 }
420}
421impl Ord for PortIndexValue {
422 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
423 match (self, other) {
424 (Self::Int(s), Self::Int(o)) => s.cmp(o),
425 (Self::Path(s), Self::Path(o)) => s
426 .to_token_stream()
427 .to_string()
428 .cmp(&o.to_token_stream().to_string()),
429 (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
430 (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
431 (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
432 (_, Self::Elided(_)) => std::cmp::Ordering::Less,
433 (Self::Elided(_), _) => std::cmp::Ordering::Greater,
434 }
435 }
436}
437
438impl Display for PortIndexValue {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 match self {
441 PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
442 PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
443 PortIndexValue::Elided(_) => write!(f, "[]"),
444 }
445 }
446}
447
448pub struct BuildDfirCodeOutput {
450 pub partitioned_graph: DfirGraph,
452 pub code: TokenStream,
454 pub diagnostics: Diagnostics,
456}
457
458pub fn build_dfir_code(
460 dfir_code: DfirCode,
461 root: &TokenStream,
462) -> Result<BuildDfirCodeOutput, Diagnostics> {
463 let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
464
465 let FlatGraphBuilderOutput {
466 mut flat_graph,
467 uses,
468 mut diagnostics,
469 } = flat_graph_builder.build()?;
470
471 let () = match flat_graph.merge_modules() {
472 Ok(()) => (),
473 Err(d) => {
474 diagnostics.push(d);
475 return Err(diagnostics);
476 }
477 };
478
479 eliminate_extra_unions_tees(&mut flat_graph);
480
481 for (_loop_id, nodes) in flat_graph.loops() {
485 let span = nodes
486 .first()
487 .map_or_else(Span::call_site, |&n| flat_graph.node(n).span());
488 diagnostics.push(Diagnostic::spanned(
489 span,
490 Level::Error,
491 "`loop { }` blocks are not (yet) supported in `dfir_syntax!`.",
492 ));
493 }
494 if diagnostics.has_error() {
495 return Err(diagnostics);
496 }
497
498 let partitioned_graph = match partition_graph(flat_graph) {
499 Ok(partitioned_graph) => partitioned_graph,
500 Err(d) => {
501 diagnostics.push(d);
502 return Err(diagnostics);
503 }
504 };
505
506 let code =
507 partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
508
509 Ok(BuildDfirCodeOutput {
510 partitioned_graph,
511 code,
512 diagnostics,
513 })
514}
515
516fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
518 use proc_macro2::{Group, TokenTree};
519 tokens
520 .into_iter()
521 .map(|token| match token {
522 TokenTree::Group(mut group) => {
523 group.set_span(span);
524 TokenTree::Group(Group::new(
525 group.delimiter(),
526 change_spans(group.stream(), span),
527 ))
528 }
529 TokenTree::Ident(mut ident) => {
530 ident.set_span(span.resolved_at(ident.span()));
531 TokenTree::Ident(ident)
532 }
533 TokenTree::Punct(mut punct) => {
534 punct.set_span(span);
535 TokenTree::Punct(punct)
536 }
537 TokenTree::Literal(mut literal) => {
538 literal.set_span(span);
539 TokenTree::Literal(literal)
540 }
541 })
542 .collect()
543}