1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub flo_type: Option<FloType>,
73
74 pub ports_inn: Option<fn() -> PortListSpec>,
76 pub ports_out: Option<fn() -> PortListSpec>,
78
79 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
81 pub write_fn: WriteFn,
83}
84
85pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
87
88impl Debug for OperatorConstraints {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("OperatorConstraints")
91 .field("name", &self.name)
92 .field("hard_range_inn", &self.hard_range_inn)
93 .field("soft_range_inn", &self.soft_range_inn)
94 .field("hard_range_out", &self.hard_range_out)
95 .field("soft_range_out", &self.soft_range_out)
96 .field("num_args", &self.num_args)
97 .field("persistence_args", &self.persistence_args)
98 .field("type_args", &self.type_args)
99 .field("is_external_input", &self.is_external_input)
100 .field("ports_inn", &self.ports_inn)
101 .field("ports_out", &self.ports_out)
102 .finish()
106 }
107}
108
109#[derive(Default)]
113pub struct OperatorWriteOutput {
114 pub write_prologue: TokenStream,
117 pub write_iterator: TokenStream,
124 pub write_iterator_after: TokenStream,
126 pub write_tick_end: TokenStream,
129}
130
131pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
133pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
135pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
137
138pub fn identity_write_iterator_fn(
141 &WriteContextArgs {
142 root,
143 op_span,
144 ident,
145 inputs,
146 outputs,
147 is_pull,
148 op_inst:
149 OperatorInstance {
150 generics: OpInstGenerics { type_args, .. },
151 ..
152 },
153 ..
154 }: &WriteContextArgs,
155) -> TokenStream {
156 let generic_type = type_args
157 .first()
158 .map(quote::ToTokens::to_token_stream)
159 .unwrap_or(quote_spanned!(op_span=> _));
160
161 if is_pull {
162 let input = &inputs[0];
163 quote_spanned! {op_span=>
164 let #ident = {
165 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
166 where
167 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
168 {
169 pull
170 }
171 check_input::<_, #generic_type>(#input)
172 };
173 }
174 } else {
175 let output = &outputs[0];
176 quote_spanned! {op_span=>
177 let #ident = {
178 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
179 where
180 Psh: #root::dfir_pipes::push::Push<Item, ()>,
181 {
182 push
183 }
184 check_output::<_, #generic_type>(#output)
185 };
186 }
187 }
188}
189
190pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
192 let write_iterator = identity_write_iterator_fn(write_context_args);
193 Ok(OperatorWriteOutput {
194 write_iterator,
195 ..Default::default()
196 })
197};
198
199pub fn null_write_iterator_fn(
202 &WriteContextArgs {
203 root,
204 op_span,
205 ident,
206 inputs,
207 outputs,
208 is_pull,
209 op_inst:
210 OperatorInstance {
211 generics: OpInstGenerics { type_args, .. },
212 ..
213 },
214 ..
215 }: &WriteContextArgs,
216) -> TokenStream {
217 let default_type = parse_quote_spanned! {op_span=> _};
218 let iter_type = type_args.first().unwrap_or(&default_type);
219
220 if is_pull {
221 quote_spanned! {op_span=>
222 let #ident = #root::dfir_pipes::pull::poll_fn({
223 #(
224 let mut #inputs = ::std::boxed::Box::pin(#inputs);
225 )*
226 move |_cx| {
227 #(
231 let #inputs = #root::dfir_pipes::pull::Pull::pull(
232 ::std::pin::Pin::as_mut(&mut #inputs),
233 <_ as #root::dfir_pipes::Context>::from_task(_cx),
234 );
235 )*
236 #(
237 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
238 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
239 }
240 )*
241 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
242 }
243 });
244 }
245 } else {
246 quote_spanned! {op_span=>
247 #[allow(clippy::let_unit_value)]
248 let _ = (#(#outputs),*);
249 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
250 }
251 }
252}
253
254pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
257 let write_iterator = null_write_iterator_fn(write_context_args);
258 Ok(OperatorWriteOutput {
259 write_iterator,
260 ..Default::default()
261 })
262};
263
264macro_rules! declare_ops {
265 ( $( $mod:ident :: $op:ident, )* ) => {
266 $( pub(crate) mod $mod; )*
267 pub const OPERATORS: &[OperatorConstraints] = &[
269 $( $mod :: $op, )*
270 ];
271 };
272}
273declare_ops![
274 all_iterations::ALL_ITERATIONS,
275 all_once::ALL_ONCE,
276 anti_join::ANTI_JOIN,
277 assert::ASSERT,
278 assert_eq::ASSERT_EQ,
279 batch::BATCH,
280 chain::CHAIN,
281 chain_first_n::CHAIN_FIRST_N,
282 _counter::_COUNTER,
283 cross_join::CROSS_JOIN,
284 cross_join_multiset::CROSS_JOIN_MULTISET,
285 cross_singleton::CROSS_SINGLETON,
286 demux_enum::DEMUX_ENUM,
287 dest_file::DEST_FILE,
288 dest_sink::DEST_SINK,
289 dest_sink_serde::DEST_SINK_SERDE,
290 difference::DIFFERENCE,
291 enumerate::ENUMERATE,
292 filter::FILTER,
293 filter_map::FILTER_MAP,
294 flat_map::FLAT_MAP,
295 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
296 flatten::FLATTEN,
297 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
298 fold::FOLD,
299 fold_no_replay::FOLD_NO_REPLAY,
300 for_each::FOR_EACH,
301 identity::IDENTITY,
302 initialize::INITIALIZE,
303 inspect::INSPECT,
304 iter_ref::ITER_REF,
305 join::JOIN,
306 join_fused::JOIN_FUSED,
307 join_fused_lhs::JOIN_FUSED_LHS,
308 join_fused_rhs::JOIN_FUSED_RHS,
309 join_multiset::JOIN_MULTISET,
310 join_multiset_half::JOIN_MULTISET_HALF,
311 fold_keyed::FOLD_KEYED,
312 reduce_keyed::REDUCE_KEYED,
313 repeat_n::REPEAT_N,
314 lattice_bimorphism::LATTICE_BIMORPHISM,
316 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
317 lattice_fold::LATTICE_FOLD,
318 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
319 lattice_reduce::LATTICE_REDUCE,
320 map::MAP,
321 union::UNION,
322 multiset_delta::MULTISET_DELTA,
323 next_iteration::NEXT_ITERATION,
324 defer_signal::DEFER_SIGNAL,
325 defer_tick::DEFER_TICK,
326 defer_tick_lazy::DEFER_TICK_LAZY,
327 null::NULL,
328 partition::PARTITION,
329 persist::PERSIST,
330 persist_mut::PERSIST_MUT,
331 persist_mut_keyed::PERSIST_MUT_KEYED,
332 prefix::PREFIX,
333 resolve_futures::RESOLVE_FUTURES,
334 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
335 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
336 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
337 reduce::REDUCE,
338 reduce_no_replay::REDUCE_NO_REPLAY,
339 scan::SCAN,
340 scan_async_blocking::SCAN_ASYNC_BLOCKING,
341 spin::SPIN,
342 sort::SORT,
343 sort_by_key::SORT_BY_KEY,
344 source_file::SOURCE_FILE,
345 source_interval::SOURCE_INTERVAL,
346 source_iter::SOURCE_ITER,
347 source_json::SOURCE_JSON,
348 source_stdin::SOURCE_STDIN,
349 source_stream::SOURCE_STREAM,
350 source_stream_serde::SOURCE_STREAM_SERDE,
351 state::STATE,
352 state_by::STATE_BY,
353 tee::TEE,
354 unique::UNIQUE,
355 unzip::UNZIP,
356 zip::ZIP,
357 zip_longest::ZIP_LONGEST,
358];
359
360pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
362 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
363 OnceLock::new();
364 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
365}
366pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
368 if let GraphNode::Operator(operator) = node {
369 find_op_op_constraints(operator)
370 } else {
371 None
372 }
373}
374pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
376 let name = &*operator.name_string();
377 operator_lookup().get(name).copied()
378}
379
380#[derive(Clone)]
382pub struct WriteContextArgs<'a> {
383 pub root: &'a TokenStream,
385 pub context: &'a Ident,
388 pub df_ident: &'a Ident,
392 pub subgraph_id: GraphSubgraphId,
394 pub node_id: GraphNodeId,
396 pub loop_id: Option<GraphLoopId>,
398 pub op_span: Span,
400 pub op_tag: Option<String>,
402 pub work_fn: &'a Ident,
404 pub work_fn_async: &'a Ident,
406
407 pub ident: &'a Ident,
409 pub is_pull: bool,
411 pub inputs: &'a [Ident],
413 pub outputs: &'a [Ident],
415
416 pub op_name: &'static str,
418 pub op_inst: &'a OperatorInstance,
420 pub arguments: &'a Punctuated<Expr, Token![,]>,
426}
427impl WriteContextArgs<'_> {
428 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
434 Ident::new(
435 &format!(
436 "sg_{:?}_node_{:?}_{}",
437 self.subgraph_id.data(),
438 self.node_id.data(),
439 suffix.as_ref(),
440 ),
441 self.op_span,
442 )
443 }
444
445 pub fn persistence_args_disallow_mutable<const N: usize>(
447 &self,
448 diagnostics: &mut Diagnostics,
449 ) -> [Persistence; N] {
450 let len = self.op_inst.generics.persistence_args.len();
451 if 0 != len && 1 != len && N != len {
452 diagnostics.push(Diagnostic::spanned(
453 self.op_span,
454 Level::Error,
455 format!(
456 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
457 self.op_name, N
458 ),
459 ));
460 }
461
462 let default_persistence = if self.loop_id.is_some() {
463 Persistence::None
464 } else {
465 Persistence::Tick
466 };
467 let mut out = [default_persistence; N];
468 self.op_inst
469 .generics
470 .persistence_args
471 .iter()
472 .copied()
473 .cycle() .take(N)
475 .enumerate()
476 .filter(|&(_i, p)| {
477 if p == Persistence::Mutable {
478 diagnostics.push(Diagnostic::spanned(
479 self.op_span,
480 Level::Error,
481 format!(
482 "An implementation of `'{}` does not exist",
483 p.to_str_lowercase()
484 ),
485 ));
486 false
487 } else {
488 true
489 }
490 })
491 .for_each(|(i, p)| {
492 out[i] = p;
493 });
494 out
495 }
496}
497
498pub trait RangeTrait<T>: Send + Sync + Debug
500where
501 T: ?Sized,
502{
503 fn start_bound(&self) -> Bound<&T>;
505 fn end_bound(&self) -> Bound<&T>;
507 fn contains(&self, item: &T) -> bool
509 where
510 T: PartialOrd<T>;
511
512 fn human_string(&self) -> String
514 where
515 T: Display + PartialEq,
516 {
517 match (self.start_bound(), self.end_bound()) {
518 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
519
520 (Bound::Included(n), Bound::Included(x)) if n == x => {
521 format!("exactly {}", n)
522 }
523 (Bound::Included(n), Bound::Included(x)) => {
524 format!("at least {} and at most {}", n, x)
525 }
526 (Bound::Included(n), Bound::Excluded(x)) => {
527 format!("at least {} and less than {}", n, x)
528 }
529 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
530 (Bound::Excluded(n), Bound::Included(x)) => {
531 format!("more than {} and at most {}", n, x)
532 }
533 (Bound::Excluded(n), Bound::Excluded(x)) => {
534 format!("more than {} and less than {}", n, x)
535 }
536 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
537 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
538 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
539 }
540 }
541}
542
543impl<R, T> RangeTrait<T> for R
544where
545 R: RangeBounds<T> + Send + Sync + Debug,
546{
547 fn start_bound(&self) -> Bound<&T> {
548 self.start_bound()
549 }
550
551 fn end_bound(&self) -> Bound<&T> {
552 self.end_bound()
553 }
554
555 fn contains(&self, item: &T) -> bool
556 where
557 T: PartialOrd<T>,
558 {
559 self.contains(item)
560 }
561}
562
563#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
565pub enum Persistence {
566 None,
568 Loop,
570 Tick,
572 Static,
574 Mutable,
576}
577impl Persistence {
578 pub fn to_str_lowercase(self) -> &'static str {
580 match self {
581 Persistence::None => "none",
582 Persistence::Tick => "tick",
583 Persistence::Loop => "loop",
584 Persistence::Static => "static",
585 Persistence::Mutable => "mutable",
586 }
587 }
588}
589
590fn make_missing_runtime_msg(op_name: &str) -> Literal {
592 Literal::string(&format!(
593 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
594 op_name
595 ))
596}
597
598#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
600pub enum OperatorCategory {
601 Map,
603 Filter,
605 Flatten,
607 Fold,
609 KeyedFold,
611 LatticeFold,
613 Persistence,
615 MultiIn,
617 MultiOut,
619 Source,
621 Sink,
623 Control,
625 CompilerFusionOperator,
627 Windowing,
629 Unwindowing,
631}
632impl OperatorCategory {
633 pub fn name(self) -> &'static str {
635 self.get_variant_docs().split_once(":").unwrap().0
636 }
637 pub fn description(self) -> &'static str {
639 self.get_variant_docs().split_once(":").unwrap().1
640 }
641}
642
643#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
645pub enum FloType {
646 Source,
648 Windowing,
650 Unwindowing,
652 NextIteration,
654}