Skip to main content

ytil_noxi/
tree_sitter.rs

1use std::ops::Deref;
2use std::path::Path;
3
4use rootcause::prelude::ResultExt as _;
5use rootcause::report;
6use tree_sitter::Node;
7use tree_sitter::Parser;
8use tree_sitter::Point;
9
10use crate::buffer::CursorPosition;
11
12/// Wrapper around [`tree_sitter::Point`] converting Nvim's 1-based row to 0-based.
13#[derive(Debug)]
14#[cfg_attr(test, derive(Eq, PartialEq))]
15pub struct PointWrap(Point);
16
17impl Deref for PointWrap {
18    type Target = Point;
19
20    fn deref(&self) -> &Self::Target {
21        &self.0
22    }
23}
24
25impl From<CursorPosition> for PointWrap {
26    /// Converts a Nvim cursor position (1-based row, 0-based column) to a [`PointWrap`].
27    fn from(cursor_position: CursorPosition) -> Self {
28        Self(Point {
29            row: cursor_position.row.saturating_sub(1),
30            column: cursor_position.col,
31        })
32    }
33}
34
35pub fn get_enclosing_fn_name_of_position(file_path: &Path) -> Option<String> {
36    let position = CursorPosition::get_current().map(PointWrap::from)?;
37
38    let enclosing_fn_name = get_enclosing_fn_name_of_position_internal(file_path, *position)
39        .inspect_err(|err| {
40            crate::notify::error(format!(
41                "error getting enclosing fn | position={position:#?} error={err:#?}"
42            ));
43        })
44        .ok()
45        .flatten();
46
47    if enclosing_fn_name.is_none() {
48        crate::notify::error(format!("error missing enclosing fn | position={position:#?}"));
49    }
50
51    enclosing_fn_name
52}
53
54/// Gets the name of the function enclosing the given [Point] in a Rust file.
55///
56/// # Errors
57/// - A filesystem operation (open/read/write/remove) fails.
58fn get_enclosing_fn_name_of_position_internal(file_path: &Path, position: Point) -> rootcause::Result<Option<String>> {
59    if file_path.extension().is_none_or(|ext| ext != "rs") {
60        Err(report!("invalid file extension"))
61            .attach_with(|| format!("path={} expected_ext=\"rs\"", file_path.display()))?;
62    }
63    let src = std::fs::read(file_path)
64        .context("error reading file")
65        .attach_with(|| format!("path={}", file_path.display()))?;
66
67    let mut parser = Parser::new();
68    parser
69        .set_language(&tree_sitter_rust::LANGUAGE.into())
70        .context("error setting parser language")?;
71
72    let src_tree = parser
73        .parse(&src, None)
74        .ok_or_else(|| report!("error parsing Rust code"))
75        .attach_with(|| format!("path={}", file_path.display()))?;
76
77    let node_at_position = src_tree.root_node().descendant_for_point_range(position, position);
78
79    Ok(get_enclosing_fn_name_of_node(&src, node_at_position))
80}
81
82/// Gets the name of the function enclosing the given [Node].
83fn get_enclosing_fn_name_of_node(src: &[u8], node: Option<Node>) -> Option<String> {
84    const FN_NODE_KINDS: &[&str] = &[
85        "function",
86        "function_declaration",
87        "function_definition",
88        "function_item",
89        "method",
90        "method_declaration",
91        "method_definition",
92        "method_item",
93    ];
94    let mut current_node = node;
95    while let Some(node) = current_node {
96        if FN_NODE_KINDS.contains(&node.kind())
97            && let Some(fn_node_name) = node
98                .child_by_field_name("name")
99                .or_else(|| node.child_by_field_name("identifier"))
100            && let Ok(fn_name) = fn_node_name.utf8_text(src)
101            && !fn_name.is_empty()
102        {
103            return Some(fn_name.to_string());
104        }
105        current_node = node.parent();
106    }
107    None
108}
109
110#[cfg(test)]
111mod tests {
112    use rstest::rstest;
113
114    use super::*;
115
116    #[rstest]
117    #[case(CursorPosition { row:1, col: 5}, (0, 5))]
118    #[case(CursorPosition { row:10, col: 20}, (9, 20))]
119    #[case(CursorPosition { row:0, col: 0}, (0, 0))]
120    fn point_wrap_from_converts_neovim_cursor_to_tree_sitter_point(
121        #[case] input: CursorPosition,
122        #[case] expected: (usize, usize),
123    ) {
124        pretty_assertions::assert_eq!(
125            PointWrap::from(input),
126            PointWrap(Point {
127                row: expected.0,
128                column: expected.1
129            })
130        );
131    }
132
133    #[test]
134    fn point_wrap_deref_allows_direct_access_to_point() {
135        pretty_assertions::assert_eq!(
136            *PointWrap::from(CursorPosition { row: 5, col: 10 }),
137            Point { row: 4, column: 10 }
138        );
139    }
140
141    #[test]
142    fn get_enclosing_fn_name_of_node_returns_fn_name_when_inside_function() {
143        let result = with_node(
144            b"fn test_function() { let x = 1; }",
145            Point { row: 0, column: 20 },
146            get_enclosing_fn_name_of_node,
147        );
148        pretty_assertions::assert_eq!(result, Some("test_function".to_string()));
149    }
150
151    #[test]
152    fn get_enclosing_fn_name_of_node_returns_none_when_not_inside_function() {
153        let result = with_node(
154            b"let x = 1;",
155            Point { row: 0, column: 5 },
156            get_enclosing_fn_name_of_node,
157        );
158        pretty_assertions::assert_eq!(result, None);
159    }
160
161    #[test]
162    fn get_enclosing_fn_name_of_node_returns_method_name_when_inside_method() {
163        let result = with_node(
164            b"impl Test { fn method(&self) { let x = 1; } }",
165            Point { row: 0, column: 25 },
166            get_enclosing_fn_name_of_node,
167        );
168        pretty_assertions::assert_eq!(result, Some("method".to_string()));
169    }
170
171    #[test]
172    fn get_enclosing_fn_name_of_node_returns_none_when_node_is_none() {
173        let result = get_enclosing_fn_name_of_node(b"fn test() {}", None);
174        pretty_assertions::assert_eq!(result, None);
175    }
176
177    // Helper to work around the [`tree_sitter::Tree`] and [`tree_sitter::Node`] lifetime issues.
178    fn with_node<F, R>(src: &[u8], position: Point, f: F) -> R
179    where
180        F: FnOnce(&[u8], Option<Node>) -> R,
181    {
182        let mut parser = Parser::new();
183        parser.set_language(&tree_sitter_rust::LANGUAGE.into()).unwrap();
184        let tree = parser.parse(src, None).unwrap();
185        let node = tree.root_node().descendant_for_point_range(position, position);
186        f(src, node)
187    }
188}