1use std::ops::Range;
4
5use color_eyre::eyre::Context as _;
6use color_eyre::eyre::bail;
7use nvim_oxi::Array;
8use nvim_oxi::Object;
9use nvim_oxi::api::Buffer;
10use nvim_oxi::api::SuperIterator;
11use nvim_oxi::api::opts::GetTextOpts;
12use nvim_oxi::conversion::FromObject;
13use nvim_oxi::lua::Poppable;
14use nvim_oxi::lua::ffi::State;
15use serde::Deserialize;
16use serde::Deserializer;
17
18use crate::buffer::BufferExt;
19
20pub fn get_lines(_: ()) -> Vec<String> {
43 get(()).map_or_else(Vec::new, |f| f.lines)
44}
45
46pub fn get(_: ()) -> Option<Selection> {
55 let Ok(mut bounds) = SelectionBounds::new().inspect_err(|err| {
56 crate::notify::error(format!("error creating selection bounds | error={err:#?}"));
57 }) else {
58 return None;
59 };
60
61 let current_buffer = Buffer::from(bounds.buf_id());
62
63 if nvim_oxi::api::get_mode().mode == "V" {
65 let end_lnum = bounds.end().lnum;
66 let Ok(last_line) = current_buffer.get_line(end_lnum).inspect_err(|err| {
67 crate::notify::error(format!(
68 "error getting selection last line | end_lnum={end_lnum} buffer={current_buffer:#?} error={err:#?}",
69 ));
70 }) else {
71 return None;
72 };
73 bounds.start.col = 0;
75 bounds.end.col = last_line.len();
76 let Ok(lines) = current_buffer
78 .get_lines(bounds.start().lnum..=bounds.end().lnum, false)
79 .inspect_err(|err| {
80 crate::notify::error(format!(
81 "error getting lines | buffer={current_buffer:#?} error={err:#?}"
82 ));
83 })
84 else {
85 return None;
86 };
87 return Some(Selection::new(bounds, lines));
88 }
89
90 if let Ok(line) = current_buffer.get_line(bounds.end().lnum)
93 && bounds.end().col < line.len()
94 {
95 bounds.incr_end_col(); }
97
98 let Ok(lines) = current_buffer
100 .get_text(
101 bounds.line_range(),
102 bounds.start().col,
103 bounds.end().col,
104 &GetTextOpts::default(),
105 )
106 .inspect_err(|err| {
107 crate::notify::error(format!(
108 "error getting text | buffer={current_buffer:#?} bounds={bounds:#?} error={err:#?}"
109 ));
110 })
111 else {
112 return None;
113 };
114
115 Some(Selection::new(bounds, lines))
116}
117
118#[derive(Debug)]
120pub struct Selection {
121 bounds: SelectionBounds,
122 lines: Vec<String>,
123}
124
125impl Selection {
126 pub fn new(bounds: SelectionBounds, lines: impl SuperIterator<nvim_oxi::String>) -> Self {
128 Self {
129 bounds,
130 lines: lines.into_iter().map(|line| line.to_string()).collect(),
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
137pub struct SelectionBounds {
138 #[cfg(feature = "testing")]
139 pub buf_id: i32,
140 #[cfg(feature = "testing")]
141 pub start: Bound,
142 #[cfg(feature = "testing")]
143 pub end: Bound,
144 #[cfg(not(feature = "testing"))]
145 buf_id: i32,
146 #[cfg(not(feature = "testing"))]
147 start: Bound,
148 #[cfg(not(feature = "testing"))]
149 end: Bound,
150}
151
152impl SelectionBounds {
153 pub fn new() -> color_eyre::Result<Self> {
162 let cursor_pos = get_pos(".")?;
163 let visual_pos = get_pos("v")?;
164
165 let (start, end) = cursor_pos.sort(visual_pos);
166
167 if start.buf_id != end.buf_id {
168 bail!("mismatched buffer ids | start={start:#?} end={end:#?}")
169 }
170
171 Ok(Self {
172 buf_id: start.buf_id,
173 start: Bound::from(start),
174 end: Bound::from(end),
175 })
176 }
177
178 pub const fn line_range(&self) -> Range<usize> {
180 self.start.lnum..self.end.lnum
181 }
182
183 pub const fn buf_id(&self) -> i32 {
185 self.buf_id
186 }
187
188 pub const fn start(&self) -> &Bound {
190 &self.start
191 }
192
193 pub const fn end(&self) -> &Bound {
195 &self.end
196 }
197
198 const fn incr_end_col(&mut self) {
200 self.end.col = self.end.col.saturating_add(1);
201 }
202}
203
204#[derive(Clone, Copy, Debug)]
206pub struct Bound {
207 pub lnum: usize,
209 pub col: usize,
211}
212
213impl From<Pos> for Bound {
214 fn from(value: Pos) -> Self {
215 Self {
216 lnum: value.lnum,
217 col: value.col,
218 }
219 }
220}
221
222impl Selection {
223 pub const fn buf_id(&self) -> i32 {
225 self.bounds.buf_id()
226 }
227
228 pub const fn start(&self) -> &Bound {
230 self.bounds.start()
231 }
232
233 pub const fn end(&self) -> &Bound {
235 self.bounds.end()
236 }
237
238 pub fn lines(&self) -> &[String] {
240 &self.lines
241 }
242
243 pub const fn line_range(&self) -> Range<usize> {
245 self.bounds.line_range()
246 }
247}
248
249#[derive(Clone, Copy, Debug, Eq, PartialEq)]
254pub struct Pos {
255 buf_id: i32,
256 lnum: usize,
258 col: usize,
260}
261
262impl Pos {
263 pub const fn sort(self, other: Self) -> (Self, Self) {
266 if self.lnum > other.lnum || (self.lnum == other.lnum && self.col > other.col) {
267 (other, self)
268 } else {
269 (self, other)
270 }
271 }
272}
273
274impl<'de> Deserialize<'de> for Pos {
276 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
277 where
278 D: Deserializer<'de>,
279 {
280 let t = RawPos::deserialize(deserializer)?;
281 Ok(Self::from(t))
282 }
283}
284
285impl From<RawPos> for Pos {
287 fn from(raw: RawPos) -> Self {
288 fn to_0_based_usize(v: i64) -> usize {
289 usize::try_from(v.saturating_sub(1)).unwrap_or_default()
290 }
291
292 Self {
293 buf_id: raw.0,
294 lnum: to_0_based_usize(raw.1),
295 col: to_0_based_usize(raw.2),
296 }
297 }
298}
299
300#[derive(Clone, Copy, Debug, Deserialize)]
302#[expect(dead_code, reason = "Unused fields are kept for completeness")]
303struct RawPos(i32, i64, i64, i64);
304
305impl FromObject for Pos {
307 fn from_object(obj: Object) -> Result<Self, nvim_oxi::conversion::Error> {
308 Self::deserialize(nvim_oxi::serde::Deserializer::new(obj)).map_err(Into::into)
309 }
310}
311
312impl Poppable for Pos {
314 unsafe fn pop(lstate: *mut State) -> Result<Self, nvim_oxi::lua::Error> {
315 unsafe {
316 let obj = Object::pop(lstate)?;
317 Self::from_object(obj).map_err(nvim_oxi::lua::Error::pop_error_from_err::<Self, _>)
318 }
319 }
320}
321
322fn get_pos(mark: &str) -> color_eyre::Result<Pos> {
332 nvim_oxi::api::call_function::<_, Pos>("getpos", Array::from_iter([mark]))
333 .inspect_err(|err| {
334 crate::notify::error(format!("error getting pos | mark={mark:?} error={err:#?}"));
335 })
336 .wrap_err_with(|| format!("error getting position | mark={mark:?}"))
337}
338
339#[cfg(test)]
340mod tests {
341 use rstest::rstest;
342
343 use super::*;
344
345 #[rstest]
346 #[case::self_has_lower_line(pos(0, 5), pos(1, 0), pos(0, 5), pos(1, 0))]
347 #[case::self_has_higher_line(pos(2, 0), pos(1, 5), pos(1, 5), pos(2, 0))]
348 #[case::same_line_self_lower_col(pos(1, 0), pos(1, 5), pos(1, 0), pos(1, 5))]
349 #[case::same_line_self_higher_col(pos(1, 10), pos(1, 5), pos(1, 5), pos(1, 10))]
350 #[case::positions_identical(pos(1, 5), pos(1, 5), pos(1, 5), pos(1, 5))]
351 fn pos_sort_returns_expected_order(
352 #[case] self_pos: Pos,
353 #[case] other_pos: Pos,
354 #[case] expected_first: Pos,
355 #[case] expected_second: Pos,
356 ) {
357 let (first, second) = self_pos.sort(other_pos);
358 pretty_assertions::assert_eq!(first, expected_first);
359 pretty_assertions::assert_eq!(second, expected_second);
360 }
361
362 fn pos(lnum: usize, col: usize) -> Pos {
363 Pos { buf_id: 1, lnum, col }
364 }
365}