diff --git a/lib/queue-msg/src/optimize/passes.rs b/lib/queue-msg/src/optimize/passes.rs index fe98da41eb..fd32083350 100644 --- a/lib/queue-msg/src/optimize/passes.rs +++ b/lib/queue-msg/src/optimize/passes.rs @@ -210,7 +210,15 @@ impl PurePass for FlattenSeq { fn go(msg: QueueMsg) -> Vec> { match msg { QueueMsg::Sequence(new_seq) => new_seq.into_iter().flat_map(go).collect(), - QueueMsg::Concurrent(c) => vec![conc(c.into_iter().flat_map(go))], + QueueMsg::Concurrent(c) => vec![conc(c.into_iter().flat_map(|msg| { + let mut msgs = go(msg); + + match msgs.len() { + 0 => None, + 1 => Some(msgs.pop().unwrap()), + _ => Some(seq(msgs)), + } + }))], QueueMsg::Aggregate { queue, data, @@ -253,18 +261,16 @@ impl PurePass for FlattenConc { fn go(msg: QueueMsg) -> Vec> { match msg { QueueMsg::Concurrent(new_conc) => new_conc.into_iter().flat_map(go).collect(), - // wrap in conc again - // seq(conc(a.., conc(b..)), c..) == seq(conc(a.., b..), c..) - // seq(conc(a.., conc(b..)), c..) != seq(a.., b.., c..) - QueueMsg::Sequence(s) => vec![seq(s.into_iter().map(|msg| { + QueueMsg::Sequence(s) => vec![seq(s.into_iter().flat_map(|msg| { let mut msgs = go(msg); match msgs.len() { - // return the original empty sequence - 0 => seq([]), - // seq(a) == a - 1 => msgs.pop().unwrap(), - _ => conc(msgs), + 0 => None, + 1 => Some(msgs.pop().unwrap()), + // wrap in conc again + // seq(conc(a.., conc(b..)), c..) == seq(conc(a.., b..), c..) + // seq(conc(a.., conc(b..)), c..) != seq(a.., b.., c..) + _ => Some(conc(msgs)), } }))], QueueMsg::Aggregate { @@ -298,8 +304,10 @@ impl PurePass for FlattenConc { mod tests { use super::*; use crate::{ - data, effect, fetch, noop, - test_utils::{DataA, DataB, DataC, FetchA, PrintAbc, SimpleMessage}, + aggregate, data, defer_relative, effect, event, fetch, noop, + test_utils::{ + AggregatePrintAbc, DataA, DataB, DataC, FetchA, PrintAbc, SimpleEvent, SimpleMessage, + }, }; #[test] @@ -349,4 +357,72 @@ mod tests { assert_eq!(optimized.ready, expected_output); assert_eq!(optimized.optimize_further, []); } + + #[test] + fn seq_conc_conc() { + let msgs = vec![seq::([ + conc([ + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + conc([ + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + conc([ + repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])), + repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])), + // this seq is the only message that should be flattened + seq([ + effect(PrintAbc { + a: DataA {}, + b: DataB {}, + c: DataC {}, + }), + seq([ + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + ]), + ]), + ])]; + + let expected_output = vec![( + vec![0], + seq::([ + conc([ + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + conc([ + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + conc([ + repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])), + repeat(None, seq([event(SimpleEvent {}), defer_relative(10)])), + seq([ + effect(PrintAbc { + a: DataA {}, + b: DataB {}, + c: DataC {}, + }), + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + aggregate([], [], AggregatePrintAbc {}), + ]), + ]), + ]), + )]; + + let optimized = Normalize::default().run_pass_pure(msgs.clone()); + + assert_eq!(optimized.optimize_further, expected_output); + assert_eq!(optimized.ready, []); + + let optimized = NormalizeFinal::default().run_pass_pure(msgs); + assert_eq!(optimized.ready, expected_output); + assert_eq!(optimized.optimize_further, []); + } }