1use std::collections::HashSet;
7use std::path::Path;
8use std::process::Command;
9
10use chrono::DateTime;
11use chrono::Utc;
12use color_eyre::eyre::Context;
13use color_eyre::eyre::bail;
14use color_eyre::eyre::eyre;
15use git2::Cred;
16use git2::RemoteCallbacks;
17use git2::Repository;
18use ytil_cmd::CmdError;
19use ytil_cmd::CmdExt as _;
20
21pub fn get_default() -> color_eyre::Result<String> {
31 let repo_path = Path::new(".");
32 let repo = crate::repo::discover(repo_path).wrap_err_with(|| {
33 eyre!(
34 "error getting repo for getting default branch | path={}",
35 repo_path.display()
36 )
37 })?;
38
39 let default_remote_ref = crate::remote::get_default(&repo)?;
40
41 let Some(target) = default_remote_ref.symbolic_target() else {
42 bail!("error missing default branch");
43 };
44
45 Ok(target
46 .split('/')
47 .next_back()
48 .ok_or_else(|| eyre!("error extracting default branch_name from target | target={target:?}"))?
49 .to_string())
50}
51
52pub fn get_current() -> color_eyre::Result<String> {
62 let repo_path = Path::new(".");
63 let repo = crate::repo::discover(repo_path).wrap_err_with(|| {
64 eyre!(
65 "error getting repo for getting current branch | path={}",
66 repo_path.display()
67 )
68 })?;
69
70 if repo
71 .head_detached()
72 .wrap_err_with(|| eyre!("error checking if head is detached | path={}", repo_path.display()))?
73 {
74 bail!("error head is detached | path={}", repo_path.display())
75 }
76
77 repo.head()
78 .wrap_err_with(|| eyre!("error getting head | path={}", repo_path.display()))?
79 .shorthand()
80 .map(str::to_string)
81 .ok_or_else(|| eyre!("error invalid branch shorthand UTF-8 | path={}", repo_path.display()))
82}
83
84pub fn create_from_default_branch(branch_name: &str, repo: Option<&Repository>) -> color_eyre::Result<()> {
97 let repo = if let Some(repo) = repo {
98 repo
99 } else {
100 let path = Path::new(".");
101 &crate::repo::discover(path).wrap_err_with(|| {
102 eyre!(
103 "error getting repo for creating new branch | path={} branch={branch_name:?}",
104 path.display()
105 )
106 })?
107 };
108
109 let commit = repo
110 .head()
111 .wrap_err_with(|| eyre!("error getting head | branch_name={branch_name:?}"))?
112 .peel_to_commit()
113 .wrap_err_with(|| eyre!("error peeling head to commit | branch_name={branch_name:?}"))?;
114
115 repo.branch(branch_name, &commit, false)
116 .wrap_err_with(|| eyre!("error creating branch | branch_name={branch_name:?}"))?;
117
118 Ok(())
119}
120
121pub fn push(branch_name: &str, repo: Option<&Repository>) -> color_eyre::Result<()> {
132 let repo = if let Some(repo) = repo {
133 repo
134 } else {
135 let path = Path::new(".");
136 &crate::repo::discover(path).wrap_err_with(|| {
137 eyre!(
138 "error getting repo for pushing new branch | path={} branch={branch_name:?}",
139 path.display()
140 )
141 })?
142 };
143
144 let default_remote = crate::remote::get_default(repo)?;
145
146 let default_remote_name = default_remote
147 .name()
148 .ok_or_else(|| eyre!("error missing name of default remote"))?
149 .trim_start_matches("refs/remotes/")
150 .trim_end_matches("/HEAD");
151
152 let mut remote = repo.find_remote(default_remote_name)?;
153
154 let mut callbacks = RemoteCallbacks::new();
155 callbacks.credentials(|_url, username_from_url, _allowed_types| {
156 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
157 });
158
159 let mut push_opts = git2::PushOptions::new();
160 push_opts.remote_callbacks(callbacks);
161
162 let branch_refspec = format!("refs/heads/{branch_name}");
163 remote.push(&[&branch_refspec], Some(&mut push_opts)).wrap_err_with(|| {
164 eyre!("error pushing branch to remote | branch_refspec={branch_refspec:?} default_remote_name={default_remote_name:?}")
165 })?;
166
167 Ok(())
168}
169
170pub fn switch(branch_name: &str) -> Result<(), Box<CmdError>> {
181 Command::new("git")
182 .args(["switch", branch_name, "--guess"])
183 .exec()
184 .map_err(Box::new)?;
185 Ok(())
186}
187
188pub fn get_all() -> color_eyre::Result<Vec<Branch>> {
200 let repo_path = Path::new(".");
201 let repo = crate::repo::discover(repo_path)
202 .wrap_err_with(|| eyre!("error getting repo for getting branches | path={}", repo_path.display()))?;
203
204 fetch(&[]).wrap_err_with(|| eyre!("error fetching branches"))?;
205
206 let mut out = vec![];
207 for branch_res in repo
208 .branches(None)
209 .wrap_err_with(|| eyre!("error enumerating branches"))?
210 {
211 let branch = branch_res.wrap_err_with(|| eyre!("error getting branch result"))?;
212 out.push(Branch::try_from(branch).wrap_err_with(|| eyre!("error creating branch from result"))?);
213 }
214
215 out.sort_by(|a, b| b.committer_date_time().cmp(a.committer_date_time()));
216
217 Ok(out)
218}
219
220pub fn get_all_no_redundant() -> color_eyre::Result<Vec<Branch>> {
231 let mut branches = get_all()?;
232 remove_redundant_remotes(&mut branches);
233 Ok(branches)
234}
235
236pub fn fetch(branches: &[&str]) -> color_eyre::Result<()> {
246 let repo_path = Path::new(".");
247 let repo = crate::repo::discover(repo_path).wrap_err_with(|| {
248 eyre!(
249 "error getting repo for fetching branches | path={} branches={branches:?}",
250 repo_path.display()
251 )
252 })?;
253
254 let mut callbacks = RemoteCallbacks::new();
255 callbacks.credentials(|_url, username_from_url, _allowed_types| {
256 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
257 });
258
259 let mut fetch_opts = git2::FetchOptions::new();
260 fetch_opts.remote_callbacks(callbacks);
261
262 repo.find_remote("origin")
263 .wrap_err_with(|| eyre!("error finding origin remote"))?
264 .fetch(branches, Some(&mut fetch_opts), None)
265 .wrap_err_with(|| eyre!("error fetching branches={branches:?}"))?;
266
267 Ok(())
268}
269
270pub fn remove_redundant_remotes(branches: &mut Vec<Branch>) {
279 let mut local_names = HashSet::with_capacity(branches.len());
280 for branch in branches.iter() {
281 if let Branch::Local { name, .. } = branch {
282 local_names.insert(name.clone());
283 }
284 }
285
286 branches.retain(|b| match b {
287 Branch::Local { .. } => true,
288 Branch::Remote { name, .. } => {
289 let short = name.split_once('/').map_or(name.as_str(), |(_, rest)| rest);
290 !local_names.contains(short)
291 }
292 });
293}
294
295#[derive(Clone, Debug)]
297#[cfg_attr(test, derive(Eq, PartialEq))]
298pub enum Branch {
299 Local {
301 name: String,
303 committer_date_time: DateTime<Utc>,
305 },
306 Remote {
308 name: String,
310 committer_date_time: DateTime<Utc>,
312 },
313}
314
315impl Branch {
316 pub fn name(&self) -> &str {
318 match self {
319 Self::Local { name, .. } | Self::Remote { name, .. } => name,
320 }
321 }
322
323 pub fn name_no_origin(&self) -> &str {
325 self.name().trim_start_matches("origin/")
326 }
327
328 pub const fn committer_date_time(&self) -> &DateTime<Utc> {
330 match self {
331 Self::Local {
332 committer_date_time, ..
333 }
334 | Self::Remote {
335 committer_date_time, ..
336 } => committer_date_time,
337 }
338 }
339}
340
341impl<'a> TryFrom<(git2::Branch<'a>, git2::BranchType)> for Branch {
350 type Error = color_eyre::eyre::Error;
351
352 fn try_from((raw_branch, branch_type): (git2::Branch<'a>, git2::BranchType)) -> Result<Self, Self::Error> {
353 let branch_name = raw_branch
354 .name()?
355 .ok_or_else(|| eyre!("error invalid branch name UTF-8 | branch_name={:?}", raw_branch.name()))?;
356 let commit_time = raw_branch.get().peel_to_commit()?.committer().when();
357 let committer_date_time = DateTime::from_timestamp(commit_time.seconds(), 0)
358 .ok_or_else(|| eyre!("error invalid commit timestamp | seconds={}", commit_time.seconds()))?;
359
360 Ok(match branch_type {
361 git2::BranchType::Local => Self::Local {
362 name: branch_name.to_string(),
363 committer_date_time,
364 },
365 git2::BranchType::Remote => Self::Remote {
366 name: branch_name.to_string(),
367 committer_date_time,
368 },
369 })
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use git2::Time;
376 use rstest::rstest;
377
378 use super::*;
379
380 #[rstest]
381 #[case::remote_same_short_name(
382 vec![local("feature-x"), remote("origin/feature-x")],
383 vec![local("feature-x")]
384 )]
385 #[case::no_redundant(
386 vec![local("feature-x"), remote("origin/feature-y")],
387 vec![local("feature-x"), remote("origin/feature-y")]
388 )]
389 #[case::multiple_mixed(
390 vec![
391 local("feature-x"),
392 remote("origin/feature-x"),
393 remote("origin/feature-y"),
394 local("main"),
395 remote("upstream/main")
396 ],
397 vec![local("feature-x"), remote("origin/feature-y"), local("main")]
398 )]
399 #[case::different_remote_prefix(
400 vec![local("feature-x"), remote("upstream/feature-x")],
401 vec![local("feature-x")]
402 )]
403 fn remove_redundant_remotes_cases(#[case] mut input: Vec<Branch>, #[case] expected: Vec<Branch>) {
404 remove_redundant_remotes(&mut input);
405 assert_eq!(input, expected);
406 }
407
408 #[test]
409 fn branch_try_from_converts_local_branch_successfully() {
410 let (_temp_dir, repo) = crate::tests::init_test_repo(Some(Time::new(42, 3)));
411
412 let head_commit = repo.head().unwrap().peel_to_commit().unwrap();
413 let branch = repo.branch("test-branch", &head_commit, false).unwrap();
414
415 assert2::let_assert!(Ok(result) = Branch::try_from((branch, git2::BranchType::Local)));
416
417 pretty_assertions::assert_eq!(
418 result,
419 Branch::Local {
420 name: "test-branch".to_string(),
421 committer_date_time: DateTime::from_timestamp(42, 0).unwrap(),
422 }
423 );
424 }
425
426 #[rstest]
427 #[case::local_variant(local("main"), "main")]
428 #[case::remote_variant(remote("origin/feature"), "origin/feature")]
429 fn branch_name_when_variant_returns_name(#[case] branch: Branch, #[case] expected: &str) {
430 pretty_assertions::assert_eq!(branch.name(), expected);
431 }
432
433 #[rstest]
434 #[case::local_no_origin(local("main"), "main")]
435 #[case::remote_origin_prefix(remote("origin/main"), "main")]
436 #[case::remote_other_prefix(remote("upstream/feature"), "upstream/feature")]
437 fn branch_name_no_origin_when_name_returns_trimmed(#[case] branch: Branch, #[case] expected: &str) {
438 pretty_assertions::assert_eq!(branch.name_no_origin(), expected);
439 }
440
441 #[rstest]
442 #[case::local_variant(
443 Branch::Local {
444 name: "test".to_string(),
445 committer_date_time: DateTime::from_timestamp(123_456, 0).unwrap(),
446 },
447 DateTime::from_timestamp(123_456, 0).unwrap()
448 )]
449 #[case::remote_variant(
450 Branch::Remote {
451 name: "origin/test".to_string(),
452 committer_date_time: DateTime::from_timestamp(654_321, 0).unwrap(),
453 },
454 DateTime::from_timestamp(654_321, 0).unwrap()
455 )]
456 fn branch_committer_date_time_when_variant_returns_date_time(
457 #[case] branch: Branch,
458 #[case] expected: DateTime<Utc>,
459 ) {
460 pretty_assertions::assert_eq!(branch.committer_date_time(), &expected);
461 }
462
463 fn local(name: &str) -> Branch {
464 Branch::Local {
465 name: name.into(),
466 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
467 }
468 }
469
470 fn remote(name: &str) -> Branch {
471 Branch::Remote {
472 name: name.into(),
473 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
474 }
475 }
476}