diff --git a/Cargo.toml b/Cargo.toml index 30411c19..a5beceb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ default = ["full-opa", "arc"] arc = ["scientific/arc"] ast = [] -azure_policy = ["dep:jsonschema", "arc", "dashmap"] +azure_policy = ["dep:jsonschema", "arc", "dashmap", "regex"] base64 = ["dep:data-encoding"] base64url = ["dep:data-encoding"] coverage = [] diff --git a/src/ast.rs b/src/ast.rs index 4e16f500..7612108d 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -107,159 +107,159 @@ pub type Ref = NodeRef; pub enum Expr { // Simple items that only have a span as content. String { + eidx: u32, span: Span, value: Value, - eidx: u32, }, RawString { + eidx: u32, span: Span, value: Value, - eidx: u32, }, Number { + eidx: u32, span: Span, value: Value, - eidx: u32, }, Bool { + eidx: u32, span: Span, value: Value, - eidx: u32, }, Null { + eidx: u32, span: Span, value: Value, - eidx: u32, }, Var { + eidx: u32, span: Span, value: Value, - eidx: u32, }, // array Array { + eidx: u32, span: Span, items: Vec>, - eidx: u32, }, // set Set { + eidx: u32, span: Span, items: Vec>, - eidx: u32, }, Object { + eidx: u32, span: Span, fields: Vec<(Span, Ref, Ref)>, - eidx: u32, }, // Comprehensions ArrayCompr { + eidx: u32, span: Span, term: Ref, query: Ref, - eidx: u32, }, SetCompr { + eidx: u32, span: Span, term: Ref, query: Ref, - eidx: u32, }, ObjectCompr { + eidx: u32, span: Span, key: Ref, value: Ref, query: Ref, - eidx: u32, }, Call { + eidx: u32, span: Span, fcn: Ref, params: Vec>, - eidx: u32, }, UnaryExpr { + eidx: u32, span: Span, expr: Ref, - eidx: u32, }, // ref RefDot { + eidx: u32, span: Span, refr: Ref, field: (Span, Value), - eidx: u32, }, RefBrack { + eidx: u32, span: Span, refr: Ref, index: Ref, - eidx: u32, }, // Infix expressions BinExpr { + eidx: u32, span: Span, op: BinOp, lhs: Ref, rhs: Ref, - eidx: u32, }, BoolExpr { + eidx: u32, span: Span, op: BoolOp, lhs: Ref, rhs: Ref, - eidx: u32, }, ArithExpr { + eidx: u32, span: Span, op: ArithOp, lhs: Ref, rhs: Ref, - eidx: u32, }, AssignExpr { + eidx: u32, span: Span, op: AssignOp, lhs: Ref, rhs: Ref, - eidx: u32, }, Membership { + eidx: u32, span: Span, key: Option>, value: Ref, collection: Ref, - eidx: u32, }, #[cfg(feature = "rego-extensions")] OrExpr { + eidx: u32, span: Span, lhs: Ref, rhs: Ref, - eidx: u32, }, } @@ -364,19 +364,19 @@ pub struct WithModifier { #[derive(Debug)] #[cfg_attr(feature = "ast", derive(serde::Serialize))] pub struct LiteralStmt { + pub sidx: u32, pub span: Span, pub literal: Literal, #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Vec::is_empty"))] pub with_mods: Vec, - pub sidx: u32, } #[derive(Debug)] #[cfg_attr(feature = "ast", derive(serde::Serialize))] pub struct Query { + pub qidx: u32, pub span: Span, pub stmts: Vec, - pub qidx: u32, } #[derive(Debug)] @@ -420,11 +420,13 @@ pub enum RuleHead { #[cfg_attr(feature = "ast", derive(serde::Serialize))] pub enum Rule { Spec { + ridx: u32, span: Span, head: RuleHead, bodies: Vec, }, Default { + ridx: u32, span: Span, refr: Ref, args: Vec>, @@ -439,6 +441,12 @@ impl Rule { Self::Spec { span, .. } | Self::Default { span, .. } => span, } } + + pub fn ridx(&self) -> u32 { + match self { + Self::Spec { ridx, .. } | Self::Default { ridx, .. } => *ridx, + } + } } #[derive(Debug)] @@ -468,12 +476,27 @@ pub struct Module { // Target name if specified via __target__ rule #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Option::is_none"))] pub target: Option, + // Expression spans indexed by expression index (eidx) + #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Vec::is_empty"))] + pub expression_spans: Vec, + // Statement spans indexed by statement index (sidx) + #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Vec::is_empty"))] + pub statement_spans: Vec, + // Query spans indexed by query index (qidx) + #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Vec::is_empty"))] + pub query_spans: Vec, + // Rule spans indexed by rule index (ridx) + #[cfg_attr(feature = "ast", serde(skip_serializing_if = "Vec::is_empty"))] + pub rule_spans: Vec, // Number of expressions in the module. pub num_expressions: u32, // Number of statements in the module. pub num_statements: u32, // Number of queries in the module. pub num_queries: u32, + // Number of rules in the module. + pub num_rules: u32, } pub type ExprRef = Ref; +pub type RuleRef = Ref; diff --git a/src/compiled_policy.rs b/src/compiled_policy.rs index 213e448c..9f1bc42b 100644 --- a/src/compiled_policy.rs +++ b/src/compiled_policy.rs @@ -146,7 +146,7 @@ impl CompiledPolicy { inferred_types .values() .map(|(resource_type, _schema)| resource_type.clone()) - .collect::>() // Remove duplicates + .collect::>() // Remove duplicates .into_iter() .collect() } else { diff --git a/src/engine.rs b/src/engine.rs index 19c14e45..1bdef2a0 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -912,6 +912,7 @@ impl Engine { if !for_target { // Check if any module specifies a target and warn if so #[cfg(feature = "azure_policy")] + #[cfg(feature = "std")] self.warn_if_targets_present(); } @@ -1316,6 +1317,7 @@ impl Engine { /// Emit a warning if any modules contain target specifications but we're not using target-aware compilation. #[cfg(feature = "azure_policy")] + #[cfg(feature = "std")] fn warn_if_targets_present(&self) { let mut has_target = false; let mut target_files = Vec::new(); diff --git a/src/indexchecker.rs b/src/indexchecker.rs index b6ab59e4..4a480ad3 100644 --- a/src/indexchecker.rs +++ b/src/indexchecker.rs @@ -10,14 +10,16 @@ use anyhow::{bail, Result}; // Ensures that indexes are unique and continuous, starting from 0. #[derive(Default)] -pub struct IndexChecker { +pub struct IndexChecker<'a> { eidx: BTreeSet, sidx: BTreeSet, qidx: BTreeSet, + ridx: BTreeSet, + module: Option<&'a Module>, } #[cfg(debug_assertions)] -impl IndexChecker { +impl<'a> IndexChecker<'a> { fn check_query(&mut self, query: &Query) -> Result<()> { let qidx = query.qidx; if !self.qidx.insert(qidx) { @@ -25,6 +27,7 @@ impl IndexChecker { .span .error(format!("query with qidx {qidx} already exists").as_str())); } + self.check_qidx(query)?; for stmt in &query.stmts { if !self.sidx.insert(stmt.sidx) { @@ -32,6 +35,7 @@ impl IndexChecker { .span .error(format!("statement with sidx {} already exists", stmt.sidx).as_str())); } + self.check_sidx(stmt)?; match &stmt.literal { Literal::Every { domain, query, .. } => { self.check_eidx(domain)?; @@ -73,6 +77,29 @@ impl IndexChecker { .error(format!("expression with eidx {eidx} already exists").as_str())); } + // Check that the span in expression_spans matches the expression's span + if let Some(module) = self.module { + if let Some(stored_span) = module.expression_spans.get(eidx as usize) { + let expr_span = expr.span(); + if stored_span.start != expr_span.start || stored_span.end != expr_span.end { + bail!(expr + .span() + .error(format!( + "expression span position mismatch at eidx {}: stored span positions ({}..{}) != expression span positions ({}..{})", + eidx, + stored_span.start, + stored_span.end, + expr_span.start, + expr_span.end + ).as_str())); + } + } else { + bail!(expr + .span() + .error(format!("missing span in expression_spans for eidx {}", eidx).as_str())); + } + } + match expr { String { .. } | RawString { .. } @@ -167,6 +194,99 @@ impl IndexChecker { Ok(()) } + fn check_sidx(&mut self, stmt: &LiteralStmt) -> Result<()> { + let sidx = stmt.sidx; + + // Check that the span in statement_spans matches the statement's span + if let Some(module) = self.module { + if let Some(stored_span) = module.statement_spans.get(sidx as usize) { + let stmt_span = &stmt.span; + if stored_span.start != stmt_span.start || stored_span.end != stmt_span.end { + bail!(stmt + .span + .error(format!( + "statement span position mismatch at sidx {}: stored span positions ({}..{}) != statement span positions ({}..{})", + sidx, + stored_span.start, + stored_span.end, + stmt_span.start, + stmt_span.end + ).as_str())); + } + } else { + bail!(stmt + .span + .error(format!("missing span in statement_spans for sidx {}", sidx).as_str())); + } + } + + Ok(()) + } + + fn check_qidx(&mut self, query: &Query) -> Result<()> { + let qidx = query.qidx; + + // Check that the span in query_spans matches the query's span + if let Some(module) = self.module { + if let Some(stored_span) = module.query_spans.get(qidx as usize) { + let query_span = &query.span; + if stored_span.start != query_span.start || stored_span.end != query_span.end { + bail!(query + .span + .error(format!( + "query span position mismatch at qidx {}: stored span positions ({}..{}) != query span positions ({}..{})", + qidx, + stored_span.start, + stored_span.end, + query_span.start, + query_span.end + ).as_str())); + } + } else { + bail!(query + .span + .error(format!("missing span in query_spans for qidx {}", qidx).as_str())); + } + } + + Ok(()) + } + + fn check_ridx(&mut self, rule: &Rule) -> Result<()> { + let ridx = rule.ridx(); + + // Check that the span in rule_spans matches the rule's span + if let Some(module) = self.module { + if let Some(stored_span) = module.rule_spans.get(ridx as usize) { + let rule_span = rule.span(); + if stored_span.start != rule_span.start || stored_span.end != rule_span.end { + bail!(rule + .span() + .error(format!( + "rule span position mismatch at ridx {}: stored span positions ({}..{}) != rule span positions ({}..{})", + ridx, + stored_span.start, + stored_span.end, + rule_span.start, + rule_span.end + ).as_str())); + } + } else { + bail!(rule + .span() + .error(format!("missing span in rule_spans for ridx {}", ridx).as_str())); + } + } + + if !self.ridx.insert(ridx) { + bail!(rule + .span() + .error(format!("ridx {} was seen before", ridx).as_str())); + } + + Ok(()) + } + fn check_rule_assign(&mut self, assign: &RuleAssign) -> Result<()> { self.check_eidx(&assign.value) } @@ -242,13 +362,16 @@ impl IndexChecker { Ok(()) } - pub fn check_module(&mut self, module: &Module) -> Result<()> { + pub fn check_module(&mut self, module: &'a Module) -> Result<()> { + self.module = Some(module); + self.check_eidx(module.package.refr.as_ref())?; for import in &module.imports { self.check_eidx(import.refr.as_ref())?; } for rule in &module.policy { + self.check_ridx(rule.as_ref())?; match rule.as_ref() { Rule::Spec { head, bodies, .. } => { self.check_rule_heade(head)?; @@ -275,6 +398,7 @@ impl IndexChecker { self.check_gathered_indexes(module.num_expressions, &self.eidx, "expression")?; self.check_gathered_indexes(module.num_statements, &self.sidx, "statement")?; self.check_gathered_indexes(module.num_queries, &self.qidx, "query")?; + self.check_gathered_indexes(module.num_rules, &self.ridx, "rule")?; Ok(()) } diff --git a/src/interpreter.rs b/src/interpreter.rs index 1c2052e3..d003b96c 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -3423,6 +3423,7 @@ impl Interpreter { fn eval_rule_impl(&mut self, module: &Ref, rule: &Ref) -> Result<()> { match rule.as_ref() { Rule::Spec { + ridx: _, span, head: rule_head, bodies: rule_body, diff --git a/src/parser.rs b/src/parser.rs index c46e3ac8..a9c815f2 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -28,6 +28,16 @@ pub struct Parser<'source> { sidx: u32, // The index of the last query that was parsed. qidx: u32, + // The index of the last rule that was parsed. + ridx: u32, + // Expression spans indexed by expression index + expression_spans: Vec, + // Statement spans indexed by statement index + statement_spans: Vec, + // Query spans indexed by query index + query_spans: Vec, + // Rule spans indexed by rule index + rule_spans: Vec, } const FUTURE_KEYWORDS: [&str; 4] = ["contains", "every", "if", "in"]; @@ -47,27 +57,86 @@ impl<'source> Parser<'source> { eidx: 0, sidx: 0, qidx: 0, + ridx: 0, + expression_spans: Vec::new(), + statement_spans: Vec::new(), + query_spans: Vec::new(), + rule_spans: Vec::new(), }) } - fn next_eidx(&mut self) -> u32 { + fn next_eidx_with_span(&mut self, span: Span) -> u32 { let eidx = self.eidx; self.eidx += 1; + // Store the span for this expression index + self.track_expr_span(eidx, span); eidx } - fn next_sidx(&mut self) -> u32 { + fn track_expr_span(&mut self, eidx: u32, span: Span) { + // Ensure the vector is large enough + while self.expression_spans.len() <= eidx as usize { + self.expression_spans.push(span.clone()); + } + self.expression_spans[eidx as usize] = span; + } + + fn next_sidx_with_span(&mut self, span: Span) -> u32 { let sidx = self.sidx; self.sidx += 1; + // Store the span for this statement index + self.track_stmt_span(sidx, span); sidx } + fn track_stmt_span(&mut self, sidx: u32, span: Span) { + // Ensure the vector is large enough + while self.statement_spans.len() <= sidx as usize { + self.statement_spans.push(span.clone()); + } + self.statement_spans[sidx as usize] = span; + } + fn next_qidx(&mut self) -> u32 { let qidx = self.qidx; self.qidx += 1; qidx } + fn next_qidx_with_span(&mut self, span: Span) -> u32 { + let qidx = self.next_qidx(); + self.track_query_span(qidx, span); + qidx + } + + fn track_query_span(&mut self, qidx: u32, span: Span) { + // Ensure the query_spans vector is large enough + while self.query_spans.len() <= qidx as usize { + self.query_spans.push(span.clone()); + } + self.query_spans[qidx as usize] = span; + } + + fn next_ridx(&mut self) -> u32 { + let ridx = self.ridx; + self.ridx += 1; + ridx + } + + fn next_ridx_with_span(&mut self, span: Span) -> u32 { + let ridx = self.next_ridx(); + self.track_rule_span(ridx, span); + ridx + } + + fn track_rule_span(&mut self, ridx: u32, span: Span) { + // Ensure the rule_spans vector is large enough + while self.rule_spans.len() <= ridx as usize { + self.rule_spans.push(span.clone()); + } + self.rule_spans[ridx as usize] = span; + } + pub fn enable_rego_v1(&mut self) -> Result<()> { self.turn_on_rego_v1(&None) } @@ -293,9 +362,9 @@ impl<'source> Parser<'source> { fn read_number(&mut self, span: Span) -> Result { match Number::from_str(span.text()) { Ok(v) => Ok(Expr::Number { - span, + span: span.clone(), value: Value::Number(v), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }), Err(_) => bail!(span.error("could not parse number")), } @@ -312,42 +381,42 @@ impl<'source> Parser<'source> { Err(e) => bail!(span.error(format!("invalid string literal. {e}").as_str())), }; Expr::String { - span, + span: span.clone(), value: v, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), } } TokenKind::RawString => { let v = Value::from(span.text().to_string()); Expr::RawString { - span, + span: span.clone(), value: v, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), } } TokenKind::Ident => match self.token_text() { "null" => Expr::Null { - span, + span: span.clone(), value: Value::Null, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }, "true" => Expr::Bool { - span, + span: span.clone(), value: Value::from(true), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }, "false" => Expr::Bool { - span, + span: span.clone(), value: Value::from(false), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }, _ => { let ident = self.parse_var()?; let value = Value::from(ident.text()); return Ok(Expr::Var { - span: ident, + span: ident.clone(), value, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(ident), }); } }, @@ -406,10 +475,10 @@ impl<'source> Parser<'source> { Ok((term, query)) => { span.end = self.end; Ok(Expr::ArrayCompr { - span, + span: span.clone(), term: Ref::new(term), query: Ref::new(query), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } Err(_) if self.end == pos => { @@ -430,9 +499,9 @@ impl<'source> Parser<'source> { self.expect("]", "while parsing array")?; span.end = self.end; Ok(Expr::Array { - span, + span: span.clone(), items, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } Err(err) => Err(err), @@ -448,10 +517,10 @@ impl<'source> Parser<'source> { Ok((term, query)) => { span.end = self.end; return Ok(Expr::SetCompr { - span, + span: span.clone(), term: Ref::new(term), query: Ref::new(query), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }); } Err(err) if self.end != pos => { @@ -468,9 +537,9 @@ impl<'source> Parser<'source> { self.next_token()?; span.end = self.end; return Ok(Expr::Object { - span, + span: span.clone(), fields: vec![], - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }); } @@ -491,9 +560,9 @@ impl<'source> Parser<'source> { self.expect("}", "while parsing set")?; span.end = self.end; return Ok(Expr::Set { - span, + span: span.clone(), items, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }); } @@ -505,11 +574,11 @@ impl<'source> Parser<'source> { Ok((term, query)) => { span.end = self.end; return Ok(Expr::ObjectCompr { - span, + span: span.clone(), key: Ref::new(first), value: Ref::new(term), query: Ref::new(query), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }); } Err(err) if self.end != pos => { @@ -549,9 +618,9 @@ impl<'source> Parser<'source> { span.end = self.end; Ok(Expr::Object { - span, + span: span.clone(), fields: items, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } @@ -561,9 +630,9 @@ impl<'source> Parser<'source> { self.expect(")", "while parsing empty set")?; span.end = self.tok.1.end; Ok(Expr::Set { - span, + span: span.clone(), items: vec![], - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } @@ -581,9 +650,9 @@ impl<'source> Parser<'source> { let expr = self.parse_in_expr()?; span.end = self.end; Ok(Expr::UnaryExpr { - span, + span: span.clone(), expr: Ref::new(expr), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } @@ -654,10 +723,10 @@ impl<'source> Parser<'source> { } let fieldv = Value::from(field.text()); term = Expr::RefDot { - span, + span: span.clone(), refr: Ref::new(term), field: (field, fieldv), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } "[" => { @@ -671,10 +740,10 @@ impl<'source> Parser<'source> { span.end = self.end; term = Expr::RefBrack { - span, + span: span.clone(), refr: Ref::new(term), index: Ref::new(index), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } "(" if possible_fcn => { @@ -694,10 +763,10 @@ impl<'source> Parser<'source> { self.expect(")", "while parsing call expr")?; span.end = self.end; term = Expr::Call { - span, + span: span.clone(), fcn: Ref::new(term), params: args, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; // The expression can no longer be a function after the call. @@ -731,11 +800,11 @@ impl<'source> Parser<'source> { let right = self.parse_term()?; span.end = self.end; expr = Expr::ArithExpr { - span, + span: span.clone(), op, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } } @@ -767,11 +836,11 @@ impl<'source> Parser<'source> { }; span.end = self.end; expr = Expr::ArithExpr { - span, + span: span.clone(), op, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } } @@ -787,11 +856,11 @@ impl<'source> Parser<'source> { let right = self.parse_arith_expr()?; span.end = self.end; expr = Expr::BinExpr { - span, + span: span.clone(), op: BinOp::Intersection, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } Ok(expr) @@ -808,11 +877,11 @@ impl<'source> Parser<'source> { let right = self.parse_set_intersection_expr()?; span.end = self.end; expr = Expr::BinExpr { - span, + span: span.clone(), op: BinOp::Union, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } Ok(expr) @@ -837,11 +906,11 @@ impl<'source> Parser<'source> { let right = self.parse_set_union_expr()?; span.end = self.end; expr = Expr::BoolExpr { - span, + span: span.clone(), op, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } Ok(expr) @@ -864,11 +933,11 @@ impl<'source> Parser<'source> { None => (None, Ref::new(expr1)), }; expr1 = Expr::Membership { - span, + span: span.clone(), key, value, collection: Ref::new(expr3), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; expr2 = None; @@ -909,10 +978,10 @@ impl<'source> Parser<'source> { self.next_token()?; let rhs = self.parse_membership_expr()?; expr = Expr::OrExpr { - span, + span: span.clone(), lhs: Ref::new(expr), rhs: Ref::new(rhs), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } Ok(expr) @@ -966,11 +1035,11 @@ impl<'source> Parser<'source> { let right = self.parse_expr()?; span.end = self.end; Ok(Expr::AssignExpr { - span, + span: span.clone(), op, lhs: Ref::new(expr), rhs: Ref::new(right), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) } @@ -1070,7 +1139,10 @@ impl<'source> Parser<'source> { span.end = self.end; // Since exprs are discarded, adjust the expression index counter. - self.eidx -= vars.len() as u32; + let discarded_count = vars.len() as u32; + self.eidx -= discarded_count; + // Also remove the corresponding spans from expression_spans + self.expression_spans.truncate((self.eidx) as usize); return Ok(Literal::SomeVars { span, vars }); } @@ -1134,10 +1206,10 @@ impl<'source> Parser<'source> { span.end = self.end; Ok(LiteralStmt { - span, + span: span.clone(), literal, with_mods, - sidx: self.next_sidx(), + sidx: self.next_sidx_with_span(span), }) } @@ -1212,10 +1284,11 @@ impl<'source> Parser<'source> { self.expect(end_delim, "while parsing query")?; } span.end = self.end; + let qidx = self.next_qidx_with_span(span.clone()); Ok(Query { span, stmts: literals, - qidx: self.next_qidx(), + qidx, }) } @@ -1254,9 +1327,9 @@ impl<'source> Parser<'source> { let (span, value) = Self::span_and_value(var); let mut refr = Expr::Var { - span, + span: span.clone(), value, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; loop { let mut span = self.tok.1.clone(); @@ -1291,10 +1364,10 @@ impl<'source> Parser<'source> { ); } refr = Expr::RefDot { - span, + span: span.clone(), refr: Ref::new(refr), field: Self::span_and_value(field), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } "[" => { @@ -1303,9 +1376,9 @@ impl<'source> Parser<'source> { TokenKind::String => { let (span, value) = Self::span_and_value(self.tok.1.clone()); Expr::String { - span, + span: span.clone(), value, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), } } _ => { @@ -1320,10 +1393,10 @@ impl<'source> Parser<'source> { self.expect("]", "while parsing bracketed reference")?; span.end = self.end; refr = Expr::RefBrack { - span, + span: span.clone(), refr: Ref::new(refr), index: Ref::new(index), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } _ => break, @@ -1349,9 +1422,9 @@ impl<'source> Parser<'source> { } let (span, value) = Self::span_and_value(v); Expr::Var { - span, + span: span.clone(), value, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), } } else { return Err(self.source.error( @@ -1394,10 +1467,10 @@ impl<'source> Parser<'source> { ); } term = Expr::RefDot { - span, + span: span.clone(), refr: Ref::new(term), field: Self::span_and_value(field), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } "[" => { @@ -1406,10 +1479,10 @@ impl<'source> Parser<'source> { span.end = self.end; self.expect("]", "while parsing bracketed reference")?; term = Expr::RefBrack { - span, + span: span.clone(), refr: Ref::new(term), index: Ref::new(index), - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }; } _ => break, @@ -1473,6 +1546,8 @@ impl<'source> Parser<'source> { { // Adjust the expression counter since we are discarding the RefBrack expression. self.eidx -= 1; + // Also remove the corresponding span from expression_spans + self.expression_spans.truncate((self.eidx) as usize); return Ok(RuleHead::Set { span, refr: refr.clone(), @@ -1525,11 +1600,8 @@ impl<'source> Parser<'source> { *self = state; let stmts = vec![self.parse_literal_stmt()?]; span.end = self.end; - Ok(Query { - span, - stmts, - qidx: self.next_qidx(), - }) + let qidx = self.next_qidx_with_span(span.clone()); + Ok(Query { span, stmts, qidx }) } pub fn parse_rule_bodies(&mut self) -> Result> { @@ -1649,10 +1721,11 @@ impl<'source> Parser<'source> { _ => { let mut query_span = span.clone(); query_span.end = query_span.start; + let qidx = self.next_qidx_with_span(query_span.clone()); let query = Ref::new(Query { span: query_span, stmts: vec![], - qidx: self.next_qidx(), + qidx, }); span.end = self.end; bodies.push(RuleBody { @@ -1707,6 +1780,7 @@ impl<'source> Parser<'source> { let value = Ref::new(self.parse_term()?); span.end = self.end; Ok(Rule::Default { + ridx: self.next_ridx_with_span(span.clone()), span, refr: rule_ref, args: args @@ -1714,9 +1788,9 @@ impl<'source> Parser<'source> { .map(|a| { let (span, value) = Self::span_and_value(a); Ref::new(Expr::Var { - span, + span: span.clone(), value, - eidx: self.next_eidx(), + eidx: self.next_eidx_with_span(span), }) }) .collect(), @@ -1762,7 +1836,12 @@ impl<'source> Parser<'source> { } } - Ok(Rule::Spec { span, head, bodies }) + Ok(Rule::Spec { + ridx: self.next_ridx_with_span(span.clone()), + span, + head, + bodies, + }) } pub fn parse_package(&mut self) -> Result { @@ -1955,6 +2034,11 @@ impl<'source> Parser<'source> { num_expressions: self.eidx, num_statements: self.sidx, num_queries: self.qidx, + num_rules: self.ridx, + expression_spans: self.expression_spans.clone(), + statement_spans: self.statement_spans.clone(), + query_spans: self.query_spans.clone(), + rule_spans: self.rule_spans.clone(), }; #[cfg(debug_assertions)] diff --git a/tests/parser/cases/rules/set.yaml b/tests/parser/cases/rules/set.yaml index a3fed7c5..9c72f4a9 100644 --- a/tests/parser/cases/rules/set.yaml +++ b/tests/parser/cases/rules/set.yaml @@ -176,3 +176,73 @@ cases: bool: true eidx: 20 sidx: 5 + + - note: refbrack-to-set-span-truncation + rego: | + package test + import future.keywords + + # Tests RefBrack to Set rule conversion with proper span truncation. + # When parsing "data["key"]", the parser creates a RefBrack expression, + # then converts it to a Set rule head and discards the RefBrack. + # This test ensures expression_spans are properly truncated when + # expressions are discarded during parsing. + data["key"] + num_expressions: 5 + num_statements: 0 + num_queries: 0 + policy: + - spec: + span: data["key"] + head: + set: + span: data["key"] + refr: + span: data + var: data + eidx: 3 + key: + span: "key" + string: key + eidx: 4 + bodies: [] + + - note: refbrack-to-set-span-truncation-with-body + rego: | + package test + import future.keywords + + # Tests RefBrack to Set rule conversion with body and proper span truncation. + # Similar to the above test but with a rule body { true }. + # Ensures span tracking works correctly for Set rules that have both + # a converted RefBrack head and query body statements. + data["key"] { true } + num_expressions: 6 + num_statements: 1 + num_queries: 1 + policy: + - spec: + span: data["key"] { true } + head: + set: + span: data["key"] + refr: + span: data + var: data + eidx: 3 + key: + span: "key" + string: key + eidx: 4 + bodies: + - query: + span: "{ true }" + qidx: 0 + stmts: + - span: "true" + literal: + expr: + span: "true" + bool: true + eidx: 5 + sidx: 0 diff --git a/tests/parser/mod.rs b/tests/parser/mod.rs index 00a26e63..81a21bca 100644 --- a/tests/parser/mod.rs +++ b/tests/parser/mod.rs @@ -637,7 +637,9 @@ fn match_rule_bodies(span: &Span, bodies: &[RuleBody], v: &Value) -> Result<()> fn match_rule(r: &Rule, v: &Value) -> Result<()> { match r { - Rule::Spec { span, head, bodies } => { + Rule::Spec { + span, head, bodies, .. + } => { let obj = &v["spec"]; match_span_opt(span, &obj["span"])?; match_rule_head(head, &obj["head"])?; @@ -649,6 +651,7 @@ fn match_rule(r: &Rule, v: &Value) -> Result<()> { args, op, value, + .. } => { let obj = &v["default"]; match_span_opt(span, &obj["span"])?; @@ -708,6 +711,335 @@ struct YamlTest { cases: Vec, } +fn verify_expression_spans(module: &Module) -> Result<()> { + // Recursively verify that expressions have the correct spans + fn check_expr_span(expr: &Expr, expression_spans: &[Span]) -> Result<()> { + let eidx = expr.eidx() as usize; + if eidx >= expression_spans.len() { + bail!( + "Expression eidx {} out of bounds for expression_spans length {}", + eidx, + expression_spans.len() + ); + } + + let expected_span = expr.span(); + let actual_span = &expression_spans[eidx]; + + // Check that the spans have the same text content + if expected_span.text() != actual_span.text() { + bail!( + "Expression span mismatch at eidx {}: expected '{}', got '{}'", + eidx, + expected_span.text(), + actual_span.text() + ); + } + + // Check that the spans have the same source positions + if expected_span.start != actual_span.start || expected_span.end != actual_span.end { + bail!( + "Expression span position mismatch at eidx {}: expected {}..{}, got {}..{}", + eidx, + expected_span.start, + expected_span.end, + actual_span.start, + actual_span.end + ); + } + + // Recursively check nested expressions + match expr { + Expr::Array { items, .. } | Expr::Set { items, .. } => { + for item in items { + check_expr_span(item, expression_spans)?; + } + } + Expr::Object { fields, .. } => { + for (_, key, value) in fields { + check_expr_span(key, expression_spans)?; + check_expr_span(value, expression_spans)?; + } + } + Expr::ArrayCompr { term, query, .. } | Expr::SetCompr { term, query, .. } => { + check_expr_span(term, expression_spans)?; + check_query_expr_spans(query, expression_spans)?; + } + Expr::ObjectCompr { + key, value, query, .. + } => { + check_expr_span(key, expression_spans)?; + check_expr_span(value, expression_spans)?; + check_query_expr_spans(query, expression_spans)?; + } + Expr::Call { fcn, params, .. } => { + check_expr_span(fcn, expression_spans)?; + for param in params { + check_expr_span(param, expression_spans)?; + } + } + Expr::RefDot { refr, .. } => { + check_expr_span(refr, expression_spans)?; + } + Expr::RefBrack { refr, index, .. } => { + check_expr_span(refr, expression_spans)?; + check_expr_span(index, expression_spans)?; + } + Expr::UnaryExpr { expr, .. } => { + check_expr_span(expr, expression_spans)?; + } + Expr::BinExpr { lhs, rhs, .. } + | Expr::ArithExpr { lhs, rhs, .. } + | Expr::BoolExpr { lhs, rhs, .. } + | Expr::AssignExpr { lhs, rhs, .. } => { + check_expr_span(lhs, expression_spans)?; + check_expr_span(rhs, expression_spans)?; + } + Expr::Membership { + key, + value, + collection, + .. + } => { + if let Some(key) = key { + check_expr_span(key, expression_spans)?; + } + check_expr_span(value, expression_spans)?; + check_expr_span(collection, expression_spans)?; + } + #[cfg(feature = "rego-extensions")] + Expr::OrExpr { lhs, rhs, .. } => { + check_expr_span(lhs, expression_spans)?; + check_expr_span(rhs, expression_spans)?; + } + // Leaf expressions - no nested expressions to check + Expr::String { .. } + | Expr::RawString { .. } + | Expr::Number { .. } + | Expr::Bool { .. } + | Expr::Null { .. } + | Expr::Var { .. } => {} + } + + Ok(()) + } + + fn check_query_expr_spans(query: &Query, expression_spans: &[Span]) -> Result<()> { + for stmt in &query.stmts { + check_literal_expr_spans(&stmt.literal, expression_spans)?; + } + Ok(()) + } + + fn check_literal_expr_spans(literal: &Literal, expression_spans: &[Span]) -> Result<()> { + match literal { + Literal::Expr { expr, .. } | Literal::NotExpr { expr, .. } => { + check_expr_span(expr, expression_spans)?; + } + Literal::SomeIn { + value, + key, + collection, + .. + } => { + check_expr_span(value, expression_spans)?; + if let Some(key) = key { + check_expr_span(key, expression_spans)?; + } + check_expr_span(collection, expression_spans)?; + } + Literal::Every { domain, query, .. } => { + check_expr_span(domain, expression_spans)?; + check_query_expr_spans(query, expression_spans)?; + } + Literal::SomeVars { .. } => { + // No expressions to check + } + } + Ok(()) + } + + // Check package expression + check_expr_span(&module.package.refr, &module.expression_spans)?; + + // Check import expressions + for import in &module.imports { + check_expr_span(&import.refr, &module.expression_spans)?; + } + + // Check policy expressions + for rule in &module.policy { + match rule.as_ref() { + Rule::Spec { head, bodies, .. } => { + match head { + RuleHead::Compr { refr, assign, .. } => { + check_expr_span(refr, &module.expression_spans)?; + if let Some(assign) = assign { + check_expr_span(&assign.value, &module.expression_spans)?; + } + } + RuleHead::Set { refr, key, .. } => { + check_expr_span(refr, &module.expression_spans)?; + if let Some(key) = key { + check_expr_span(key, &module.expression_spans)?; + } + } + RuleHead::Func { + refr, args, assign, .. + } => { + check_expr_span(refr, &module.expression_spans)?; + for arg in args { + check_expr_span(arg, &module.expression_spans)?; + } + if let Some(assign) = assign { + check_expr_span(&assign.value, &module.expression_spans)?; + } + } + } + + for body in bodies { + if let Some(assign) = &body.assign { + check_expr_span(&assign.value, &module.expression_spans)?; + } + check_query_expr_spans(&body.query, &module.expression_spans)?; + } + } + Rule::Default { + refr, args, value, .. + } => { + check_expr_span(refr, &module.expression_spans)?; + for arg in args { + check_expr_span(arg, &module.expression_spans)?; + } + check_expr_span(value, &module.expression_spans)?; + } + } + } + + Ok(()) +} + +fn verify_statement_spans(module: &Module) -> Result<()> { + // Recursively verify that statements have the correct spans + fn check_stmt_span(stmt: &LiteralStmt, statement_spans: &[Span]) -> Result<()> { + let sidx = stmt.sidx as usize; + if sidx >= statement_spans.len() { + bail!( + "Statement sidx {} out of bounds for statement_spans length {}", + sidx, + statement_spans.len() + ); + } + + let expected_span = &stmt.span; + let actual_span = &statement_spans[sidx]; + + // Check that the spans have the same text content + if expected_span.text() != actual_span.text() { + bail!( + "Statement span mismatch at sidx {}: expected '{}', got '{}'", + sidx, + expected_span.text(), + actual_span.text() + ); + } + + // Check that the spans have the same source positions + if expected_span.start != actual_span.start || expected_span.end != actual_span.end { + bail!( + "Statement span position mismatch at sidx {}: expected {}..{}, got {}..{}", + sidx, + expected_span.start, + expected_span.end, + actual_span.start, + actual_span.end + ); + } + + Ok(()) + } + + fn check_query_stmt_spans(query: &Query, statement_spans: &[Span]) -> Result<()> { + for stmt in &query.stmts { + check_stmt_span(stmt, statement_spans)?; + } + Ok(()) + } + + // Check statements in policy rules + for rule in &module.policy { + match rule.as_ref() { + Rule::Spec { bodies, .. } => { + for body in bodies { + check_query_stmt_spans(&body.query, &module.statement_spans)?; + } + } + Rule::Default { .. } => { + // Default rules don't have statement bodies + } + } + } + + Ok(()) +} + +fn verify_query_spans(module: &Module) -> Result<()> { + // Recursively verify that queries have the correct spans + fn check_query_span(query: &Query, query_spans: &[Span]) -> Result<()> { + let qidx = query.qidx as usize; + if qidx >= query_spans.len() { + bail!( + "Query qidx {} out of bounds for query_spans length {}", + qidx, + query_spans.len() + ); + } + + let expected_span = &query.span; + let actual_span = &query_spans[qidx]; + + // Check that the spans have the same text content + if expected_span.text() != actual_span.text() { + bail!( + "Query span mismatch at qidx {}: expected '{}', got '{}'", + qidx, + expected_span.text(), + actual_span.text() + ); + } + + // Check that the spans have the same source positions + if expected_span.start != actual_span.start || expected_span.end != actual_span.end { + bail!( + "Query span position mismatch at qidx {}: expected {}..{}, got {}..{}", + qidx, + expected_span.start, + expected_span.end, + actual_span.start, + actual_span.end + ); + } + + Ok(()) + } + + // Check queries in policy rules + for rule in &module.policy { + match rule.as_ref() { + Rule::Spec { bodies, .. } => { + for body in bodies { + check_query_span(body.query.as_ref(), &module.query_spans)?; + } + } + Rule::Default { .. } => { + // Default rules don't have query bodies + } + } + } + + Ok(()) +} + fn yaml_test_impl(file: &str) -> Result<()> { println!("\nrunning {file}"); @@ -730,6 +1062,13 @@ fn yaml_test_impl(file: &str) -> Result<()> { "mismatch in num_expressions" ); + // Verify that expression_spans count matches num_expressions + my_assert_eq!( + module.expression_spans.len() as u32, + module.num_expressions, + "expression_spans count should match num_expressions" + ); + my_assert_eq!( module.num_statements, case.num_statements, @@ -771,6 +1110,25 @@ fn yaml_test_impl(file: &str) -> Result<()> { match_rule(&module.policy[idx], policy)?; } } + + // Also verify that statement spans count matches num_statements + if module.statement_spans.len() != module.num_statements as usize { + bail!("statement_spans count should match num_statements"); + } + + // Also verify that query spans count matches num_queries + if module.query_spans.len() != module.num_queries as usize { + bail!("query_spans count should match num_queries"); + } + + // Test expression spans for specific expressions in the AST + verify_expression_spans(&module)?; + + // Test statement spans for specific statements in the AST + verify_statement_spans(&module)?; + + // Test query spans for specific queries in the AST + verify_query_spans(&module)?; } Err(actual) => match &case.error { Some(expected) => {