ytil_noxi/
tree_sitter.rs

1use std::ops::Deref;
2use std::path::Path;
3
4use color_eyre::eyre::Context as _;
5use color_eyre::eyre::ensure;
6use color_eyre::eyre::eyre;
7use tree_sitter::Node;
8use tree_sitter::Parser;
9use tree_sitter::Point;
10
11use crate::buffer::CursorPosition;
12
13/// Wrapper around [`tree_sitter::Point`] that converts Nvim's 1-based row indexing
14/// to tree-sitter's 0-based indexing.
15///
16/// # Rationale
17///
18/// Nvim uses 1-based row indices for cursor positions, while tree-sitter expects 0-based rows.
19/// This wrapper simplifies the conversion in the codebase.
20#[derive(Debug)]
21#[cfg_attr(test, derive(Eq, PartialEq))]
22pub struct PointWrap(Point);
23
24impl Deref for PointWrap {
25    type Target = Point;
26
27    fn deref(&self) -> &Self::Target {
28        &self.0
29    }
30}
31
32impl From<CursorPosition> for PointWrap {
33    /// Converts a Nvim cursor position (1-based row, 0-based column) to a [`PointWrap`].
34    fn from(cursor_position: CursorPosition) -> Self {
35        Self(Point {
36            row: cursor_position.row.saturating_sub(1),
37            column: cursor_position.col,
38        })
39    }
40}
41
42pub fn get_enclosing_fn_name_of_position(file_path: &Path) -> Option<String> {
43    let position = CursorPosition::get_current().map(PointWrap::from)?;
44
45    let enclosing_fn_name = get_enclosing_fn_name_of_position_internal(file_path, *position)
46        .inspect_err(|err| {
47            crate::notify::error(format!(
48                "error getting enclosing fn | position={position:#?} error={err:#?}"
49            ));
50        })
51        .ok()
52        .flatten();
53
54    if enclosing_fn_name.is_none() {
55        crate::notify::error(format!("error missing enclosing fn | position={position:#?}"));
56    }
57
58    enclosing_fn_name
59}
60
61/// Gets the name of the function enclosing the given [Point] in a Rust file.
62///
63/// # Errors
64/// - A filesystem operation (open/read/write/remove) fails.
65fn get_enclosing_fn_name_of_position_internal(file_path: &Path, position: Point) -> color_eyre::Result<Option<String>> {
66    ensure!(
67        file_path.extension().is_some_and(|ext| ext == "rs"),
68        "invalid file extension | path={} expected_ext=\"rs\"",
69        file_path.display(),
70    );
71    let src = std::fs::read(file_path).with_context(|| format!("Error reading {}", file_path.display()))?;
72
73    let mut parser = Parser::new();
74    parser
75        .set_language(&tree_sitter_rust::LANGUAGE.into())
76        .with_context(|| "error setting parser language")?;
77
78    let src_tree = parser
79        .parse(&src, None)
80        .ok_or_else(|| eyre!("error parsing Rust code | path={}", file_path.display()))?;
81
82    let node_at_position = src_tree.root_node().descendant_for_point_range(position, position);
83
84    Ok(get_enclosing_fn_name_of_node(&src, node_at_position))
85}
86
87/// Gets the name of the function enclosing the given [Node].
88fn get_enclosing_fn_name_of_node(src: &[u8], node: Option<Node>) -> Option<String> {
89    const FN_NODE_KINDS: &[&str] = &[
90        "function",
91        "function_declaration",
92        "function_definition",
93        "function_item",
94        "method",
95        "method_declaration",
96        "method_definition",
97        "method_item",
98    ];
99    let mut current_node = node;
100    while let Some(node) = current_node {
101        if FN_NODE_KINDS.contains(&node.kind())
102            && let Some(fn_node_name) = node
103                .child_by_field_name("name")
104                .or_else(|| node.child_by_field_name("identifier"))
105            && let Ok(fn_name) = fn_node_name.utf8_text(src)
106            && !fn_name.is_empty()
107        {
108            return Some(fn_name.to_string());
109        }
110        current_node = node.parent();
111    }
112    None
113}
114
115#[cfg(test)]
116mod tests {
117    use rstest::rstest;
118
119    use super::*;
120
121    #[rstest]
122    #[case(CursorPosition { row:1, col: 5}, (0, 5))]
123    #[case(CursorPosition { row:10, col: 20}, (9, 20))]
124    #[case(CursorPosition { row:0, col: 0}, (0, 0))]
125    fn point_wrap_from_converts_neovim_cursor_to_tree_sitter_point(
126        #[case] input: CursorPosition,
127        #[case] expected: (usize, usize),
128    ) {
129        pretty_assertions::assert_eq!(
130            PointWrap::from(input),
131            PointWrap(Point {
132                row: expected.0,
133                column: expected.1
134            })
135        );
136    }
137
138    #[test]
139    fn point_wrap_deref_allows_direct_access_to_point() {
140        pretty_assertions::assert_eq!(
141            *PointWrap::from(CursorPosition { row: 5, col: 10 }),
142            Point { row: 4, column: 10 }
143        );
144    }
145
146    #[test]
147    fn get_enclosing_fn_name_of_node_returns_fn_name_when_inside_function() {
148        let result = with_node(
149            b"fn test_function() { let x = 1; }",
150            Point { row: 0, column: 20 },
151            get_enclosing_fn_name_of_node,
152        );
153        pretty_assertions::assert_eq!(result, Some("test_function".to_string()));
154    }
155
156    #[test]
157    fn get_enclosing_fn_name_of_node_returns_none_when_not_inside_function() {
158        let result = with_node(
159            b"let x = 1;",
160            Point { row: 0, column: 5 },
161            get_enclosing_fn_name_of_node,
162        );
163        pretty_assertions::assert_eq!(result, None);
164    }
165
166    #[test]
167    fn get_enclosing_fn_name_of_node_returns_method_name_when_inside_method() {
168        let result = with_node(
169            b"impl Test { fn method(&self) { let x = 1; } }",
170            Point { row: 0, column: 25 },
171            get_enclosing_fn_name_of_node,
172        );
173        pretty_assertions::assert_eq!(result, Some("method".to_string()));
174    }
175
176    #[test]
177    fn get_enclosing_fn_name_of_node_returns_none_when_node_is_none() {
178        let result = get_enclosing_fn_name_of_node(b"fn test() {}", None);
179        pretty_assertions::assert_eq!(result, None);
180    }
181
182    // Helper to work around the [`tree_sitter::Tree`] and [`tree_sitter::Node`] lifetime issues.
183    fn with_node<F, R>(src: &[u8], position: Point, f: F) -> R
184    where
185        F: FnOnce(&[u8], Option<Node>) -> R,
186    {
187        let mut parser = Parser::new();
188        parser.set_language(&tree_sitter_rust::LANGUAGE.into()).unwrap();
189        let tree = parser.parse(src, None).unwrap();
190        let node = tree.root_node().descendant_for_point_range(position, position);
191        f(src, node)
192    }
193}