diff --git a/rib-core/src/expr.rs b/rib-core/src/expr.rs index e4b79c3..da368cb 100644 --- a/rib-core/src/expr.rs +++ b/rib-core/src/expr.rs @@ -22,8 +22,8 @@ use crate::rib_source_span::SourceSpan; use crate::rib_type_error::RibTypeErrorInternal; use crate::{ from_string, text, type_checker, type_inference, ComponentDependencies, ComponentDependencyKey, - CustomInstanceSpec, DynamicParsedFunctionName, ExprVisitor, GlobalVariableTypeSpec, - InferredType, InstanceIdentifier, VariableId, + CustomInstanceSpec, DynamicParsedFunctionName, GlobalVariableTypeSpec, InferredType, + InstanceIdentifier, VariableId, }; use crate::{IntoValueAndType, ValueAndType}; use bigdecimal::{BigDecimal, ToPrimitive}; @@ -1187,15 +1187,13 @@ impl Expr { } pub fn set_origin(&mut self) { - let mut visitor = ExprVisitor::bottom_up(self); - - while let Some(expr) = visitor.pop_front() { + type_inference::visit_post_order_mut(self, &mut |expr| { let source_location = expr.source_span(); let origin = TypeOrigin::OriginatedAt(source_location.clone()); let inferred_type = expr.inferred_type(); let origin = inferred_type.add_origin(origin); expr.with_inferred_type_mut(origin); - } + }); } // An inference is a single cycle of to-and-fro scanning of Rib expression, that it takes part in fix point of inference. @@ -1780,7 +1778,7 @@ impl Expr { } pub fn visit_expr_nodes_lazy<'a>(&'a mut self, queue: &mut VecDeque<&'a mut Expr>) { - type_inference::visit_expr_nodes_lazy(self, queue); + type_inference::collect_children_mut(self, queue); } pub fn number_inferred( @@ -2112,16 +2110,16 @@ impl Serialize for Expr { fn find_expr(expr: &mut Expr, source_span: &SourceSpan) -> Option { let mut expr = expr.clone(); + let mut found = None; - let mut visitor = ExprVisitor::bottom_up(&mut expr); - - while let Some(current) = visitor.pop_back() { - let span = current.source_span(); - - if source_span.eq(&span) { - return Some(current.clone()); + type_inference::visit_post_order_rev_mut(&mut expr, &mut |current| { + if found.is_none() { + let span = current.source_span(); + if source_span.eq(&span) { + found = Some(current.clone()); + } } - } + }); - None + found } diff --git a/rib-core/src/type_checker/exhaustive_pattern_match.rs b/rib-core/src/type_checker/exhaustive_pattern_match.rs index a50cc48..8ae1696 100644 --- a/rib-core/src/type_checker/exhaustive_pattern_match.rs +++ b/rib-core/src/type_checker/exhaustive_pattern_match.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::rib_source_span::SourceSpan; -use crate::{ArmPattern, ComponentDependencies, Expr, ExprVisitor}; +use crate::{try_visit_post_order_rev_mut, ArmPattern, ComponentDependencies, Expr}; // When checking exhaustive pattern match, there is no need to ensure // if the pattern aligns with conditions because those checks are done @@ -23,9 +23,7 @@ pub fn check_exhaustive_pattern_match( expr: &mut Expr, component_dependency: &ComponentDependencies, ) -> Result<(), ExhaustivePatternMatchError> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + try_visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::PatternMatch { match_arms, .. } = expr { let match_arm = match_arms .iter() @@ -33,9 +31,8 @@ pub fn check_exhaustive_pattern_match( .collect::>(); internal::check_exhaustive_pattern_match(expr, &match_arm, component_dependency)?; } - } - - Ok(()) + Ok(()) + }) } #[derive(Debug, Clone)] diff --git a/rib-core/src/type_checker/invalid_function_args.rs b/rib-core/src/type_checker/invalid_function_args.rs index 6e20c66..efcef09 100644 --- a/rib-core/src/type_checker/invalid_function_args.rs +++ b/rib-core/src/type_checker/invalid_function_args.rs @@ -15,8 +15,8 @@ use crate::analysis::AnalysedType; use crate::call_type::CallType; use crate::type_checker::missing_fields::find_missing_fields_in_record; +use crate::{try_visit_post_order_mut, Expr, FunctionCallError}; use crate::{type_checker, ComponentDependencies, FunctionName}; -use crate::{Expr, ExprVisitor, FunctionCallError}; // While we have a dedicated generic phases (refer submodules) within type_checker module, // we have this special phase to grab errors in the context function calls. @@ -27,21 +27,18 @@ pub fn check_invalid_function_args( expr: &mut Expr, component_dependency: &ComponentDependencies, ) -> Result<(), FunctionCallError> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_front() { + try_visit_post_order_mut(expr, &mut |expr| { if let Expr::Call { call_type, args, .. - } = &expr + } = &*expr { match call_type { CallType::InstanceCreation(_) => {} call_type => get_missing_record_keys(call_type, args, component_dependency, expr)?, } } - } - - Ok(()) + Ok(()) + }) } #[allow(clippy::result_large_err)] diff --git a/rib-core/src/type_checker/invalid_function_calls.rs b/rib-core/src/type_checker/invalid_function_calls.rs index 1785d18..6fe9c4f 100644 --- a/rib-core/src/type_checker/invalid_function_calls.rs +++ b/rib-core/src/type_checker/invalid_function_calls.rs @@ -13,13 +13,11 @@ // limitations under the License. use crate::call_type::CallType; -use crate::{Expr, ExprVisitor, FunctionCallError}; +use crate::{try_visit_post_order_mut, Expr, FunctionCallError}; #[allow(clippy::result_large_err)] pub fn check_invalid_function_calls(expr: &mut Expr) -> Result<(), FunctionCallError> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_front() { + try_visit_post_order_mut(expr, &mut |expr| { if let Expr::Call { call_type: CallType::Function { @@ -28,7 +26,7 @@ pub fn check_invalid_function_calls(expr: &mut Expr) -> Result<(), FunctionCallE .. }, .. - } = &expr + } = &*expr { if component_info.is_none() { return Err(FunctionCallError::InvalidFunctionCall { @@ -38,7 +36,6 @@ pub fn check_invalid_function_calls(expr: &mut Expr) -> Result<(), FunctionCallE }); } } - } - - Ok(()) + Ok(()) + }) } diff --git a/rib-core/src/type_checker/invalid_worker_name.rs b/rib-core/src/type_checker/invalid_worker_name.rs index 9756770..d865334 100644 --- a/rib-core/src/type_checker/invalid_worker_name.rs +++ b/rib-core/src/type_checker/invalid_worker_name.rs @@ -13,13 +13,11 @@ // limitations under the License. use crate::call_type::{CallType, InstanceCreationType}; -use crate::{Expr, ExprVisitor, InvalidWorkerName}; +use crate::{try_visit_post_order_rev_mut, Expr, InvalidWorkerName}; // Capture all worker name and see if they are resolved to a string type pub fn check_invalid_worker_name(expr: &mut Expr) -> Result<(), InvalidWorkerName> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + try_visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Call { call_type, .. } = expr { match call_type { CallType::InstanceCreation(InstanceCreationType::WitWorker { @@ -46,9 +44,8 @@ pub fn check_invalid_worker_name(expr: &mut Expr) -> Result<(), InvalidWorkerNam } } } - } - - Ok(()) + Ok(()) + }) } mod internal { diff --git a/rib-core/src/type_inference/enum_inference.rs b/rib-core/src/type_inference/enum_inference.rs index 3cadb21..d4d25e4 100644 --- a/rib-core/src/type_inference/enum_inference.rs +++ b/rib-core/src/type_inference/enum_inference.rs @@ -23,7 +23,7 @@ pub fn infer_enums(expr: &mut Expr, component_dependencies: &ComponentDependenci mod internal { use crate::analysis::AnalysedType; use crate::call_type::CallType; - use crate::{ComponentDependencies, Expr, ExprVisitor}; + use crate::{visit_post_order_rev_mut, ComponentDependencies, Expr}; pub(crate) fn convert_identifiers_to_enum_function_calls( expr: &mut Expr, @@ -31,9 +31,7 @@ mod internal { ) { let enum_cases = enum_info.clone(); - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, @@ -52,7 +50,7 @@ mod internal { }; } } - } + }); } pub(crate) fn get_enum_info( @@ -60,9 +58,7 @@ mod internal { component_dependency: &ComponentDependencies, ) -> EnumInfo { let mut enum_cases = vec![]; - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, @@ -84,7 +80,7 @@ mod internal { } } } - } + }); EnumInfo { cases: enum_cases } } diff --git a/rib-core/src/type_inference/expr_visitor.rs b/rib-core/src/type_inference/expr_visitor.rs index e2fd058..03072ef 100644 --- a/rib-core/src/type_inference/expr_visitor.rs +++ b/rib-core/src/type_inference/expr_visitor.rs @@ -1,618 +1,607 @@ use crate::{Expr, TypeInternal}; use std::collections::VecDeque; -// A structure that allows to visit expressions in a bottom-up or top-down order. -// All other functionalities are to be replaced with the usage of `ExprVisitor` -// https://github.com/golemcloud/golem/issues/1428 -pub struct ExprVisitor<'a> { - queue: VecDeque<&'a mut Expr>, +// Post-order (bottom-up): children first, then parent. Left-to-right. +pub fn visit_post_order_mut(expr: &mut Expr, f: &mut impl FnMut(&mut Expr)) { + visit_children_mut(expr, |child| visit_post_order_mut(child, f)); + f(expr); } -impl Default for ExprVisitor<'_> { - fn default() -> Self { - Self::new() - } +// Pre-order reversed: parent first, then children right-to-left. +// This matches the old ExprVisitor::bottom_up + pop_back() semantics. +pub fn visit_post_order_rev_mut(expr: &mut Expr, f: &mut impl FnMut(&mut Expr)) { + f(expr); + visit_children_rev_mut(expr, |child| visit_post_order_rev_mut(child, f)); } -impl<'a> ExprVisitor<'a> { - pub fn new() -> Self { - ExprVisitor { - queue: VecDeque::new(), - } - } - - pub fn is_empty(&self) -> bool { - self.queue.is_empty() - } - - // Enqueue expressions in a bottom-up order, - // but in the natural order of rib program - // Given - // `Expr::Block(Expr::Let(x, Expr::Num(1)), Expr::Call(func, x))` - // Expr::Num(1) - // Expr::Let(Variable(x), Expr::Num(1)) - // Expr::Identifier(x) - // Expr::Call(func, Expr::Identifier(x)) - // Expr::Block(Expr::Let(x, Expr::Num(1)), Expr::Call(func, x)) - pub fn bottom_up(expr: &'a mut Expr) -> Self { - let mut queue: VecDeque<&'a mut Expr> = VecDeque::new(); - - enqueue_expr_bottom_up(expr, &mut queue); - - ExprVisitor { queue } - } - - // Enqueue expressions in a top-down order, - // while processing the expressions in the natural order within the block (Expr::Block). - // Given - // `Expr::Block(Expr::Let(x, Expr::Num(1)), Expr::Call(func, x))` - // Expr::Block(Expr::Let(x, Expr::Num(1)), Expr::Call(func, x)) - // Expr::Let(Variable(x), Expr::Num(1, U64)) - // Expr::Num(1, U64) - // Expr::Call(func, Expr::Identifier(x)) - // Expr::Identifier(x) - pub fn top_down(expr: &'a mut Expr) -> Self { - let mut queue: VecDeque<&'a mut Expr> = VecDeque::new(); - - enqueue_expr_top_down(expr, &mut queue); - - ExprVisitor { queue } - } - - pub fn pop_front(&mut self) -> Option<&mut Expr> { - self.queue.pop_front() - } - - pub fn pop_back(&mut self) -> Option<&mut Expr> { - self.queue.pop_back() - } - - pub fn pop_all(&mut self) -> Vec<&mut Expr> { - self.queue.drain(..).collect() - } +// Pre-order (top-down): parent first, then children left-to-right. +pub fn visit_pre_order_mut(expr: &mut Expr, f: &mut impl FnMut(&mut Expr)) { + f(expr); + visit_children_mut(expr, |child| visit_pre_order_mut(child, f)); } -fn enqueue_expr_top_down(expr: &mut Expr, queue: &mut VecDeque<&mut Expr>) { - let mut stack: VecDeque<*mut Expr> = VecDeque::new(); - - stack.push_back(expr); - - while let Some(current_ptr) = stack.pop_front() { - queue.push_back(unsafe { &mut *current_ptr }); - - let current = unsafe { &mut *current_ptr }; - - match current { - Expr::Let { expr, .. } => stack.push_front(&mut **expr), - Expr::SelectField { expr, .. } => stack.push_front(&mut **expr), - Expr::SelectIndex { expr, index, .. } => { - stack.push_front(&mut **expr); - stack.push_front(&mut **index); - } - Expr::Sequence { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Record { exprs, .. } => { - stack.extend(exprs.iter_mut().map(|(_, expr)| &mut **expr as *mut Expr)) - } - Expr::Tuple { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Concat { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::ExprBlock { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Not { expr, .. } => stack.push_front(&mut **expr), - Expr::Length { expr, .. } => stack.push_front(&mut **expr), - Expr::GreaterThan { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::GreaterThanOrEqualTo { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::LessThanOrEqualTo { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::EqualTo { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::Plus { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::Minus { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::Divide { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::Multiply { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::LessThan { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::Cond { cond, lhs, rhs, .. } => { - stack.push_front(&mut **cond); - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs); - } - Expr::PatternMatch { - predicate, - match_arms, - .. - } => { - stack.push_front(&mut **predicate); - for arm in match_arms { - let arm_literal_expressions = arm.arm_pattern.get_expr_literals_mut(); - - for x in arm_literal_expressions { - let x = x.as_mut(); - stack.push_front(x); - } - - stack.push_front(&mut *arm.arm_resolution_expr); - } - } - - Expr::Range { range, .. } => { - for expr in range.get_exprs_mut() { - stack.push_front(&mut **expr); - } - } +// Fallible post-order: children first, then parent. Stops on first error. +pub fn try_visit_post_order_mut( + expr: &mut Expr, + f: &mut impl FnMut(&mut Expr) -> Result<(), E>, +) -> Result<(), E> { + try_visit_children_mut(expr, |child| try_visit_post_order_mut(child, f))?; + f(expr) +} - Expr::Option { - expr: Some(expr), .. - } => stack.push_front(&mut **expr), - Expr::Result { expr: Ok(expr), .. } => stack.push_front(&mut **expr), - Expr::Result { - expr: Err(expr), .. - } => stack.push_front(&mut **expr), - Expr::Call { - call_type, - args, - inferred_type, - .. - } => { - let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); - if let Some(exprs) = exprs { - for x in exprs { - stack.push_front(x); - } - } +// Fallible pre-order reversed: parent first, then children right-to-left. +pub fn try_visit_post_order_rev_mut( + expr: &mut Expr, + f: &mut impl FnMut(&mut Expr) -> Result<(), E>, +) -> Result<(), E> { + f(expr)?; + try_visit_children_rev_mut(expr, |child| try_visit_post_order_rev_mut(child, f)) +} - if let Some(worker) = worker { - stack.push_front(&mut **worker); - } +// Fallible pre-order: parent first, then children left-to-right. +pub fn try_visit_pre_order_mut( + expr: &mut Expr, + f: &mut impl FnMut(&mut Expr) -> Result<(), E>, +) -> Result<(), E> { + f(expr)?; + try_visit_children_mut(expr, |child| try_visit_pre_order_mut(child, f)) +} - // The expr existing in the inferred type should be visited - if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { - if let Some(worker_expr) = instance_type.worker_mut() { - stack.push_front(&mut **worker_expr); - } - } +// Immutable post-order traversal. +pub fn visit_post_order<'a>(expr: &'a Expr, f: &mut impl FnMut(&'a Expr)) { + visit_children(expr, |child| visit_post_order(child, f)); + f(expr); +} - for x in args { - stack.push_front(x); +// Collect immediate children into a queue (used by visit_expr_nodes_lazy). +pub fn collect_children_mut<'a>(expr: &'a mut Expr, queue: &mut VecDeque<&'a mut Expr>) { + match expr { + Expr::Let { expr, .. } => queue.push_back(expr), + Expr::SelectField { expr, .. } => queue.push_back(expr), + Expr::SelectIndex { expr, index, .. } => { + queue.push_back(&mut *expr); + queue.push_back(&mut *index); + } + Expr::Sequence { exprs, .. } + | Expr::Tuple { exprs, .. } + | Expr::Concat { exprs, .. } + | Expr::ExprBlock { exprs, .. } => queue.extend(exprs.iter_mut()), + Expr::Record { exprs, .. } => queue.extend(exprs.iter_mut().map(|(_, expr)| &mut **expr)), + Expr::Not { expr, .. } | Expr::Length { expr, .. } | Expr::Unwrap { expr, .. } => { + queue.push_back(expr) + } + Expr::GreaterThan { lhs, rhs, .. } + | Expr::GreaterThanOrEqualTo { lhs, rhs, .. } + | Expr::LessThanOrEqualTo { lhs, rhs, .. } + | Expr::EqualTo { lhs, rhs, .. } + | Expr::Plus { lhs, rhs, .. } + | Expr::Minus { lhs, rhs, .. } + | Expr::Divide { lhs, rhs, .. } + | Expr::Multiply { lhs, rhs, .. } + | Expr::LessThan { lhs, rhs, .. } + | Expr::And { lhs, rhs, .. } + | Expr::Or { lhs, rhs, .. } => { + queue.push_back(lhs); + queue.push_back(rhs); + } + Expr::Cond { cond, lhs, rhs, .. } => { + queue.push_back(cond); + queue.push_back(lhs); + queue.push_back(rhs); + } + Expr::PatternMatch { + predicate, + match_arms, + .. + } => { + queue.push_back(&mut *predicate); + for arm in match_arms { + for lit in arm.arm_pattern.get_expr_literals_mut() { + queue.push_back(lit.as_mut()); } + queue.push_back(&mut arm.arm_resolution_expr); } - Expr::Unwrap { expr, .. } => stack.push_front(&mut **expr), // not yet needed - Expr::And { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs) - } - - Expr::Or { lhs, rhs, .. } => { - stack.push_front(&mut **lhs); - stack.push_front(&mut **rhs) + } + Expr::Range { range, .. } => { + for e in range.get_exprs_mut() { + queue.push_back(&mut *e); } - - Expr::ListComprehension { - iterable_expr, - yield_expr, - .. - } => { - stack.push_front(&mut **iterable_expr); - stack.push_front(&mut **yield_expr); + } + Expr::Option { + expr: Some(expr), .. + } => queue.push_back(expr), + Expr::Result { expr: Ok(expr), .. } => queue.push_back(expr), + Expr::Result { + expr: Err(expr), .. + } => queue.push_back(expr), + Expr::Call { + call_type, + args, + inferred_type, + .. + } => { + let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); + if let Some(exprs) = exprs { + queue.extend(exprs.iter_mut()) } - - Expr::ListReduce { - iterable_expr, - init_value_expr, - yield_expr, - .. - } => { - stack.push_front(&mut **iterable_expr); - stack.push_front(&mut **init_value_expr); - stack.push_front(&mut **yield_expr); + if let Some(worker) = worker { + queue.push_back(worker); } - - Expr::InvokeMethodLazy { - lhs, - args, - inferred_type, - .. - } => { - if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { - if let Some(worker_expr) = instance_type.worker_mut() { - stack.push_front(&mut **worker_expr); - } + if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { + if let Some(worker_expr) = instance_type.worker_mut() { + queue.push_back(worker_expr); } - - stack.push_front(&mut **lhs); - stack.extend(args.iter_mut().map(|x| x as *mut Expr)); } - - Expr::GetTag { expr, .. } => { - stack.push_front(&mut **expr); + queue.extend(args.iter_mut()) + } + Expr::ListComprehension { + iterable_expr, + yield_expr, + .. + } => { + queue.push_back(iterable_expr); + queue.push_back(yield_expr); + } + Expr::ListReduce { + iterable_expr, + init_value_expr, + yield_expr, + .. + } => { + queue.push_back(iterable_expr); + queue.push_back(init_value_expr); + queue.push_back(yield_expr); + } + Expr::InvokeMethodLazy { + lhs, + args, + inferred_type, + .. + } => { + if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { + if let Some(worker_expr) = instance_type.worker_mut() { + queue.push_back(worker_expr); + } } - - Expr::Literal { .. } => {} - Expr::Number { .. } => {} - Expr::Flags { .. } => {} - Expr::Identifier { .. } => {} - Expr::Boolean { .. } => {} - Expr::Option { expr: None, .. } => {} - Expr::Throw { .. } => {} - Expr::GenerateWorkerName { .. } => {} + queue.push_back(lhs); + queue.extend(args.iter_mut()); } + Expr::GetTag { expr, .. } => queue.push_back(expr), + Expr::Literal { .. } + | Expr::Number { .. } + | Expr::Flags { .. } + | Expr::Identifier { .. } + | Expr::Boolean { .. } + | Expr::Option { expr: None, .. } + | Expr::Throw { .. } + | Expr::GenerateWorkerName { .. } => {} } } -fn enqueue_expr_bottom_up(expr: &mut Expr, queue: &mut VecDeque<&mut Expr>) { - let mut stack: VecDeque<*mut Expr> = VecDeque::new(); - - stack.push_back(expr); - - while let Some(current) = stack.pop_back() { - queue.push_front(unsafe { &mut *current }); +// --- Core child iteration --- - let current = unsafe { &mut *current }; - - match &mut *current { - Expr::Let { expr, .. } => stack.push_back(&mut **expr), - Expr::SelectField { expr, .. } => stack.push_back(&mut **expr), - Expr::SelectIndex { expr, index, .. } => { - stack.push_back(&mut **expr); - stack.push_back(&mut **index); - } - Expr::Sequence { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Record { exprs, .. } => { - stack.extend(exprs.iter_mut().map(|(_, expr)| &mut **expr as *mut Expr)) - } - Expr::Tuple { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Concat { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::ExprBlock { exprs, .. } => stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)), - Expr::Not { expr, .. } => stack.push_back(&mut **expr), - Expr::Length { expr, .. } => stack.push_back(&mut **expr), - Expr::GreaterThan { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::GreaterThanOrEqualTo { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::LessThanOrEqualTo { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::EqualTo { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::Plus { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::Minus { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::Divide { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::Multiply { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::LessThan { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::Cond { cond, lhs, rhs, .. } => { - stack.push_back(&mut **cond); - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs); - } - Expr::PatternMatch { - predicate, - match_arms, - .. - } => { - stack.push_back(&mut **predicate); - for arm in match_arms { - let arm_literal_expressions = arm.arm_pattern.get_expr_literals_mut(); - stack.extend(arm_literal_expressions.into_iter().map(|x| { - let x = x.as_mut(); - x as *mut Expr - })); - stack.push_back(&mut *arm.arm_resolution_expr); - } +fn visit_children_mut(expr: &mut Expr, mut each: impl FnMut(&mut Expr)) { + match expr { + Expr::Let { expr, .. } => each(expr), + Expr::SelectField { expr, .. } => each(expr), + Expr::SelectIndex { expr, index, .. } => { + each(expr); + each(index); + } + Expr::Sequence { exprs, .. } + | Expr::Tuple { exprs, .. } + | Expr::Concat { exprs, .. } + | Expr::ExprBlock { exprs, .. } => { + for e in exprs { + each(e); } - - Expr::Range { range, .. } => { - for expr in range.get_exprs_mut() { - stack.push_back(&mut **expr); - } + } + Expr::Record { exprs, .. } => { + for (_, e) in exprs { + each(e); } - - Expr::Option { - expr: Some(expr), .. - } => stack.push_back(&mut **expr), - Expr::Result { expr: Ok(expr), .. } => stack.push_back(&mut **expr), - Expr::Result { - expr: Err(expr), .. - } => stack.push_back(&mut **expr), - Expr::Call { - call_type, - args, - inferred_type, - .. - } => { - let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); - if let Some(exprs) = exprs { - stack.extend(exprs.iter_mut().map(|x| x as *mut Expr)) - } - - if let Some(worker) = worker { - stack.push_back(&mut **worker); - } - - // The expr existing in the inferred type should be visited - if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { - if let Some(worker_expr) = instance_type.worker_mut() { - stack.push_back(&mut **worker_expr); - } + } + Expr::Not { expr, .. } | Expr::Length { expr, .. } | Expr::Unwrap { expr, .. } => { + each(expr) + } + Expr::GreaterThan { lhs, rhs, .. } + | Expr::GreaterThanOrEqualTo { lhs, rhs, .. } + | Expr::LessThanOrEqualTo { lhs, rhs, .. } + | Expr::EqualTo { lhs, rhs, .. } + | Expr::Plus { lhs, rhs, .. } + | Expr::Minus { lhs, rhs, .. } + | Expr::Divide { lhs, rhs, .. } + | Expr::Multiply { lhs, rhs, .. } + | Expr::LessThan { lhs, rhs, .. } + | Expr::And { lhs, rhs, .. } + | Expr::Or { lhs, rhs, .. } => { + each(lhs); + each(rhs); + } + Expr::Cond { cond, lhs, rhs, .. } => { + each(cond); + each(lhs); + each(rhs); + } + Expr::PatternMatch { + predicate, + match_arms, + .. + } => { + each(predicate); + for arm in match_arms { + for lit in arm.arm_pattern.get_expr_literals_mut() { + each(lit.as_mut()); } - - stack.extend(args.iter_mut().map(|x| x as *mut Expr)) + each(&mut arm.arm_resolution_expr); } - Expr::Unwrap { expr, .. } => stack.push_back(&mut **expr), - Expr::And { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs) + } + Expr::Range { range, .. } => { + for e in range.get_exprs_mut() { + each(&mut *e); } - - Expr::Or { lhs, rhs, .. } => { - stack.push_back(&mut **lhs); - stack.push_back(&mut **rhs) + } + Expr::Option { + expr: Some(expr), .. + } => each(expr), + Expr::Result { expr: Ok(expr), .. } => each(expr), + Expr::Result { + expr: Err(expr), .. + } => each(expr), + Expr::Call { + call_type, + args, + inferred_type, + .. + } => { + let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); + if let Some(exprs) = exprs { + for e in exprs { + each(e); + } } - - Expr::ListComprehension { - iterable_expr, - yield_expr, - .. - } => { - stack.push_back(&mut **iterable_expr); - stack.push_back(&mut **yield_expr); + if let Some(worker) = worker { + each(worker); } - - Expr::ListReduce { - iterable_expr, - init_value_expr, - yield_expr, - .. - } => { - stack.push_back(&mut **iterable_expr); - stack.push_back(&mut **init_value_expr); - stack.push_back(&mut **yield_expr); + if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { + if let Some(worker_expr) = instance_type.worker_mut() { + each(worker_expr); + } } - - Expr::InvokeMethodLazy { - lhs, - args, - inferred_type, - .. - } => { - if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { - if let Some(worker_expr) = instance_type.worker_mut() { - stack.push_back(&mut **worker_expr); - } + for arg in args { + each(arg); + } + } + Expr::ListComprehension { + iterable_expr, + yield_expr, + .. + } => { + each(iterable_expr); + each(yield_expr); + } + Expr::ListReduce { + iterable_expr, + init_value_expr, + yield_expr, + .. + } => { + each(iterable_expr); + each(init_value_expr); + each(yield_expr); + } + Expr::InvokeMethodLazy { + lhs, + args, + inferred_type, + .. + } => { + if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { + if let Some(worker_expr) = instance_type.worker_mut() { + each(worker_expr); } - - stack.push_back(&mut **lhs); - stack.extend(args.iter_mut().map(|x| x as *mut Expr)); } - - Expr::GetTag { expr, .. } => { - stack.push_back(&mut **expr); + each(lhs); + for arg in args { + each(arg); } - - Expr::Literal { .. } => {} - Expr::Number { .. } => {} - Expr::Flags { .. } => {} - Expr::Identifier { .. } => {} - Expr::Boolean { .. } => {} - Expr::Option { expr: None, .. } => {} - Expr::Throw { .. } => {} - Expr::GenerateWorkerName { .. } => {} } + Expr::GetTag { expr, .. } => each(expr), + Expr::Literal { .. } + | Expr::Number { .. } + | Expr::Flags { .. } + | Expr::Identifier { .. } + | Expr::Boolean { .. } + | Expr::Option { expr: None, .. } + | Expr::Throw { .. } + | Expr::GenerateWorkerName { .. } => {} } } -// This is almost a lazy visit, that we don't put the expr into the queue -// unless it is needed. To a great extent both ExprVisitor and this function -// can be used instead of each other, but depending on situations one can perform better -// over the other. -pub fn visit_expr_nodes_lazy<'a>(expr: &'a mut Expr, queue: &mut VecDeque<&'a mut Expr>) { +fn visit_children_rev_mut(expr: &mut Expr, mut each: impl FnMut(&mut Expr)) { match expr { - Expr::Let { expr, .. } => queue.push_back(&mut *expr), - Expr::SelectField { expr, .. } => queue.push_back(&mut *expr), + Expr::Let { expr, .. } => each(expr), + Expr::SelectField { expr, .. } => each(expr), Expr::SelectIndex { expr, index, .. } => { - queue.push_back(&mut *expr); - queue.push_back(&mut *index); + each(index); + each(expr); } - Expr::Sequence { exprs, .. } => queue.extend(exprs.iter_mut()), - Expr::Record { exprs, .. } => queue.extend(exprs.iter_mut().map(|(_, expr)| &mut **expr)), - Expr::Tuple { exprs, .. } => queue.extend(exprs.iter_mut()), - Expr::Concat { exprs, .. } => queue.extend(exprs.iter_mut()), - Expr::ExprBlock { exprs, .. } => queue.extend(exprs.iter_mut()), // let x = 1, y = call(x); - Expr::Not { expr, .. } => queue.push_back(&mut *expr), - Expr::Length { expr, .. } => queue.push_back(&mut *expr), - Expr::GreaterThan { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::GreaterThanOrEqualTo { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::LessThanOrEqualTo { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::EqualTo { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::Plus { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::Minus { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::Divide { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::Multiply { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); - } - Expr::LessThan { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); + Expr::Sequence { exprs, .. } + | Expr::Tuple { exprs, .. } + | Expr::Concat { exprs, .. } + | Expr::ExprBlock { exprs, .. } => { + for e in exprs.iter_mut().rev() { + each(e); + } + } + Expr::Record { exprs, .. } => { + for (_, e) in exprs.iter_mut().rev() { + each(e); + } + } + Expr::Not { expr, .. } | Expr::Length { expr, .. } | Expr::Unwrap { expr, .. } => { + each(expr) + } + Expr::GreaterThan { lhs, rhs, .. } + | Expr::GreaterThanOrEqualTo { lhs, rhs, .. } + | Expr::LessThanOrEqualTo { lhs, rhs, .. } + | Expr::EqualTo { lhs, rhs, .. } + | Expr::Plus { lhs, rhs, .. } + | Expr::Minus { lhs, rhs, .. } + | Expr::Divide { lhs, rhs, .. } + | Expr::Multiply { lhs, rhs, .. } + | Expr::LessThan { lhs, rhs, .. } + | Expr::And { lhs, rhs, .. } + | Expr::Or { lhs, rhs, .. } => { + each(rhs); + each(lhs); } Expr::Cond { cond, lhs, rhs, .. } => { - queue.push_back(&mut *cond); - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs); + each(rhs); + each(lhs); + each(cond); } Expr::PatternMatch { predicate, match_arms, .. } => { - queue.push_back(&mut *predicate); - for arm in match_arms { - let arm_literal_expressions = arm.arm_pattern.get_expr_literals_mut(); - queue.extend(arm_literal_expressions.into_iter().map(|x| x.as_mut())); - queue.push_back(&mut *arm.arm_resolution_expr); + for arm in match_arms.iter_mut().rev() { + each(&mut arm.arm_resolution_expr); + for lit in arm.arm_pattern.get_expr_literals_mut().into_iter().rev() { + each(lit.as_mut()); + } } + each(predicate); } - Expr::Range { range, .. } => { - for expr in range.get_exprs_mut() { - queue.push_back(&mut *expr); + let mut exprs = range.get_exprs_mut(); + exprs.reverse(); + for e in exprs { + each(&mut *e); } } - Expr::Option { expr: Some(expr), .. - } => queue.push_back(&mut *expr), - Expr::Result { expr: Ok(expr), .. } => queue.push_back(&mut *expr), + } => each(expr), + Expr::Result { expr: Ok(expr), .. } => each(expr), Expr::Result { expr: Err(expr), .. - } => queue.push_back(&mut *expr), + } => each(expr), Expr::Call { call_type, args, inferred_type, .. } => { - let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); - if let Some(exprs) = exprs { - queue.extend(exprs.iter_mut()) + for arg in args.iter_mut().rev() { + each(arg); } - - if let Some(worker) = worker { - queue.push_back(worker); - } - - // The expr existing in the inferred type should be visited if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { if let Some(worker_expr) = instance_type.worker_mut() { - queue.push_back(worker_expr); + each(worker_expr); + } + } + let (exprs, worker) = internal::get_expressions_in_call_type_mut(call_type); + if let Some(worker) = worker { + each(worker); + } + if let Some(exprs) = exprs { + for e in exprs.iter_mut().rev() { + each(e); } } - - queue.extend(args.iter_mut()) - } - Expr::Unwrap { expr, .. } => queue.push_back(&mut *expr), // not yet needed - Expr::And { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs) - } - - Expr::Or { lhs, rhs, .. } => { - queue.push_back(&mut *lhs); - queue.push_back(&mut *rhs) } - Expr::ListComprehension { iterable_expr, yield_expr, .. } => { - queue.push_back(&mut *iterable_expr); - queue.push_back(&mut *yield_expr); + each(yield_expr); + each(iterable_expr); } - Expr::ListReduce { iterable_expr, init_value_expr, yield_expr, .. } => { - queue.push_back(iterable_expr); - queue.push_back(init_value_expr); - queue.push_back(yield_expr); + each(yield_expr); + each(init_value_expr); + each(iterable_expr); } - Expr::InvokeMethodLazy { lhs, args, inferred_type, .. } => { + for arg in args.iter_mut().rev() { + each(arg); + } + each(lhs); if let TypeInternal::Instance { instance_type } = inferred_type.inner.as_mut() { if let Some(worker_expr) = instance_type.worker_mut() { - queue.push_back(worker_expr); + each(worker_expr); } } + } + Expr::GetTag { expr, .. } => each(expr), + Expr::Literal { .. } + | Expr::Number { .. } + | Expr::Flags { .. } + | Expr::Identifier { .. } + | Expr::Boolean { .. } + | Expr::Option { expr: None, .. } + | Expr::Throw { .. } + | Expr::GenerateWorkerName { .. } => {} + } +} - queue.push_back(lhs); - queue.extend(args.iter_mut()); +fn try_visit_children_mut( + expr: &mut Expr, + mut each: impl FnMut(&mut Expr) -> Result<(), E>, +) -> Result<(), E> { + // We need to propagate errors properly. Use a Cell to smuggle the + // error out of the infallible closure interface. + let mut err: Option = None; + visit_children_mut(expr, |child| { + if err.is_none() { + if let Err(e) = each(child) { + err = Some(e); + } } + }); + match err { + Some(e) => Err(e), + None => Ok(()), + } +} - Expr::GetTag { expr, .. } => { - queue.push_back(&mut *expr); +fn try_visit_children_rev_mut( + expr: &mut Expr, + mut each: impl FnMut(&mut Expr) -> Result<(), E>, +) -> Result<(), E> { + let mut err: Option = None; + visit_children_rev_mut(expr, |child| { + if err.is_none() { + if let Err(e) = each(child) { + err = Some(e); + } } + }); + match err { + Some(e) => Err(e), + None => Ok(()), + } +} - Expr::Literal { .. } => {} - Expr::Number { .. } => {} - Expr::Flags { .. } => {} - Expr::Identifier { .. } => {} - Expr::Boolean { .. } => {} - Expr::Option { expr: None, .. } => {} - Expr::Throw { .. } => {} - Expr::GenerateWorkerName { .. } => {} +fn visit_children<'a>(expr: &'a Expr, mut each: impl FnMut(&'a Expr)) { + match expr { + Expr::Let { expr, .. } => each(expr), + Expr::SelectField { expr, .. } => each(expr), + Expr::SelectIndex { expr, index, .. } => { + each(expr); + each(index); + } + Expr::Sequence { exprs, .. } + | Expr::Tuple { exprs, .. } + | Expr::Concat { exprs, .. } + | Expr::ExprBlock { exprs, .. } => { + for e in exprs { + each(e); + } + } + Expr::Record { exprs, .. } => { + for (_, e) in exprs { + each(e); + } + } + Expr::Not { expr, .. } | Expr::Length { expr, .. } | Expr::Unwrap { expr, .. } => { + each(expr) + } + Expr::GreaterThan { lhs, rhs, .. } + | Expr::GreaterThanOrEqualTo { lhs, rhs, .. } + | Expr::LessThanOrEqualTo { lhs, rhs, .. } + | Expr::EqualTo { lhs, rhs, .. } + | Expr::Plus { lhs, rhs, .. } + | Expr::Minus { lhs, rhs, .. } + | Expr::Divide { lhs, rhs, .. } + | Expr::Multiply { lhs, rhs, .. } + | Expr::LessThan { lhs, rhs, .. } + | Expr::And { lhs, rhs, .. } + | Expr::Or { lhs, rhs, .. } => { + each(lhs); + each(rhs); + } + Expr::Cond { cond, lhs, rhs, .. } => { + each(cond); + each(lhs); + each(rhs); + } + Expr::PatternMatch { + predicate, + match_arms, + .. + } => { + each(predicate); + for arm in match_arms { + for lit in arm.arm_pattern.get_expr_literals() { + each(lit); + } + each(&arm.arm_resolution_expr); + } + } + Expr::Range { range, .. } => { + for e in range.get_exprs() { + each(e); + } + } + Expr::Option { + expr: Some(expr), .. + } => each(expr), + Expr::Result { expr: Ok(expr), .. } => each(expr), + Expr::Result { + expr: Err(expr), .. + } => each(expr), + Expr::Call { args, .. } => { + for arg in args { + each(arg); + } + } + Expr::ListComprehension { + iterable_expr, + yield_expr, + .. + } => { + each(iterable_expr); + each(yield_expr); + } + Expr::ListReduce { + iterable_expr, + init_value_expr, + yield_expr, + .. + } => { + each(iterable_expr); + each(init_value_expr); + each(yield_expr); + } + Expr::InvokeMethodLazy { lhs, args, .. } => { + each(lhs); + for arg in args { + each(arg); + } + } + Expr::GetTag { expr, .. } => each(expr), + Expr::Literal { .. } + | Expr::Number { .. } + | Expr::Flags { .. } + | Expr::Identifier { .. } + | Expr::Boolean { .. } + | Expr::Option { expr: None, .. } + | Expr::Throw { .. } + | Expr::GenerateWorkerName { .. } => {} } } @@ -620,7 +609,6 @@ mod internal { use crate::call_type::{CallType, InstanceCreationType}; use crate::Expr; - // (args, worker in calls, worker in inferred type) pub(crate) fn get_expressions_in_call_type_mut( call_type: &mut CallType, ) -> (Option<&mut [Expr]>, Option<&mut Box>) { diff --git a/rib-core/src/type_inference/global_input_inference.rs b/rib-core/src/type_inference/global_input_inference.rs index 1f9d6c5..2f3c358 100644 --- a/rib-core/src/type_inference/global_input_inference.rs +++ b/rib-core/src/type_inference/global_input_inference.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{Expr, ExprVisitor, InferredType}; +use crate::{visit_post_order_rev_mut, Expr, InferredType}; use std::collections::HashMap; // request.path.user is used as a string in one place @@ -21,9 +21,7 @@ use std::collections::HashMap; pub fn infer_global_inputs(expr: &mut Expr) { let global_variables_dictionary = collect_all_global_variables_type(expr); // Updating the collected types in all positions of input - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, @@ -37,14 +35,12 @@ pub fn infer_global_inputs(expr: &mut Expr) { } } } - } + }); } fn collect_all_global_variables_type(expr: &mut Expr) -> HashMap> { - let mut visitor = ExprVisitor::bottom_up(expr); - let mut all_types_of_global_variables = HashMap::new(); - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, @@ -66,7 +62,7 @@ fn collect_all_global_variables_type(expr: &mut Expr) -> HashMap = None; - while let Some(expr) = visitor.pop_front() { - match expr { - Expr::Identifier { - variable_id, - inferred_type, - .. - } => { - if variable_id == &type_spec.variable_id { - current_path.progress(); + visit_post_order_mut(expr, &mut |expr| match expr { + Expr::Identifier { + variable_id, + inferred_type, + .. + } => { + if variable_id == &type_spec.variable_id { + current_path.progress(); - if type_spec.path.is_empty() { - *inferred_type = type_spec.inferred_type.clone(); - previous_expr_ptr = None; - current_path = full_path.clone(); - } else { - previous_expr_ptr = Some(expr as *const _); - } - } else { + if type_spec.path.is_empty() { + *inferred_type = type_spec.inferred_type.clone(); previous_expr_ptr = None; current_path = full_path.clone(); + } else { + previous_expr_ptr = Some(expr as *const _); } + } else { + previous_expr_ptr = None; + current_path = full_path.clone(); } + } - Expr::SelectField { - expr: inner_expr, - field, - inferred_type, - .. - } => { - if let Some(prev_ptr) = previous_expr_ptr { - if std::ptr::eq(inner_expr.as_ref(), prev_ptr) { - if current_path.is_empty() { - *inferred_type = type_spec.inferred_type.clone(); - previous_expr_ptr = None; - current_path = full_path.clone(); - } else if current_path.current() - == Some(&PathElem::Field(field.to_string())) - { - current_path.progress(); - previous_expr_ptr = Some(expr as *const _); - } else { - previous_expr_ptr = None; - current_path = full_path.clone(); - } + Expr::SelectField { + expr: inner_expr, + field, + inferred_type, + .. + } => { + if let Some(prev_ptr) = previous_expr_ptr { + if std::ptr::eq(inner_expr.as_ref(), prev_ptr) { + if current_path.is_empty() { + *inferred_type = type_spec.inferred_type.clone(); + previous_expr_ptr = None; + current_path = full_path.clone(); + } else if current_path.current() == Some(&PathElem::Field(field.to_string())) { + current_path.progress(); + previous_expr_ptr = Some(expr as *const _); } else { previous_expr_ptr = None; current_path = full_path.clone(); } + } else { + previous_expr_ptr = None; + current_path = full_path.clone(); } } + } - _ => { - previous_expr_ptr = None; - current_path = full_path.clone(); - } + _ => { + previous_expr_ptr = None; + current_path = full_path.clone(); } - } + }); } #[cfg(test)] diff --git a/rib-core/src/type_inference/identifier_inference.rs b/rib-core/src/type_inference/identifier_inference.rs index ab67f87..8d89ff6 100644 --- a/rib-core/src/type_inference/identifier_inference.rs +++ b/rib-core/src/type_inference/identifier_inference.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{ArmPattern, Expr, ExprVisitor, InferredType, MatchArm, VariableId}; +use crate::{ + visit_post_order_rev_mut, visit_pre_order_mut, ArmPattern, Expr, InferredType, MatchArm, + VariableId, +}; use std::collections::HashMap; pub fn infer_all_identifiers(expr: &mut Expr) { @@ -34,11 +37,10 @@ fn infer_all_identifiers_bottom_up(expr: &mut Expr) { // Expr::Identifier(x) // Expr::Call(func, Expr::Identifier(x)) // Expr::Block(Expr::Let(x, Expr::Num(1)), Expr::Call(func, x)) - let mut visitor = ExprVisitor::bottom_up(expr); // Popping it from the back results in `Expr::Identifier(x)` to be processed first // in the above example. - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { match expr { // If identifier is inferred (probably because it was part of a function call befre), // make sure to update the identifier inference lookup table. @@ -69,52 +71,47 @@ fn infer_all_identifiers_bottom_up(expr: &mut Expr) { _ => {} } - } + }); } // This is more of an optional stage, as bottom-up type propagation would be enough // but helps with reaching early fix point later down the line of compilation phases fn infer_all_identifiers_top_down(expr: &mut Expr) { let mut identifier_lookup = IdentifierTypeState::new(); - let mut visitor = ExprVisitor::top_down(expr); - while let Some(expr) = visitor.pop_front() { - match expr { - Expr::Let { - variable_id, expr, .. - } => { - if let Some(inferred_type) = identifier_lookup.lookup(variable_id) { - expr.add_infer_type_mut(inferred_type); - } - - identifier_lookup.update(variable_id.clone(), expr.inferred_type()); + visit_pre_order_mut(expr, &mut |expr| match expr { + Expr::Let { + variable_id, expr, .. + } => { + if let Some(inferred_type) = identifier_lookup.lookup(variable_id) { + expr.add_infer_type_mut(inferred_type); } - Expr::Identifier { - variable_id, - inferred_type, - .. - } => { - if let Some(new_inferred_type) = identifier_lookup.lookup(variable_id) { - *inferred_type = inferred_type.merge(new_inferred_type) - } - identifier_lookup.update(variable_id.clone(), inferred_type.clone()); + identifier_lookup.update(variable_id.clone(), expr.inferred_type()); + } + Expr::Identifier { + variable_id, + inferred_type, + .. + } => { + if let Some(new_inferred_type) = identifier_lookup.lookup(variable_id) { + *inferred_type = inferred_type.merge(new_inferred_type) } - _ => {} + identifier_lookup.update(variable_id.clone(), inferred_type.clone()); } - } + + _ => {} + }); } fn infer_match_binding_variables(expr: &mut Expr) { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::PatternMatch { match_arms, .. } = expr { for arm in match_arms { process_arm(arm) } } - } + }); } // A state that maps from the identifiers to the types inferred @@ -176,9 +173,7 @@ fn collect_all_identifiers(pattern: &mut ArmPattern, state: &mut IdentifierTypeS } fn accumulate_types_of_identifiers(expr: &mut Expr, state: &mut IdentifierTypeState) { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, @@ -189,27 +184,23 @@ fn accumulate_types_of_identifiers(expr: &mut Expr, state: &mut IdentifierTypeSt state.update(variable_id.clone(), inferred_type.clone()) } } - } + }); } fn update_arm_resolution_expr_with_identifiers( arm_resolution: &mut Expr, state: &IdentifierTypeState, ) { - let mut visitor = ExprVisitor::bottom_up(arm_resolution); - - while let Some(expr) = visitor.pop_back() { - match expr { - Expr::Identifier { - variable_id, - inferred_type, - .. - } if variable_id.is_match_binding() => { - if let Some(new_inferred_type) = state.lookup(variable_id) { - *inferred_type = inferred_type.merge(new_inferred_type) - } + visit_post_order_rev_mut(arm_resolution, &mut |expr| match expr { + Expr::Identifier { + variable_id, + inferred_type, + .. + } if variable_id.is_match_binding() => { + if let Some(new_inferred_type) = state.lookup(variable_id) { + *inferred_type = inferred_type.merge(new_inferred_type) } - _ => {} } - } + _ => {} + }); } diff --git a/rib-core/src/type_inference/identify_instance_creation.rs b/rib-core/src/type_inference/identify_instance_creation.rs index 0f15599..7c60b41 100644 --- a/rib-core/src/type_inference/identify_instance_creation.rs +++ b/rib-core/src/type_inference/identify_instance_creation.rs @@ -17,11 +17,11 @@ use crate::call_type::{CallType, InstanceCreationType}; use crate::instance_type::InstanceType; use crate::rib_type_error::RibTypeErrorInternal; use crate::type_parameter::TypeParameter; -use crate::{ComponentDependencies, CustomInstanceSpec, Expr}; use crate::{ - CustomError, ExprVisitor, FunctionCallError, InferredType, ParsedFunctionReference, - TypeInternal, TypeOrigin, + try_visit_post_order_mut, try_visit_post_order_rev_mut, CustomError, FunctionCallError, + InferredType, ParsedFunctionReference, TypeInternal, TypeOrigin, }; +use crate::{ComponentDependencies, CustomInstanceSpec, Expr}; // Handling the following and making sure the types are inferred fully at this stage. // The expr `Call` will still be expr `Call` itself but CallType will be worker instance creation @@ -42,9 +42,7 @@ pub fn identify_instance_creation( pub fn search_for_invalid_instance_declarations( expr: &mut Expr, ) -> Result<(), RibTypeErrorInternal> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_front() { + try_visit_post_order_mut(expr, &mut |expr| { match expr { Expr::Let { variable_id, expr, .. @@ -74,9 +72,8 @@ pub fn search_for_invalid_instance_declarations( _ => {} } - } - - Ok(()) + Ok(()) + }) } // Identifying instance creations out of all parsed function calls. @@ -87,9 +84,7 @@ pub fn identify_instance_creation_with_worker( component_dependency: &ComponentDependencies, custom_instance_spec: &[CustomInstanceSpec], ) -> Result<(), RibTypeErrorInternal> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + try_visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Call { call_type, generic_type_parameter, @@ -151,9 +146,8 @@ pub fn identify_instance_creation_with_worker( ); } } - } - - Ok(()) + Ok(()) + }) } // Returns a new type parameter in certain cases diff --git a/rib-core/src/type_inference/index_selection_type_binding.rs b/rib-core/src/type_inference/index_selection_type_binding.rs index 7062a36..bcdc80f 100644 --- a/rib-core/src/type_inference/index_selection_type_binding.rs +++ b/rib-core/src/type_inference/index_selection_type_binding.rs @@ -12,34 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{Expr, ExprVisitor, InferredType}; +use crate::{visit_post_order_rev_mut, Expr, InferredType}; use std::ops::DerefMut; // All select indices with literal numbers don't need to explicit // type annotation to get better developer experience, // and all literal numbers will be automatically inferred as u64 pub fn bind_default_types_to_index_expressions(expr: &mut Expr) { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { - match expr { - Expr::SelectIndex { index, .. } => { - if let Expr::Number { inferred_type, .. } = index.deref_mut() { - *inferred_type = InferredType::u64() - } - - if let Expr::Range { range, .. } = index.deref_mut() { - let exprs = range.get_exprs_mut(); - - for expr in exprs { - if let Expr::Number { inferred_type, .. } = expr.deref_mut() { - *inferred_type = InferredType::u64() - } - } - } + visit_post_order_rev_mut(expr, &mut |expr| match expr { + Expr::SelectIndex { index, .. } => { + if let Expr::Number { inferred_type, .. } = index.deref_mut() { + *inferred_type = InferredType::u64() } - Expr::Range { range, .. } => { + if let Expr::Range { range, .. } = index.deref_mut() { let exprs = range.get_exprs_mut(); for expr in exprs { @@ -48,8 +34,18 @@ pub fn bind_default_types_to_index_expressions(expr: &mut Expr) { } } } + } + + Expr::Range { range, .. } => { + let exprs = range.get_exprs_mut(); - _ => {} + for expr in exprs { + if let Expr::Number { inferred_type, .. } = expr.deref_mut() { + *inferred_type = InferredType::u64() + } + } } - } + + _ => {} + }); } diff --git a/rib-core/src/type_inference/inference_fix_point.rs b/rib-core/src/type_inference/inference_fix_point.rs index 5197dd7..6d879c6 100644 --- a/rib-core/src/type_inference/inference_fix_point.rs +++ b/rib-core/src/type_inference/inference_fix_point.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{Expr, ExprVisitor, InferredType, TypeInternal}; +use crate::{visit_post_order, Expr, InferredType, TypeInternal}; // Given `f` executes inference, find expr where `f(expr) = expr` pub fn type_inference_fix_point(mut scan_and_infer: F, expr: &mut Expr) -> Result<(), E> @@ -33,16 +33,24 @@ where } fn compare_expr_types(left: &mut Expr, right: &mut Expr) -> bool { - let mut left_stack = ExprVisitor::bottom_up(left); - let mut right_stack = ExprVisitor::bottom_up(right); - - while let (Some(left), Some(right)) = (left_stack.pop_front(), right_stack.pop_front()) { - if !compare_inferred_types(&left.inferred_type(), &right.inferred_type()) { - return false; - } + let mut left_types = Vec::new(); + let mut right_types = Vec::new(); + + visit_post_order(left, &mut |expr| { + left_types.push(expr.inferred_type()); + }); + visit_post_order(right, &mut |expr| { + right_types.push(expr.inferred_type()); + }); + + if left_types.len() != right_types.len() { + return false; } - left_stack.is_empty() && right_stack.is_empty() + left_types + .iter() + .zip(right_types.iter()) + .all(|(l, r)| compare_inferred_types(l, r)) } fn compare_inferred_types(left: &InferredType, right: &InferredType) -> bool { diff --git a/rib-core/src/type_inference/inferred_expr.rs b/rib-core/src/type_inference/inferred_expr.rs index 0e965fa..51e5b2c 100644 --- a/rib-core/src/type_inference/inferred_expr.rs +++ b/rib-core/src/type_inference/inferred_expr.rs @@ -15,8 +15,8 @@ use crate::call_type::CallType; use crate::rib_type_error::RibTypeErrorInternal; use crate::{ - ComponentDependencies, CustomInstanceSpec, DynamicParsedFunctionName, Expr, ExprVisitor, - FunctionName, GlobalVariableTypeSpec, + visit_post_order_rev_mut, ComponentDependencies, CustomInstanceSpec, DynamicParsedFunctionName, + Expr, FunctionName, GlobalVariableTypeSpec, }; use std::collections::HashSet; @@ -50,9 +50,7 @@ impl InferredExpr { pub fn worker_invoke_calls(&self) -> Vec { let mut expr = self.0.clone(); let mut worker_calls = vec![]; - let mut visitor = ExprVisitor::bottom_up(&mut expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(&mut expr, &mut |expr| { if let Expr::Call { call_type: CallType::Function { function_name, .. }, .. @@ -60,7 +58,7 @@ impl InferredExpr { { worker_calls.push(function_name.clone()); } - } + }); worker_calls } diff --git a/rib-core/src/type_inference/instance_type_binding.rs b/rib-core/src/type_inference/instance_type_binding.rs index f3441fa..9062ce0 100644 --- a/rib-core/src/type_inference/instance_type_binding.rs +++ b/rib-core/src/type_inference/instance_type_binding.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{Expr, ExprVisitor, InferredType, TypeInternal, TypeOrigin}; +use crate::{visit_pre_order_mut, Expr, InferredType, TypeInternal, TypeOrigin}; use std::collections::HashMap; // This is about binding the `InstanceType` to the corresponding identifiers. @@ -32,37 +32,31 @@ use std::collections::HashMap; // // In this case `foo` in `foo` should have inferred type of `String` and not `InstanceType` pub fn bind_instance_types(expr: &mut Expr) { - let mut queue = ExprVisitor::top_down(expr); - let mut instance_variables = HashMap::new(); - while let Some(expr) = queue.pop_front() { - match expr { - Expr::Let { - variable_id, expr, .. - } => { - if let TypeInternal::Instance { instance_type } = - expr.inferred_type().internal_type() - { - instance_variables.insert(variable_id.clone(), instance_type.clone()); - } + visit_pre_order_mut(expr, &mut |expr| match expr { + Expr::Let { + variable_id, expr, .. + } => { + if let TypeInternal::Instance { instance_type } = expr.inferred_type().internal_type() { + instance_variables.insert(variable_id.clone(), instance_type.clone()); } - Expr::Identifier { - variable_id, - inferred_type, - .. - } => { - if let Some(new_inferred_type) = instance_variables.get(variable_id) { - *inferred_type = InferredType::new( - TypeInternal::Instance { - instance_type: new_inferred_type.clone(), - }, - TypeOrigin::NoOrigin, - ); - } + } + Expr::Identifier { + variable_id, + inferred_type, + .. + } => { + if let Some(new_inferred_type) = instance_variables.get(variable_id) { + *inferred_type = InferredType::new( + TypeInternal::Instance { + instance_type: new_inferred_type.clone(), + }, + TypeOrigin::NoOrigin, + ); } - - _ => {} } - } + + _ => {} + }); } diff --git a/rib-core/src/type_inference/rib_input_type.rs b/rib-core/src/type_inference/rib_input_type.rs index 02fd119..97dd4ef 100644 --- a/rib-core/src/type_inference/rib_input_type.rs +++ b/rib-core/src/type_inference/rib_input_type.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::analysis::AnalysedType; -use crate::{Expr, ExprVisitor, InferredExpr, RibCompilationError}; +use crate::{try_visit_post_order_rev_mut, Expr, InferredExpr, RibCompilationError}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -39,16 +39,15 @@ impl RibInputTypeInfo { inferred_expr: &InferredExpr, ) -> Result { let mut expr = inferred_expr.get_expr().clone(); - let mut queue = ExprVisitor::bottom_up(&mut expr); let mut global_variables = HashMap::new(); - while let Some(expr) = queue.pop_back() { + try_visit_post_order_rev_mut(&mut expr, &mut |expr| { if let Expr::Identifier { variable_id, inferred_type, .. - } = &expr + } = &*expr { if variable_id.is_global() { let analysed_type = AnalysedType::try_from(inferred_type).map_err(|e| { @@ -60,7 +59,8 @@ impl RibInputTypeInfo { global_variables.insert(variable_id.name(), analysed_type); } } - } + Ok::<(), RibCompilationError>(()) + })?; Ok(RibInputTypeInfo { types: global_variables, diff --git a/rib-core/src/type_inference/stateful_instance.rs b/rib-core/src/type_inference/stateful_instance.rs index 17e8010..5b7ef8d 100644 --- a/rib-core/src/type_inference/stateful_instance.rs +++ b/rib-core/src/type_inference/stateful_instance.rs @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{CallType, Expr, ExprVisitor, InstanceCreationType, TypeInternal}; +use crate::{visit_post_order_rev_mut, CallType, Expr, InstanceCreationType, TypeInternal}; pub fn ensure_stateful_instance(expr: &mut Expr) { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { if let Expr::Call { call_type: CallType::InstanceCreation(InstanceCreationType::WitWorker { worker_name, .. }), @@ -38,5 +36,5 @@ pub fn ensure_stateful_instance(expr: &mut Expr) { } } } - } + }); } diff --git a/rib-core/src/type_inference/type_pull_up.rs b/rib-core/src/type_inference/type_pull_up.rs index 79c7f21..c63fd13 100644 --- a/rib-core/src/type_inference/type_pull_up.rs +++ b/rib-core/src/type_inference/type_pull_up.rs @@ -18,20 +18,18 @@ use crate::type_inference::type_hint::TypeHint; use crate::type_refinement::precise_types::{ListType, RecordType}; use crate::type_refinement::TypeRefinement; use crate::FunctionName; +use crate::{try_visit_post_order_mut, CustomError, Expr}; use crate::{ ActualType, ComponentDependencies, ExpectedType, FullyQualifiedResourceMethod, GetTypeHint, InferredType, InstanceIdentifier, InterfaceName, MatchArm, PackageName, Path, Range, TypeInternal, TypeMismatchError, }; -use crate::{CustomError, Expr, ExprVisitor}; pub fn type_pull_up( expr: &mut Expr, component_dependencies: &ComponentDependencies, ) -> Result<(), RibTypeErrorInternal> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_front() { + try_visit_post_order_mut(expr, &mut |expr| { match expr { Expr::Tuple { exprs, @@ -248,9 +246,8 @@ pub fn type_pull_up( handle_range(range, inferred_type); } } - } - - Ok(()) + Ok(()) + }) } fn handle_list_comprehension( diff --git a/rib-core/src/type_inference/type_push_down.rs b/rib-core/src/type_inference/type_push_down.rs index 870047d..32e0804 100644 --- a/rib-core/src/type_inference/type_push_down.rs +++ b/rib-core/src/type_inference/type_push_down.rs @@ -16,13 +16,11 @@ use crate::rib_type_error::RibTypeErrorInternal; use crate::type_inference::type_push_down::internal::{ handle_list_comprehension, handle_list_reduce, }; -use crate::{Expr, ExprVisitor, InferredType, MatchArm, TypeInternal}; +use crate::{try_visit_post_order_rev_mut, Expr, InferredType, MatchArm, TypeInternal}; use std::ops::Deref; pub fn push_types_down(expr: &mut Expr) -> Result<(), RibTypeErrorInternal> { - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(outer_expr) = visitor.pop_back() { + try_visit_post_order_rev_mut(expr, &mut |outer_expr| { let source_span = outer_expr.source_span(); match outer_expr { @@ -241,9 +239,8 @@ pub fn push_types_down(expr: &mut Expr) -> Result<(), RibTypeErrorInternal> { _ => {} } - } - - Ok(()) + Ok(()) + }) } mod internal { diff --git a/rib-core/src/type_inference/type_reset.rs b/rib-core/src/type_inference/type_reset.rs index d411e34..4e0da2d 100644 --- a/rib-core/src/type_inference/type_reset.rs +++ b/rib-core/src/type_inference/type_reset.rs @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{Expr, ExprVisitor, InferredType}; +use crate::{visit_post_order_rev_mut, Expr, InferredType}; pub fn reset_type_info(expr: &mut Expr) { - let mut visitor = ExprVisitor::bottom_up(expr); - - // Start from the end - while let Some(expr) = visitor.pop_back() { + visit_post_order_rev_mut(expr, &mut |expr| { expr.with_inferred_type_mut(InferredType::unknown()); - } + }); } diff --git a/rib-core/src/type_inference/type_unification.rs b/rib-core/src/type_inference/type_unification.rs index 0e0fde4..41acb80 100644 --- a/rib-core/src/type_inference/type_unification.rs +++ b/rib-core/src/type_inference/type_unification.rs @@ -14,16 +14,14 @@ use crate::inferred_type::UnificationFailureInternal; use crate::rib_source_span::SourceSpan; -use crate::{Expr, ExprVisitor, InferredType, TypeUnificationError}; +use crate::{try_visit_post_order_mut, Expr, InferredType, TypeUnificationError}; pub fn unify_types(expr: &mut Expr) -> Result<(), TypeUnificationError> { // keeping the original expression to lookup source span let original_expr = expr.clone(); - let mut visitor = ExprVisitor::bottom_up(expr); - - // Pop front to get the innermost expression first that may have caused the type mismatch. - while let Some(sub_expr) = visitor.pop_front() { + // Visit innermost expressions first (post-order) to find the root cause of type mismatch. + try_visit_post_order_mut(expr, &mut |sub_expr| { match sub_expr { Expr::Let { .. } => {} Expr::Boolean { .. } => {} @@ -40,9 +38,8 @@ pub fn unify_types(expr: &mut Expr) -> Result<(), TypeUnificationError> { unify_inferred_type(&original_expr, sub_expr)?; } } - } - - Ok(()) + Ok(()) + }) } fn unify_inferred_type( diff --git a/rib-core/src/type_inference/variable_binding.rs b/rib-core/src/type_inference/variable_binding.rs index e59b799..685dadf 100644 --- a/rib-core/src/type_inference/variable_binding.rs +++ b/rib-core/src/type_inference/variable_binding.rs @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{ArmPattern, Expr, ExprVisitor, MatchArm, MatchIdentifier, VariableId}; +use crate::{ + visit_post_order_mut, visit_pre_order_mut, ArmPattern, Expr, MatchArm, MatchIdentifier, + VariableId, +}; use std::collections::HashMap; // This function will assign ids to variables declared with `let` expressions, // and propagate these ids to the usage sites (`Expr::Identifier` nodes). pub fn bind_variables_of_let_assignment(expr: &mut Expr) { let mut identifier_id_state = IdentifierVariableIdState::new(); - let mut visitor = ExprVisitor::bottom_up(expr); - - // Start from the end - while let Some(expr) = visitor.pop_front() { + visit_post_order_mut(expr, &mut |expr| { match expr { Expr::Let { variable_id, .. } => { let field_name = variable_id.name(); @@ -40,13 +40,11 @@ pub fn bind_variables_of_let_assignment(expr: &mut Expr) { } _ => {} } - } + }); } pub fn bind_variables_of_list_comprehension(expr: &mut Expr) { - let mut visitor = ExprVisitor::top_down(expr); - - while let Some(expr) = visitor.pop_front() { + visit_pre_order_mut(expr, &mut |expr| { if let Expr::ListComprehension { iterated_variable, yield_expr, @@ -58,14 +56,11 @@ pub fn bind_variables_of_list_comprehension(expr: &mut Expr) { process_yield_expr_in_comprehension(iterated_variable, yield_expr) } - } + }); } pub fn bind_variables_of_list_reduce(expr: &mut Expr) { - let mut visitor = ExprVisitor::top_down(expr); - - // Start from the end - while let Some(expr) = visitor.pop_front() { + visit_pre_order_mut(expr, &mut |expr| { if let Expr::ListReduce { reduce_variable, iterated_variable, @@ -82,7 +77,7 @@ pub fn bind_variables_of_list_reduce(expr: &mut Expr) { process_yield_expr_in_reduce(reduce_variable, iterated_variable, yield_expr) } - } + }); } pub fn bind_variables_of_pattern_match(expr: &mut Expr) { @@ -95,11 +90,9 @@ fn bind_variables_in_pattern_match_internal( match_identifiers: &mut [MatchIdentifier], ) -> usize { let mut index = previous_index; - let mut queue = ExprVisitor::top_down(expr); let mut shadowed_let_binding = vec![]; - // Start from the end - while let Some(expr) = queue.pop_front() { + visit_pre_order_mut(expr, &mut |expr| { match expr { Expr::PatternMatch { match_arms, .. } => { for arm in match_arms { @@ -125,7 +118,7 @@ fn bind_variables_in_pattern_match_internal( _ => {} } - } + }); index } @@ -200,9 +193,7 @@ fn update_all_identifier_in_lhs_expr( global_arm_index: usize, ) -> Vec { let mut identifier_names = vec![]; - let mut visitor = ExprVisitor::bottom_up(expr); - - while let Some(expr) = visitor.pop_front() { + visit_post_order_mut(expr, &mut |expr| { if let Expr::Identifier { variable_id, .. } = expr { let match_identifier = MatchIdentifier::new(variable_id.name(), global_arm_index); identifier_names.push(match_identifier); @@ -210,21 +201,19 @@ fn update_all_identifier_in_lhs_expr( VariableId::match_identifier(variable_id.name(), global_arm_index); *variable_id = new_variable_id; } - } + }); identifier_names } fn process_yield_expr_in_comprehension(variable: &mut VariableId, yield_expr: &mut Expr) { - let mut visitor = ExprVisitor::top_down(yield_expr); - - while let Some(expr) = visitor.pop_front() { + visit_pre_order_mut(yield_expr, &mut |expr| { if let Expr::Identifier { variable_id, .. } = expr { if variable.name() == variable_id.name() { *variable_id = variable.clone(); } } - } + }); } fn process_yield_expr_in_reduce( @@ -232,9 +221,7 @@ fn process_yield_expr_in_reduce( iterated_variable_id: &mut VariableId, yield_expr: &mut Expr, ) { - let mut visitor = ExprVisitor::top_down(yield_expr); - - while let Some(expr) = visitor.pop_front() { + visit_pre_order_mut(yield_expr, &mut |expr| { if let Expr::Identifier { variable_id, .. } = expr { if iterated_variable_id.name() == variable_id.name() { *variable_id = iterated_variable_id.clone(); @@ -242,27 +229,27 @@ fn process_yield_expr_in_reduce( *variable_id = reduce_variable.clone() } } - } + }); } struct IdentifierVariableIdState(HashMap); impl IdentifierVariableIdState { - pub(crate) fn new() -> Self { + fn new() -> Self { IdentifierVariableIdState(HashMap::new()) } - pub(crate) fn update_variable_id(&mut self, identifier: &str) { + fn update_variable_id(&mut self, name: &str) { self.0 - .entry(identifier.to_string()) + .entry(name.to_string()) .and_modify(|x| { *x = x.increment_local_variable_id(); }) - .or_insert(VariableId::local(identifier, 0)); + .or_insert_with(|| VariableId::local(name, 0)); } - pub(crate) fn lookup(&self, identifier: &str) -> Option { - self.0.get(identifier).cloned() + fn lookup(&self, name: &str) -> Option<&VariableId> { + self.0.get(name) } } @@ -284,7 +271,6 @@ mod name_binding_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - // Bind x in let with the x in foo expr.bind_variables_of_let_assignment(); let let_binding = Expr::let_binding_with_variable_id( @@ -312,65 +298,6 @@ mod name_binding_tests { assert_eq!(expr, expected); } - #[test] - fn test_name_binding_multiple() { - let rib_expr = r#" - let x = 1; - let y = 2; - foo(x); - foo(y) - "#; - - let mut expr = Expr::from_text(rib_expr).unwrap(); - - // Bind x in let with the x in foo - expr.bind_variables_of_let_assignment(); - - let let_binding1 = Expr::let_binding_with_variable_id( - VariableId::local("x", 0), - Expr::number(BigDecimal::from(1)), - None, - ); - - let let_binding2 = Expr::let_binding_with_variable_id( - VariableId::local("y", 0), - Expr::number(BigDecimal::from(2)), - None, - ); - - let call_expr1 = Expr::call( - CallType::function_call( - DynamicParsedFunctionName { - site: ParsedFunctionSite::Global, - function: DynamicParsedFunctionReference::Function { - function: "foo".to_string(), - }, - }, - None, - ), - None, - vec![Expr::identifier_local("x", 0, None)], - ); - - let call_expr2 = Expr::call( - CallType::function_call( - DynamicParsedFunctionName { - site: ParsedFunctionSite::Global, - function: DynamicParsedFunctionReference::Function { - function: "foo".to_string(), - }, - }, - None, - ), - None, - vec![Expr::identifier_local("y", 0, None)], - ); - - let expected = Expr::expr_block(vec![let_binding1, let_binding2, call_expr1, call_expr2]); - - assert_eq!(expr, expected); - } - #[test] fn test_name_binding_shadowing() { let rib_expr = r#" @@ -382,7 +309,6 @@ mod name_binding_tests { let mut expr = Expr::from_text(rib_expr).unwrap(); - // Bind x in let with the x in foo expr.bind_variables_of_let_assignment(); let let_binding1 = Expr::let_binding_with_variable_id( @@ -432,7 +358,6 @@ mod name_binding_tests { #[test] fn test_simple_pattern_match_name_binding() { - // The first x is global and the second x is a match binding let expr_string = r#" match some(x) { some(x) => x, @@ -447,29 +372,8 @@ mod name_binding_tests { assert_eq!(expr, expectations::expected_match(1)); } - #[test] - fn test_simple_pattern_match_name_binding_with_shadow() { - // The first x is global and the second x is a match binding - let expr_string = r#" - match some(x) { - some(x) => { - let x = 1; - x - }, - none => 0 - } - "#; - - let mut expr = Expr::from_text(expr_string).unwrap(); - - expr.bind_variables_of_pattern_match(); - - assert_eq!(expr, expectations::expected_match_with_let_binding(1)); - } - #[test] fn test_simple_pattern_match_name_binding_block() { - // The first x is global and the second x is a match binding let expr_string = r#" match some(x) { some(x) => x, @@ -487,7 +391,7 @@ mod name_binding_tests { expr.bind_variables_of_pattern_match(); let first_expr = expectations::expected_match(1); - let second_expr = expectations::expected_match(3); // 3 because first block has 2 arms + let second_expr = expectations::expected_match(3); let block = Expr::expr_block(vec![first_expr, second_expr]) .with_inferred_type(InferredType::unknown()); @@ -495,25 +399,6 @@ mod name_binding_tests { assert_eq!(expr, block); } - #[test] - fn test_nested_simple_pattern_match_binding() { - let expr_string = r#" - match ok(some(x)) { - ok(x) => match x { - some(x) => x, - none => 0 - }, - err(x) => 0 - } - "#; - - let mut expr = Expr::from_text(expr_string).unwrap(); - - expr.bind_variables_of_pattern_match(); - - assert_eq!(expr, expectations::expected_nested_match()); - } - mod expectations { use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId}; use bigdecimal::BigDecimal; @@ -549,119 +434,5 @@ mod name_binding_tests { ], ) } - - pub(crate) fn expected_match_with_let_binding(index: usize) -> Expr { - let let_binding = Expr::let_binding("x", Expr::number(BigDecimal::from(1)), None); - let identifier_expr = - Expr::identifier_with_variable_id(VariableId::Global("x".to_string()), None); - let block = Expr::expr_block(vec![let_binding, identifier_expr]); - - Expr::pattern_match( - Expr::option(Some(Expr::identifier_global("x", None))), - vec![ - MatchArm { - arm_pattern: ArmPattern::constructor( - "some", - vec![ArmPattern::literal(Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - index, - )), - None, - ))], - ), - arm_resolution_expr: Box::new(block), - }, - MatchArm { - arm_pattern: ArmPattern::constructor("none", vec![]), - arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))), - }, - ], - ) - } - - pub(crate) fn expected_nested_match() -> Expr { - Expr::pattern_match( - Expr::ok( - Expr::option(Some(Expr::identifier_with_variable_id( - VariableId::Global("x".to_string()), - None, - ))) - .with_inferred_type(InferredType::option(InferredType::unknown())), - None, - ) - .with_inferred_type(InferredType::result( - Some(InferredType::option(InferredType::unknown())), - Some(InferredType::unknown()), - )), - vec![ - MatchArm { - arm_pattern: ArmPattern::constructor( - "ok", - vec![ArmPattern::literal(Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - 1, - )), - None, - ))], - ), - arm_resolution_expr: Box::new(Expr::pattern_match( - Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - 1, - )), - None, - ), - vec![ - MatchArm { - arm_pattern: ArmPattern::constructor( - "some", - vec![ArmPattern::literal( - Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - 5, - )), - None, - ), - )], - ), - arm_resolution_expr: Box::new( - Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - 5, - )), - None, - ), - ), - }, - MatchArm { - arm_pattern: ArmPattern::constructor("none", vec![]), - arm_resolution_expr: Box::new(Expr::number(BigDecimal::from( - 0, - ))), - }, - ], - )), - }, - MatchArm { - arm_pattern: ArmPattern::constructor( - "err", - vec![ArmPattern::literal(Expr::identifier_with_variable_id( - VariableId::MatchIdentifier(MatchIdentifier::new( - "x".to_string(), - 4, - )), - None, - ))], - ), - arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))), - }, - ], - ) - } } } diff --git a/rib-repl/src/compiler.rs b/rib-repl/src/compiler.rs index 91ba44d..3967e24 100644 --- a/rib-repl/src/compiler.rs +++ b/rib-repl/src/compiler.rs @@ -124,36 +124,32 @@ impl InstanceVariables { pub fn get_identifiers(inferred_expr: &InferredExpr) -> Vec { let mut expr = inferred_expr.get_expr().clone(); - let mut visitor = ExprVisitor::bottom_up(&mut expr); let mut identifiers = Vec::new(); - while let Some(expr) = visitor.pop_back() { - match expr { - Expr::Let { variable_id, .. } => { - if !identifiers.contains(variable_id) { - identifiers.push(variable_id.clone()); - } + visit_post_order_rev_mut(&mut expr, &mut |expr| match expr { + Expr::Let { variable_id, .. } => { + if !identifiers.contains(variable_id) { + identifiers.push(variable_id.clone()); } - Expr::Identifier { variable_id, .. } => { - if !identifiers.contains(variable_id) { - identifiers.push(variable_id.clone()); - } + } + Expr::Identifier { variable_id, .. } => { + if !identifiers.contains(variable_id) { + identifiers.push(variable_id.clone()); } - _ => {} } - } + _ => {} + }); identifiers } pub fn fetch_instance_variables(inferred_expr: &InferredExpr) -> InstanceVariables { let mut expr = inferred_expr.get_expr().clone(); - let mut queue = ExprVisitor::bottom_up(&mut expr); let mut instance_variables = HashMap::new(); - while let Some(expr) = queue.pop_front() { + visit_post_order_mut(&mut expr, &mut |expr| { if let Expr::Let { variable_id, expr, .. } = expr @@ -172,7 +168,7 @@ pub fn fetch_instance_variables(inferred_expr: &InferredExpr) -> InstanceVariabl }; } } - } + }); InstanceVariables { instance_variables } }