Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 51 additions & 58 deletions shared/yeast/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,48 @@ pub const CHILD_FIELD: u16 = u16::MAX;
#[derive(Debug)]
pub struct AstCursor<'a> {
ast: &'a Ast,
/// A stack of parents, along with iterators for their children
parents: Vec<(&'a Node, ChildrenIter<'a>)>,
node: &'a Node,
/// A stack of parents, along with iterators for their children.
parents: Vec<(Id, ChildrenIter<'a>)>,
node_id: Id,
}

impl<'a> AstCursor<'a> {
pub fn new(ast: &'a Ast) -> Self {
// TODO: handle non-zero root
let node = ast.get_node(ast.root).unwrap();
Self {
ast,
parents: vec![],
node,
node_id: ast.root,
}
}

/// The Id of the node currently under the cursor.
pub fn node_id(&self) -> Id {
self.node_id
}

fn goto_next_sibling_opt(&mut self) -> Option<()> {
self.node = self.parents.last_mut()?.1.next()?;
self.node_id = self.parents.last_mut()?.1.next()?;
Some(())
}

fn goto_first_child_opt(&mut self) -> Option<()> {
let parent = self.node;
let mut children = ChildrenIter::new(self.ast, parent);
let parent_id = self.node_id;
let parent = self.ast.get_node(parent_id)?;
let mut children = ChildrenIter::new(parent);
let first_child = children.next()?;
self.node = first_child;
self.parents.push((parent, children));
self.node_id = first_child;
self.parents.push((parent_id, children));
Some(())
}

fn goto_parent_opt(&mut self) -> Option<()> {
self.node = self.parents.pop()?.0;
self.node_id = self.parents.pop()?.0;
Some(())
}
}
impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
fn node(&self) -> &'a Node {
self.node
&self.ast.nodes[self.node_id]
}

fn field_id(&self) -> Option<FieldId> {
Expand Down Expand Up @@ -101,36 +105,30 @@ impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
}
}

/// An iterator over all the child nodes of a node.
/// An iterator over the child Ids of a node.
#[derive(Debug)]
struct ChildrenIter<'a> {
ast: &'a Ast,
current_field: Option<FieldId>,
fields: std::collections::btree_map::Iter<'a, FieldId, Vec<Id>>,
field_children: Option<std::slice::Iter<'a, Id>>,
}

impl<'a> ChildrenIter<'a> {
fn new(ast: &'a Ast, node: &'a Node) -> Self {
fn new(node: &'a Node) -> Self {
Self {
ast,
current_field: None,
fields: node.fields.iter(),
field_children: None,
}
}

fn get_node(&self, id: Id) -> &'a Node {
self.ast.get_node(id).unwrap()
}

fn current_field(&self) -> Option<FieldId> {
self.current_field
}
}

