1use std::ops::Deref;
2use std::path::Path;
3
4use color_eyre::eyre::Context as _;
5use color_eyre::eyre::ensure;
6use color_eyre::eyre::eyre;
7use tree_sitter::Node;
8use tree_sitter::Parser;
9use tree_sitter::Point;
10
11use crate::buffer::CursorPosition;
12
13#[derive(Debug)]
21#[cfg_attr(test, derive(Eq, PartialEq))]
22pub struct PointWrap(Point);
23
24impl Deref for PointWrap {
25 type Target = Point;
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30}
31
32impl From<CursorPosition> for PointWrap {
33 fn from(cursor_position: CursorPosition) -> Self {
35 Self(Point {
36 row: cursor_position.row.saturating_sub(1),
37 column: cursor_position.col,
38 })
39 }
40}
41
42pub fn get_enclosing_fn_name_of_position(file_path: &Path) -> Option<String> {
43 let position = CursorPosition::get_current().map(PointWrap::from)?;
44
45 let enclosing_fn_name = get_enclosing_fn_name_of_position_internal(file_path, *position)
46 .inspect_err(|err| {
47 crate::notify::error(format!(
48 "error getting enclosing fn | position={position:#?} error={err:#?}"
49 ));
50 })
51 .ok()
52 .flatten();
53
54 if enclosing_fn_name.is_none() {
55 crate::notify::error(format!("error missing enclosing fn | position={position:#?}"));
56 }
57
58 enclosing_fn_name
59}
60
61fn get_enclosing_fn_name_of_position_internal(file_path: &Path, position: Point) -> color_eyre::Result<Option<String>> {
66 ensure!(
67 file_path.extension().is_some_and(|ext| ext == "rs"),
68 "invalid file extension | path={} expected_ext=\"rs\"",
69 file_path.display(),
70 );
71 let src = std::fs::read(file_path).with_context(|| format!("Error reading {}", file_path.display()))?;
72
73 let mut parser = Parser::new();
74 parser
75 .set_language(&tree_sitter_rust::LANGUAGE.into())
76 .with_context(|| "error setting parser language")?;
77
78 let src_tree = parser
79 .parse(&src, None)
80 .ok_or_else(|| eyre!("error parsing Rust code | path={}", file_path.display()))?;
81
82 let node_at_position = src_tree.root_node().descendant_for_point_range(position, position);
83
84 Ok(get_enclosing_fn_name_of_node(&src, node_at_position))
85}
86
87fn get_enclosing_fn_name_of_node(src: &[u8], node: Option<Node>) -> Option<String> {
89 const FN_NODE_KINDS: &[&str] = &[
90 "function",
91 "function_declaration",
92 "function_definition",
93 "function_item",
94 "method",
95 "method_declaration",
96 "method_definition",
97 "method_item",
98 ];
99 let mut current_node = node;
100 while let Some(node) = current_node {
101 if FN_NODE_KINDS.contains(&node.kind())
102 && let Some(fn_node_name) = node
103 .child_by_field_name("name")
104 .or_else(|| node.child_by_field_name("identifier"))
105 && let Ok(fn_name) = fn_node_name.utf8_text(src)
106 && !fn_name.is_empty()
107 {
108 return Some(fn_name.to_string());
109 }
110 current_node = node.parent();
111 }
112 None
113}
114
115#[cfg(test)]
116mod tests {
117 use rstest::rstest;
118
119 use super::*;
120
121 #[rstest]
122 #[case(CursorPosition { row:1, col: 5}, (0, 5))]
123 #[case(CursorPosition { row:10, col: 20}, (9, 20))]
124 #[case(CursorPosition { row:0, col: 0}, (0, 0))]
125 fn point_wrap_from_converts_neovim_cursor_to_tree_sitter_point(
126 #[case] input: CursorPosition,
127 #[case] expected: (usize, usize),
128 ) {
129 pretty_assertions::assert_eq!(
130 PointWrap::from(input),
131 PointWrap(Point {
132 row: expected.0,
133 column: expected.1
134 })
135 );
136 }
137
138 #[test]
139 fn point_wrap_deref_allows_direct_access_to_point() {
140 pretty_assertions::assert_eq!(
141 *PointWrap::from(CursorPosition { row: 5, col: 10 }),
142 Point { row: 4, column: 10 }
143 );
144 }
145
146 #[test]
147 fn get_enclosing_fn_name_of_node_returns_fn_name_when_inside_function() {
148 let result = with_node(
149 b"fn test_function() { let x = 1; }",
150 Point { row: 0, column: 20 },
151 get_enclosing_fn_name_of_node,
152 );
153 pretty_assertions::assert_eq!(result, Some("test_function".to_string()));
154 }
155
156 #[test]
157 fn get_enclosing_fn_name_of_node_returns_none_when_not_inside_function() {
158 let result = with_node(
159 b"let x = 1;",
160 Point { row: 0, column: 5 },
161 get_enclosing_fn_name_of_node,
162 );
163 pretty_assertions::assert_eq!(result, None);
164 }
165
166 #[test]
167 fn get_enclosing_fn_name_of_node_returns_method_name_when_inside_method() {
168 let result = with_node(
169 b"impl Test { fn method(&self) { let x = 1; } }",
170 Point { row: 0, column: 25 },
171 get_enclosing_fn_name_of_node,
172 );
173 pretty_assertions::assert_eq!(result, Some("method".to_string()));
174 }
175
176 #[test]
177 fn get_enclosing_fn_name_of_node_returns_none_when_node_is_none() {
178 let result = get_enclosing_fn_name_of_node(b"fn test() {}", None);
179 pretty_assertions::assert_eq!(result, None);
180 }
181
182 fn with_node<F, R>(src: &[u8], position: Point, f: F) -> R
184 where
185 F: FnOnce(&[u8], Option<Node>) -> R,
186 {
187 let mut parser = Parser::new();
188 parser.set_language(&tree_sitter_rust::LANGUAGE.into()).unwrap();
189 let tree = parser.parse(src, None).unwrap();
190 let node = tree.root_node().descendant_for_point_range(position, position);
191 f(src, node)
192 }
193}