Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions src/models/capture_group_patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ Copyright (c) 2023 Uber Technologies, Inc.
use crate::{
models::Validator,
utilities::{
tree_sitter_utilities::{get_ts_query_parser, number_of_errors},
tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors},
Instantiate,
},
};
use pyo3::prelude::pyclass;
use regex::Regex;
use serde_derive::Deserialize;
use std::collections::HashMap;
use tree_sitter::{Node, Query};

use super::matches::Match;

#[pyclass]
#[derive(Deserialize, Debug, Clone, Default, PartialEq, Hash, Eq)]
Expand All @@ -38,12 +42,18 @@ impl CGPattern {

impl Validator for CGPattern {
fn validate(&self) -> Result<(), String> {
if self.pattern().starts_with("rgx ") {
panic!("Regex not supported")
}
let mut parser = get_ts_query_parser();
parser
.parse(self.pattern(), None)
.filter(|x| number_of_errors(&x.root_node()) == 0)
.map(|_| Ok(()))
.unwrap_or(Err(format!("Cannot parse - {}", self.pattern())))
.unwrap_or(Err(format!(
"Cannot parse the tree-sitter query - {}",
self.pattern()
)))
}
}

Expand All @@ -56,3 +66,44 @@ impl Instantiate for CGPattern {
CGPattern::new(self.pattern().instantiate(&substitutions))
}
}

#[derive(Debug)]
pub(crate) enum CompiledCGPattern {
Q(Query),
R(Regex), // Regex is not yet supported
}

