1use std::ops::Deref;
2use std::path::Path;
3
4use rootcause::prelude::ResultExt as _;
5use rootcause::report;
6use tree_sitter::Node;
7use tree_sitter::Parser;
8use tree_sitter::Point;
9
10use crate::buffer::CursorPosition;
11
12#[derive(Debug)]
14#[cfg_attr(test, derive(Eq, PartialEq))]
15pub struct PointWrap(Point);
16
17impl Deref for PointWrap {
18 type Target = Point;
19
20 fn deref(&self) -> &Self::Target {
21 &self.0
22 }
23}
24
25impl From<CursorPosition> for PointWrap {
26 fn from(cursor_position: CursorPosition) -> Self {
28 Self(Point {
29 row: cursor_position.row.saturating_sub(1),
30 column: cursor_position.col,
31 })
32 }
33}
34
35pub fn get_enclosing_fn_name_of_position(file_path: &Path) -> Option<String> {
36 let position = CursorPosition::get_current().map(PointWrap::from)?;
37
38 let enclosing_fn_name = get_enclosing_fn_name_of_position_internal(file_path, *position)
39 .inspect_err(|err| {
40 crate::notify::error(format!(
41 "error getting enclosing fn | position={position:#?} error={err:#?}"
42 ));
43 })
44 .ok()
45 .flatten();
46
47 if enclosing_fn_name.is_none() {
48 crate::notify::error(format!("error missing enclosing fn | position={position:#?}"));
49 }
50
51 enclosing_fn_name
52}
53
54fn get_enclosing_fn_name_of_position_internal(file_path: &Path, position: Point) -> rootcause::Result<Option<String>> {
59 if file_path.extension().is_none_or(|ext| ext != "rs") {
60 Err(report!("invalid file extension"))
61 .attach_with(|| format!("path={} expected_ext=\"rs\"", file_path.display()))?;
62 }
63 let src = std::fs::read(file_path)
64 .context("error reading file")
65 .attach_with(|| format!("path={}", file_path.display()))?;
66
67 let mut parser = Parser::new();
68 parser
69 .set_language(&tree_sitter_rust::LANGUAGE.into())
70 .context("error setting parser language")?;
71
72 let src_tree = parser
73 .parse(&src, None)
74 .ok_or_else(|| report!("error parsing Rust code"))
75 .attach_with(|| format!("path={}", file_path.display()))?;
76
77 let node_at_position = src_tree.root_node().descendant_for_point_range(position, position);
78
79 Ok(get_enclosing_fn_name_of_node(&src, node_at_position))
80}
81
82fn get_enclosing_fn_name_of_node(src: &[u8], node: Option<Node>) -> Option<String> {
84 const FN_NODE_KINDS: &[&str] = &[
85 "function",
86 "function_declaration",
87 "function_definition",
88 "function_item",
89 "method",
90 "method_declaration",
91 "method_definition",
92 "method_item",
93 ];
94 let mut current_node = node;
95 while let Some(node) = current_node {
96 if FN_NODE_KINDS.contains(&node.kind())
97 && let Some(fn_node_name) = node
98 .child_by_field_name("name")
99 .or_else(|| node.child_by_field_name("identifier"))
100 && let Ok(fn_name) = fn_node_name.utf8_text(src)
101 && !fn_name.is_empty()
102 {
103 return Some(fn_name.to_string());
104 }
105 current_node = node.parent();
106 }
107 None
108}
109
110#[cfg(test)]
111mod tests {
112 use rstest::rstest;
113
114 use super::*;
115
116 #[rstest]
117 #[case(CursorPosition { row:1, col: 5}, (0, 5))]
118 #[case(CursorPosition { row:10, col: 20}, (9, 20))]
119 #[case(CursorPosition { row:0, col: 0}, (0, 0))]
120 fn point_wrap_from_converts_neovim_cursor_to_tree_sitter_point(
121 #[case] input: CursorPosition,
122 #[case] expected: (usize, usize),
123 ) {
124 pretty_assertions::assert_eq!(
125 PointWrap::from(input),
126 PointWrap(Point {
127 row: expected.0,
128 column: expected.1
129 })
130 );
131 }
132
133 #[test]
134 fn point_wrap_deref_allows_direct_access_to_point() {
135 pretty_assertions::assert_eq!(
136 *PointWrap::from(CursorPosition { row: 5, col: 10 }),
137 Point { row: 4, column: 10 }
138 );
139 }
140
141 #[test]
142 fn get_enclosing_fn_name_of_node_returns_fn_name_when_inside_function() {
143 let result = with_node(
144 b"fn test_function() { let x = 1; }",
145 Point { row: 0, column: 20 },
146 get_enclosing_fn_name_of_node,
147 );
148 pretty_assertions::assert_eq!(result, Some("test_function".to_string()));
149 }
150
151 #[test]
152 fn get_enclosing_fn_name_of_node_returns_none_when_not_inside_function() {
153 let result = with_node(
154 b"let x = 1;",
155 Point { row: 0, column: 5 },
156 get_enclosing_fn_name_of_node,
157 );
158 pretty_assertions::assert_eq!(result, None);
159 }
160
161 #[test]
162 fn get_enclosing_fn_name_of_node_returns_method_name_when_inside_method() {
163 let result = with_node(
164 b"impl Test { fn method(&self) { let x = 1; } }",
165 Point { row: 0, column: 25 },
166 get_enclosing_fn_name_of_node,
167 );
168 pretty_assertions::assert_eq!(result, Some("method".to_string()));
169 }
170
171 #[test]
172 fn get_enclosing_fn_name_of_node_returns_none_when_node_is_none() {
173 let result = get_enclosing_fn_name_of_node(b"fn test() {}", None);
174 pretty_assertions::assert_eq!(result, None);
175 }
176
177 fn with_node<F, R>(src: &[u8], position: Point, f: F) -> R
179 where
180 F: FnOnce(&[u8], Option<Node>) -> R,
181 {
182 let mut parser = Parser::new();
183 parser.set_language(&tree_sitter_rust::LANGUAGE.into()).unwrap();
184 let tree = parser.parse(src, None).unwrap();
185 let node = tree.root_node().descendant_for_point_range(position, position);
186 f(src, node)
187 }
188}