1use std::collections::HashSet;
7use std::path::Path;
8use std::process::Command;
9
10use chrono::DateTime;
11use chrono::Utc;
12use git2::Cred;
13use git2::RemoteCallbacks;
14use git2::Repository;
15use rootcause::bail;
16use rootcause::prelude::ResultExt;
17use rootcause::report;
18use ytil_cmd::CmdError;
19use ytil_cmd::CmdExt as _;
20
21pub fn get_default() -> rootcause::Result<String> {
31 let repo_path = Path::new(".");
32 let repo = crate::repo::discover(repo_path)
33 .context("error getting repo for getting default branch")
34 .attach_with(|| format!("path={}", repo_path.display()))?;
35
36 let default_remote_ref = crate::remote::get_default(&repo)?;
37
38 let Some(target) = default_remote_ref.symbolic_target() else {
39 bail!("error missing default branch");
40 };
41
42 Ok(target
43 .split('/')
44 .next_back()
45 .ok_or_else(|| report!("error extracting default branch_name from target"))
46 .attach_with(|| format!("target={target:?}"))?
47 .to_string())
48}
49
50pub fn get_current() -> rootcause::Result<String> {
55 let repo_path = Path::new(".");
56 let repo = crate::repo::discover(repo_path)
57 .context("error getting repo for getting current branch")
58 .attach_with(|| format!("path={}", repo_path.display()))?;
59
60 if repo
61 .head_detached()
62 .context("error checking if head is detached")
63 .attach_with(|| format!("path={}", repo_path.display()))?
64 {
65 Err(report!("error head is detached")).attach_with(|| format!("path={}", repo_path.display()))?;
66 }
67
68 repo.head()
69 .context("error getting head")
70 .attach_with(|| format!("path={}", repo_path.display()))?
71 .shorthand()
72 .map(str::to_string)
73 .ok_or_else(|| report!("error invalid branch shorthand UTF-8"))
74 .attach_with(|| format!("path={}", repo_path.display()))
75}
76
77pub fn create_from_default_branch(branch_name: &str, repo: Option<&Repository>) -> rootcause::Result<()> {
82 let repo = if let Some(repo) = repo {
83 repo
84 } else {
85 let path = Path::new(".");
86 &crate::repo::discover(path)
87 .context("error getting repo for creating new branch")
88 .attach_with(|| format!("path={} branch={branch_name:?}", path.display()))?
89 };
90
91 let commit = repo
92 .head()
93 .context("error getting head")
94 .attach_with(|| format!("branch_name={branch_name:?}"))?
95 .peel_to_commit()
96 .context("error peeling head to commit")
97 .attach_with(|| format!("branch_name={branch_name:?}"))?;
98
99 repo.branch(branch_name, &commit, false)
100 .context("error creating branch")
101 .attach_with(|| format!("branch_name={branch_name:?}"))?;
102
103 Ok(())
104}
105
106pub fn push(branch_name: &str, repo: Option<&Repository>) -> rootcause::Result<()> {
117 let repo = if let Some(repo) = repo {
118 repo
119 } else {
120 let path = Path::new(".");
121 &crate::repo::discover(path)
122 .context("error getting repo for pushing new branch")
123 .attach_with(|| format!("path={} branch={branch_name:?}", path.display()))?
124 };
125
126 let default_remote = crate::remote::get_default(repo)?;
127
128 let default_remote_name = default_remote
129 .name()
130 .ok_or_else(|| report!("error missing name of default remote"))?
131 .trim_start_matches("refs/remotes/")
132 .trim_end_matches("/HEAD");
133
134 let mut remote = repo.find_remote(default_remote_name)?;
135
136 let mut callbacks = RemoteCallbacks::new();
137 callbacks.credentials(|_url, username_from_url, _allowed_types| {
138 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
139 });
140
141 let mut push_opts = git2::PushOptions::new();
142 push_opts.remote_callbacks(callbacks);
143
144 let branch_refspec = format!("refs/heads/{branch_name}");
145 remote
146 .push(&[&branch_refspec], Some(&mut push_opts))
147 .context("error pushing branch to remote")
148 .attach_with(|| format!("branch_refspec={branch_refspec:?} default_remote_name={default_remote_name:?}"))?;
149
150 Ok(())
151}
152
153pub fn get_previous(repo: &Repository) -> Option<String> {
161 let reflog = repo.reflog("HEAD").ok()?;
162 reflog.iter().find_map(|entry| {
163 let msg = entry.message()?;
164 let rest = msg
165 .strip_prefix("checkout: moving from ")
166 .or_else(|| msg.strip_prefix("switch: moving from "))?;
167 Some(rest.rsplit_once(" to ")?.0.to_string())
168 })
169}
170
171pub fn switch(branch_name: &str) -> Result<(), Box<CmdError>> {
176 Command::new("git")
177 .args(["switch", branch_name, "--guess"])
178 .exec()
179 .map_err(Box::new)?;
180 Ok(())
181}
182
183pub fn get_all(repo: &Repository) -> rootcause::Result<Vec<Branch>> {
194 fetch_with_repo(repo, &[]).context("error fetching branches")?;
195
196 let mut out = vec![];
197 for branch_res in repo.branches(None).context("error enumerating branches")? {
198 let branch = branch_res.context("error getting branch result")?;
199 out.push(Branch::try_from(branch).context("error creating branch from result")?);
200 }
201
202 out.sort_unstable_by(|a, b| b.committer_date_time().cmp(a.committer_date_time()));
203
204 Ok(out)
205}
206
207pub fn get_all_no_redundant(repo: &Repository) -> rootcause::Result<Vec<Branch>> {
217 let mut branches = get_all(repo)?;
218 remove_redundant_remotes(&mut branches);
219 Ok(branches)
220}
221
222pub fn fetch(branches: &[&str]) -> rootcause::Result<()> {
232 let repo_path = Path::new(".");
233 let repo = crate::repo::discover(repo_path)
234 .context("error getting repo for fetching branches")
235 .attach_with(|| format!("path={} branches={branches:?}", repo_path.display()))?;
236 fetch_with_repo(&repo, branches)
237}
238
239fn fetch_with_repo(repo: &Repository, branches: &[&str]) -> rootcause::Result<()> {
241 let mut callbacks = RemoteCallbacks::new();
242 callbacks.credentials(|_url, username_from_url, _allowed_types| {
243 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
244 });
245
246 let mut fetch_opts = git2::FetchOptions::new();
247 fetch_opts.remote_callbacks(callbacks);
248
249 repo.find_remote("origin")
250 .context("error finding origin remote")?
251 .fetch(branches, Some(&mut fetch_opts), None)
252 .context("error performing fetch from origin remote")
253 .attach_with(|| format!("branches={branches:?}"))?;
254
255 Ok(())
256}
257
258pub fn remove_redundant_remotes(branches: &mut Vec<Branch>) {
267 let local_names: HashSet<String> = branches
270 .iter()
271 .filter_map(|b| {
272 if let Branch::Local { name, .. } = b {
273 Some(name.clone())
274 } else {
275 None
276 }
277 })
278 .collect();
279
280 branches.retain(|b| match b {
281 Branch::Local { .. } => true,
282 Branch::Remote { name, .. } => {
283 let short = name.split_once('/').map_or(name.as_str(), |(_, rest)| rest);
284 !local_names.contains(short)
285 }
286 });
287}
288
289#[derive(Clone, Debug)]
291#[cfg_attr(test, derive(Eq, PartialEq))]
292pub enum Branch {
293 Local {
295 name: String,
297 committer_email: String,
299 committer_date_time: DateTime<Utc>,
301 },
302 Remote {
304 name: String,
306 committer_email: String,
308 committer_date_time: DateTime<Utc>,
310 },
311}
312
313impl Branch {
314 pub fn name(&self) -> &str {
316 match self {
317 Self::Local { name, .. } | Self::Remote { name, .. } => name,
318 }
319 }
320
321 pub fn name_no_origin(&self) -> &str {
323 self.name().trim_start_matches("origin/")
324 }
325
326 pub fn committer_email(&self) -> &str {
328 match self {
329 Self::Local { committer_email, .. } | Self::Remote { committer_email, .. } => committer_email,
330 }
331 }
332
333 pub const fn committer_date_time(&self) -> &DateTime<Utc> {
335 match self {
336 Self::Local {
337 committer_date_time, ..
338 }
339 | Self::Remote {
340 committer_date_time, ..
341 } => committer_date_time,
342 }
343 }
344}
345
346impl<'a> TryFrom<(git2::Branch<'a>, git2::BranchType)> for Branch {
356 type Error = rootcause::Report;
357
358 fn try_from((raw_branch, branch_type): (git2::Branch<'a>, git2::BranchType)) -> Result<Self, Self::Error> {
359 let branch_name = raw_branch
360 .name()?
361 .ok_or_else(|| report!("error invalid branch name UTF-8"))
362 .attach_with(|| format!("branch_name={:?}", raw_branch.name()))?;
363 let committer = raw_branch.get().peel_to_commit()?.committer().to_owned();
364 let committer_email = committer
365 .email()
366 .ok_or_else(|| report!("error invalid committer email UTF-8"))
367 .attach_with(|| format!("branch_name={branch_name:?}"))?
368 .to_string();
369 let committer_date_time = DateTime::from_timestamp(committer.when().seconds(), 0)
370 .ok_or_else(|| report!("error invalid commit timestamp"))
371 .attach_with(|| format!("seconds={}", committer.when().seconds()))?;
372
373 Ok(match branch_type {
374 git2::BranchType::Local => Self::Local {
375 name: branch_name.to_string(),
376 committer_email,
377 committer_date_time,
378 },
379 git2::BranchType::Remote => Self::Remote {
380 name: branch_name.to_string(),
381 committer_email,
382 committer_date_time,
383 },
384 })
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use git2::Time;
391 use rstest::rstest;
392
393 use super::*;
394
395 #[rstest]
396 #[case::remote_same_short_name(
397 vec![local("feature-x"), remote("origin/feature-x")],
398 vec![local("feature-x")]
399 )]
400 #[case::no_redundant(
401 vec![local("feature-x"), remote("origin/feature-y")],
402 vec![local("feature-x"), remote("origin/feature-y")]
403 )]
404 #[case::multiple_mixed(
405 vec![
406 local("feature-x"),
407 remote("origin/feature-x"),
408 remote("origin/feature-y"),
409 local("main"),
410 remote("upstream/main")
411 ],
412 vec![local("feature-x"), remote("origin/feature-y"), local("main")]
413 )]
414 #[case::different_remote_prefix(
415 vec![local("feature-x"), remote("upstream/feature-x")],
416 vec![local("feature-x")]
417 )]
418 fn remove_redundant_remotes_cases(#[case] mut input: Vec<Branch>, #[case] expected: Vec<Branch>) {
419 remove_redundant_remotes(&mut input);
420 assert_eq!(input, expected);
421 }
422
423 #[test]
424 fn branch_try_from_converts_local_branch_successfully() {
425 let (_temp_dir, repo) = crate::tests::init_test_repo(Some(Time::new(42, 3)));
426
427 let head_commit = repo.head().unwrap().peel_to_commit().unwrap();
428 let branch = repo.branch("test-branch", &head_commit, false).unwrap();
429
430 assert2::assert!(let Ok(result) = Branch::try_from((branch, git2::BranchType::Local)));
431
432 pretty_assertions::assert_eq!(
433 result,
434 Branch::Local {
435 name: "test-branch".to_string(),
436 committer_email: "test@example.com".to_string(),
437 committer_date_time: DateTime::from_timestamp(42, 0).unwrap(),
438 }
439 );
440 }
441
442 #[rstest]
443 #[case::local_variant(local("main"), "main")]
444 #[case::remote_variant(remote("origin/feature"), "origin/feature")]
445 fn branch_name_when_variant_returns_name(#[case] branch: Branch, #[case] expected: &str) {
446 pretty_assertions::assert_eq!(branch.name(), expected);
447 }
448
449 #[rstest]
450 #[case::local_no_origin(local("main"), "main")]
451 #[case::remote_origin_prefix(remote("origin/main"), "main")]
452 #[case::remote_other_prefix(remote("upstream/feature"), "upstream/feature")]
453 fn branch_name_no_origin_when_name_returns_trimmed(#[case] branch: Branch, #[case] expected: &str) {
454 pretty_assertions::assert_eq!(branch.name_no_origin(), expected);
455 }
456
457 #[rstest]
458 #[case::local_variant(
459 Branch::Local {
460 name: "test".to_string(),
461 committer_email: "a@b.com".to_string(),
462 committer_date_time: DateTime::from_timestamp(123_456, 0).unwrap(),
463 },
464 DateTime::from_timestamp(123_456, 0).unwrap()
465 )]
466 #[case::remote_variant(
467 Branch::Remote {
468 name: "origin/test".to_string(),
469 committer_email: "a@b.com".to_string(),
470 committer_date_time: DateTime::from_timestamp(654_321, 0).unwrap(),
471 },
472 DateTime::from_timestamp(654_321, 0).unwrap()
473 )]
474 fn branch_committer_date_time_when_variant_returns_date_time(
475 #[case] branch: Branch,
476 #[case] expected: DateTime<Utc>,
477 ) {
478 pretty_assertions::assert_eq!(branch.committer_date_time(), &expected);
479 }
480
481 fn local(name: &str) -> Branch {
482 Branch::Local {
483 name: name.into(),
484 committer_email: String::new(),
485 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
486 }
487 }
488
489 fn remote(name: &str) -> Branch {
490 Branch::Remote {
491 name: name.into(),
492 committer_email: String::new(),
493 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
494 }
495 }
496}