impl CompiledCGPattern {
/// Applies the CGPattern (self) upon the input `node`, and returns the first match
/// # Arguments
/// * `node` - the root node to apply the query upon
/// * `source_code` - the corresponding source code string for the node.
/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`.
pub(crate) fn get_match(&self, node: &Node, source_code: &str, recursive: bool) -> Option<Match> {
if let Some(m) = self
.get_matches(node, source_code.to_string(), recursive, None, None)
.first()
{
return Some(m.clone());
}
None
}

/// Applies the pattern upon the given `node`, and gets all the matches
pub(crate) fn get_matches(
&self, node: &Node, source_code: String, recursive: bool, replace_node: Option<String>,
replace_node_idx: Option<u8>,
) -> Vec<Match> {
match self {
CompiledCGPattern::Q(query) => get_all_matches_for_query(
node,
source_code,
query,
recursive,
replace_node,
replace_node_idx,
),
CompiledCGPattern::R(_) => panic!("Regex is not yet supported!!!"),
}
}
}
21 changes: 5 additions & 16 deletions src/models/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ use pyo3::prelude::{pyclass, pymethods};
use serde_derive::Deserialize;
use tree_sitter::Node;

use crate::utilities::{
gen_py_str_methods,
tree_sitter_utilities::{get_all_matches_for_query, get_match_for_query, get_node_for_range},
};
use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range};

use super::{
capture_group_patterns::CGPattern, default_configs::default_child_count,
Expand Down Expand Up @@ -415,9 +412,8 @@ impl SourceCodeUnit {
}

while let Some(parent) = current_node.parent() {
if let Some(p_match) =
get_match_for_query(&parent, self.code(), rule_store.query(ts_query), false)
{
let pattern = rule_store.query(ts_query);
if let Some(p_match) = pattern.get_match(&parent, self.code(), false) {
let matched_ancestor = get_node_for_range(
self.root_node(),
p_match.range().start_byte,
Expand All @@ -442,14 +438,7 @@ impl SourceCodeUnit {

// Retrieve all matches within the ancestor node
let contains_query = &rule_store.query(filter.contains());
let matches = get_all_matches_for_query(
ancestor,
self.code().to_string(),
contains_query,
true,
None,
None,
);
let matches = contains_query.get_matches(ancestor, self.code().to_string(), true, None, None);
let at_least = filter.at_least as usize;
let at_most = filter.at_most as usize;
// Validate if the count of matches falls within the expected range
Expand All @@ -464,7 +453,7 @@ impl SourceCodeUnit {
// Check if there's a match within the scope node
// If one of the filters is not satisfied, return false
let query = &rule_store.query(ts_query);
if get_match_for_query(ancestor, self.code(), query, true).is_some() {
if query.get_match(ancestor, self.code(), true).is_some() {
return false;
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/models/matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ use pyo3::prelude::{pyclass, pymethods};
use serde_derive::{Deserialize, Serialize};
use tree_sitter::Node;

use crate::utilities::{
gen_py_str_methods,
tree_sitter_utilities::{get_all_matches_for_query, get_node_for_range},
};
use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range};

use super::{
piranha_arguments::PiranhaArguments, rule::InstantiatedRule, rule_store::RuleStore,
Expand Down Expand Up @@ -291,10 +288,11 @@ impl SourceCodeUnit {
} else {
(rule.replace_node(), rule.replace_idx())
};
let mut all_query_matches = get_all_matches_for_query(

let pattern = rule_store.query(&rule.query());
let mut all_query_matches = pattern.get_matches(
&node,
self.code().to_string(),
rule_store.query(&rule.query()),
recursive,
replace_node_tag,
replace_node_idx,
Expand Down
20 changes: 13 additions & 7 deletions src/models/rule_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ use itertools::Itertools;
use jwalk::WalkDir;
use log::{debug, trace};
use regex::Regex;
use tree_sitter::Query;

use crate::{
models::capture_group_patterns::CGPattern, models::piranha_arguments::PiranhaArguments,
models::scopes::ScopeQueryGenerator, utilities::read_file,
};

use super::{language::PiranhaLanguage, rule::InstantiatedRule};
use super::{
capture_group_patterns::CompiledCGPattern, language::PiranhaLanguage, rule::InstantiatedRule,
};
use glob::Pattern;

/// This maintains the state for Piranha.
#[derive(Debug, Getters, Default)]
pub(crate) struct RuleStore {
// Caches the compiled tree-sitter queries.
rule_query_cache: HashMap<String, Query>,
rule_query_cache: HashMap<String, CompiledCGPattern>,
// Current global rules to be applied.
#[get = "pub"]
global_rules: Vec<InstantiatedRule>,
Expand Down Expand Up @@ -75,11 +76,16 @@ impl RuleStore {

/// Get the compiled query for the `query_str` from the cache
/// else compile it, add it to the cache and return it.
pub(crate) fn query(&mut self, query_str: &CGPattern) -> &Query {
self
pub(crate) fn query(&mut self, cg_pattern: &CGPattern) -> &CompiledCGPattern {
let pattern = cg_pattern.pattern();
if pattern.starts_with("rgx ") {
panic!("Regex not supported.")
}

&*self
.rule_query_cache
.entry(query_str.pattern())
.or_insert_with(|| self.language.create_query(query_str.pattern()))
.entry(pattern.to_string())
.or_insert_with(|| CompiledCGPattern::Q(self.language.create_query(pattern)))
}

// For the given scope level, get the ScopeQueryGenerator from the `scope_config.toml` file
Expand Down
9 changes: 2 additions & 7 deletions src/models/scopes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Copyright (c) 2023 Uber Technologies, Inc.

use super::capture_group_patterns::CGPattern;
use super::{rule_store::RuleStore, source_code_unit::SourceCodeUnit};
use crate::utilities::tree_sitter_utilities::get_match_for_query;
use crate::utilities::tree_sitter_utilities::get_node_for_range;
use crate::utilities::Instantiate;
use derive_builder::Builder;
Expand Down Expand Up @@ -65,12 +64,8 @@ impl SourceCodeUnit {
changed_node.kind()
);
for m in &scope_enclosing_nodes {
if let Some(p_match) = get_match_for_query(
&changed_node,
self.code(),
rules_store.query(m.enclosing_node()),
false,
) {
let pattern = rules_store.query(m.enclosing_node());
if let Some(p_match) = pattern.get_match(&changed_node, self.code(), false) {
// Generate the scope query for the specific context by substituting the
// the tags with code snippets appropriately in the `generator` query.
return m.scope().instantiate(p_match.matches());
Expand Down
12 changes: 3 additions & 9 deletions src/models/source_code_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use crate::{
models::capture_group_patterns::CGPattern,
models::rule_graph::{GLOBAL, PARENT},
utilities::tree_sitter_utilities::{
get_match_for_query, get_node_for_range, get_replace_range, get_tree_sitter_edit,
number_of_errors,
get_node_for_range, get_replace_range, get_tree_sitter_edit, number_of_errors,
},
};

Expand Down Expand Up @@ -299,13 +298,8 @@ impl SourceCodeUnit {
// let mut scope_node = self.root_node();
if let Some(query_str) = scope_query {
// Apply the scope query in the source code and get the appropriate node
let tree_sitter_scope_query = rules_store.query(query_str);
if let Some(p_match) = get_match_for_query(
&self.root_node(),
self.code(),
tree_sitter_scope_query,
true,
) {
let scope_pattern = rules_store.query(query_str);
if let Some(p_match) = scope_pattern.get_match(&self.root_node(), self.code(), true) {
return get_node_for_range(
self.root_node(),
p_match.range().start_byte,
Expand Down
10 changes: 10 additions & 0 deletions src/models/unit_tests/rule_graph_validation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,13 @@ fn test_filter_bad_arg_contains_n_sibling() {
.sibling_count(2)
.build();
}

#[test]
#[should_panic(expected = "Regex not supported")]
fn test_unsupported_regex() {
RuleGraphBuilder::default()
.rules(vec![
piranha_rule! {name = "Test rule", query = "rgx (\\w+) (\\w)+"},
])
.build();
}
22 changes: 0 additions & 22 deletions src/utilities/tree_sitter_utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@ use std::collections::HashMap;
use tree_sitter::{InputEdit, Node, Parser, Point, Query, QueryCapture, QueryCursor, Range};
use tree_sitter_traversal::{traverse, Order};

/// Applies the query upon the given node, and gets all the matches
/// # Arguments
/// * `node` - the root node to apply the query upon
/// * `source_code` - the corresponding source code string for the node.
/// * `query` - the query to be applied
/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`.
///
/// # Returns
/// A vector of `tuples` containing the range of the matches in the source code and the corresponding mapping for the tags (to code snippets).
/// By default it returns the range of the outermost node for each query match.
/// If `replace_node` is provided in the rule, it returns the range of the node corresponding to that tag.
pub(crate) fn get_match_for_query(
node: &Node, source_code: &str, query: &Query, recursive: bool,
) -> Option<Match> {
if let Some(m) =
get_all_matches_for_query(node, source_code.to_string(), query, recursive, None, None).first()
{
return Some(m.clone());
}
None
}

/// Applies the query upon the given `node`, and gets the first match
/// # Arguments
/// * `node` - the root node to apply the query upon
Expand Down