impl<'a> Iterator for ChildrenIter<'a> {
type Item = &'a Node;
impl Iterator for ChildrenIter<'_> {
type Item = Id;

fn next(&mut self) -> Option<Self::Item> {
match self.field_children.as_mut() {
Expand All @@ -151,7 +149,7 @@ impl<'a> Iterator for ChildrenIter<'a> {
self.next()
}
},
Some(child_id) => Some(self.get_node(*child_id)),
Some(child_id) => Some(*child_id),
},
}
}
Expand Down Expand Up @@ -236,7 +234,6 @@ impl Ast {
) -> Id {
let id = self.nodes.len();
self.nodes.push(Node {
id,
kind,
kind_name: self.schema.node_kind_for_id(kind).unwrap(),
fields,
Expand Down Expand Up @@ -265,7 +262,6 @@ impl Ast {
});
let id = self.nodes.len();
self.nodes.push(Node {
id,
kind: kind_id,
kind_name: kind,
is_named: true,
Expand Down Expand Up @@ -345,7 +341,6 @@ impl Ast {
/// A node in our AST
#[derive(PartialEq, Eq, Debug, Clone, Serialize)]
pub struct Node {
id: Id,
kind: KindId,
kind_name: &'static str,
pub(crate) fields: BTreeMap<FieldId, Vec<Id>>,
Expand All @@ -361,10 +356,6 @@ pub struct Node {
}

impl Node {
pub fn id(&self) -> Id {
self.id
}

pub fn kind(&self) -> &'static str {
self.kind_name
}
Expand Down Expand Up @@ -600,39 +591,41 @@ fn apply_rules_inner(
}
}

// Collect fields before recursing (avoids borrowing ast immutably during mutation)
let field_entries: Vec<(FieldId, Vec<Id>)> = ast.nodes[id]
.fields
.iter()
.map(|(&fid, children)| (fid, children.clone()))
.collect();

// recursively descend into all the fields
// Take the parent's fields by ownership: the recursion will rewrite
// each child Id, and we'll write the (possibly mutated) field map back
// when we're done. Avoids cloning the whole BTreeMap and its child
// Vecs on entry. Each child Vec is only re-allocated if a rewrite
// actually changes its contents.
//
// Child traversal does not increment rewrite depth and starts fresh
// (no rule is skipped on child subtrees).
let mut changed = false;
let mut new_fields = BTreeMap::new();
for (field_id, children) in field_entries {
let mut new_children = Vec::new();
for child_id in children {
let mut fields = std::mem::take(&mut ast.nodes[id].fields);
for children in fields.values_mut() {
let mut new_children: Option<Vec<Id>> = None;
for (i, &child_id) in children.iter().enumerate() {
let result = apply_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?;
if result.len() != 1 || result[0] != child_id {
changed = true;
let unchanged = result.len() == 1 && result[0] == child_id;
match (&mut new_children, unchanged) {
(None, true) => {} // unchanged so far, no allocation needed
(None, false) => {
// First divergence — copy already-processed Ids and
// start collecting the rewritten sequence.
let mut new = Vec::with_capacity(children.len());
new.extend_from_slice(&children[..i]);
new.extend(result);
new_children = Some(new);
}
(Some(new), _) => {
new.extend(result);
}
}
new_children.extend(result);
}
new_fields.insert(field_id, new_children);
}

if !changed {
return Ok(vec![id]);
if let Some(new) = new_children {
*children = new;
}
}

let mut node = ast.nodes[id].clone();
node.fields = new_fields;
node.id = ast.nodes.len();
ast.nodes.push(node);
Ok(vec![ast.nodes.len() - 1])
ast.nodes[id].fields = fields;
Ok(vec![id])
}

/// One phase of a desugaring pass: a named bundle of rules that runs to
Expand Down
10 changes: 4 additions & 6 deletions shared/yeast/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl Visitor {

pub fn build_with_schema(self, schema: crate::schema::Schema) -> Ast {
Ast {
root: self.nodes[0].inner.id,
root: 0,
schema,
nodes: self.nodes.into_iter().map(|n| n.inner).collect(),
}
Expand All @@ -59,7 +59,6 @@ impl Visitor {
let id = self.nodes.len();
self.nodes.push(VisitorNode {
inner: Node {
id,
kind: self.language.id_for_node_kind(n.kind(), is_named),
kind_name: n.kind(),
content,
Expand All @@ -82,11 +81,10 @@ impl Visitor {
}

fn leave_node(&mut self, field_name: Option<&'static str>, _node: tree_sitter::Node<'_>) {
let node = self.current.map(|i| &self.nodes[i]).unwrap();
let node_id = node.inner.id;
let node_parent = node.parent;
let node_id = self.current.unwrap();
let node_parent = self.nodes[node_id].parent;

if let Some(parent_id) = node.parent {
if let Some(parent_id) = node_parent {
let parent = self.nodes.get_mut(parent_id).unwrap();
if let Some(field) = field_name {
let field_id = self.language.field_id_for_name(field).unwrap().get();
Expand Down
14 changes: 7 additions & 7 deletions shared/yeast/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ fn test_query_repeated_capture() {
// Match against the assignment node (first named child of program)
let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child();
let assignment_id = cursor.node().id();
let assignment_id = cursor.node_id();

let mut captures = yeast::captures::Captures::new();
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
Expand All @@ -206,7 +206,7 @@ fn test_capture_unnamed_node_parenthesized() {

let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child();
let assignment_id = cursor.node().id();
let assignment_id = cursor.node_id();

let mut captures = yeast::captures::Captures::new();
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
Expand All @@ -233,7 +233,7 @@ fn test_capture_unnamed_node_bare_literal() {

let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child();
let assignment_id = cursor.node().id();
let assignment_id = cursor.node_id();

let mut captures = yeast::captures::Captures::new();
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
Expand All @@ -254,7 +254,7 @@ fn test_bare_underscore_matches_unnamed() {

let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child();
let assignment_id = cursor.node().id();
let assignment_id = cursor.node_id();

// `(_)` skips unnamed children, so a query containing a single `(_)`
// bare pattern fails to match the assignment (whose only unfielded
Expand Down Expand Up @@ -293,7 +293,7 @@ fn test_bare_forms_in_field_position() {

let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child();
let assignment_id = cursor.node().id();
let assignment_id = cursor.node_id();

// Bare `_` in field position. Captures the named `identifier "x"`
// child of the `left` field — bare `_` admits unnamed too, but the
Expand Down Expand Up @@ -337,7 +337,7 @@ fn test_forward_scan_finds_unnamed_token_late() {
while cursor.node().kind() != "do" || !cursor.node().is_named() {
assert!(cursor.goto_next_sibling(), "expected to find named `do`");
}
let do_id = cursor.node().id();
let do_id = cursor.node_id();

let query = yeast::query!((do ("end") @kw));
let mut captures = yeast::captures::Captures::new();
Expand All @@ -363,7 +363,7 @@ fn test_forward_scan_preserves_order() {
while cursor.node().kind() != "do" || !cursor.node().is_named() {
assert!(cursor.goto_next_sibling(), "expected to find named `do`");
}
let do_id = cursor.node().id();
let do_id = cursor.node_id();

let query = yeast::query!((do ("end") @first ("do") @second));
let mut captures = yeast::captures::Captures::new();
Expand Down
Loading