1use std::collections::HashSet;
7use std::path::Path;
8use std::process::Command;
9
10use chrono::DateTime;
11use chrono::Utc;
12use git2::Cred;
13use git2::RemoteCallbacks;
14use git2::Repository;
15use rootcause::bail;
16use rootcause::prelude::ResultExt;
17use rootcause::report;
18use ytil_cmd::CmdError;
19use ytil_cmd::CmdExt as _;
20
21pub fn get_default() -> rootcause::Result<String> {
31 let repo_path = Path::new(".");
32 let repo = crate::repo::discover(repo_path)
33 .context("error getting repo for getting default branch")
34 .attach_with(|| format!("path={}", repo_path.display()))?;
35
36 let default_remote_ref = crate::remote::get_default(&repo)?;
37
38 let Some(target) = default_remote_ref.symbolic_target() else {
39 bail!("error missing default branch");
40 };
41
42 Ok(target
43 .split('/')
44 .next_back()
45 .ok_or_else(|| report!("error extracting default branch_name from target"))
46 .attach_with(|| format!("target={target:?}"))?
47 .to_string())
48}
49
50pub fn get_current() -> rootcause::Result<String> {
55 let repo_path = Path::new(".");
56 let repo = crate::repo::discover(repo_path)
57 .context("error getting repo for getting current branch")
58 .attach_with(|| format!("path={}", repo_path.display()))?;
59
60 if repo
61 .head_detached()
62 .context("error checking if head is detached")
63 .attach_with(|| format!("path={}", repo_path.display()))?
64 {
65 Err(report!("error head is detached")).attach_with(|| format!("path={}", repo_path.display()))?;
66 }
67
68 repo.head()
69 .context("error getting head")
70 .attach_with(|| format!("path={}", repo_path.display()))?
71 .shorthand()
72 .map(str::to_string)
73 .ok_or_else(|| report!("error invalid branch shorthand UTF-8"))
74 .attach_with(|| format!("path={}", repo_path.display()))
75}
76
77pub fn create_from_default_branch(branch_name: &str, repo: Option<&Repository>) -> rootcause::Result<()> {
82 let repo = if let Some(repo) = repo {
83 repo
84 } else {
85 let path = Path::new(".");
86 &crate::repo::discover(path)
87 .context("error getting repo for creating new branch")
88 .attach_with(|| format!("path={} branch={branch_name:?}", path.display()))?
89 };
90
91 let commit = repo
92 .head()
93 .context("error getting head")
94 .attach_with(|| format!("branch_name={branch_name:?}"))?
95 .peel_to_commit()
96 .context("error peeling head to commit")
97 .attach_with(|| format!("branch_name={branch_name:?}"))?;
98
99 repo.branch(branch_name, &commit, false)
100 .context("error creating branch")
101 .attach_with(|| format!("branch_name={branch_name:?}"))?;
102
103 Ok(())
104}
105
106pub fn push(branch_name: &str, repo: Option<&Repository>) -> rootcause::Result<()> {
117 let repo = if let Some(repo) = repo {
118 repo
119 } else {
120 let path = Path::new(".");
121 &crate::repo::discover(path)
122 .context("error getting repo for pushing new branch")
123 .attach_with(|| format!("path={} branch={branch_name:?}", path.display()))?
124 };
125
126 let default_remote = crate::remote::get_default(repo)?;
127
128 let default_remote_name = default_remote
129 .name()
130 .ok_or_else(|| report!("error missing name of default remote"))?
131 .trim_start_matches("refs/remotes/")
132 .trim_end_matches("/HEAD");
133
134 let mut remote = repo.find_remote(default_remote_name)?;
135
136 let mut callbacks = RemoteCallbacks::new();
137 callbacks.credentials(|_url, username_from_url, _allowed_types| {
138 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
139 });
140
141 let mut push_opts = git2::PushOptions::new();
142 push_opts.remote_callbacks(callbacks);
143
144 let branch_refspec = format!("refs/heads/{branch_name}");
145 remote
146 .push(&[&branch_refspec], Some(&mut push_opts))
147 .context("error pushing branch to remote")
148 .attach_with(|| format!("branch_refspec={branch_refspec:?} default_remote_name={default_remote_name:?}"))?;
149
150 Ok(())
151}
152
153pub fn get_previous(repo: &Repository) -> Option<String> {
161 let reflog = repo.reflog("HEAD").ok()?;
162 reflog.iter().find_map(|entry| {
163 let msg = entry.message()?;
164 let rest = msg
165 .strip_prefix("checkout: moving from ")
166 .or_else(|| msg.strip_prefix("switch: moving from "))?;
167 Some(rest.rsplit_once(" to ")?.0.to_string())
168 })
169}
170
171pub fn get_user_email(repo: &Repository) -> rootcause::Result<Option<String>> {
178 let config = repo.config().context("error opening repo config")?;
179 match config.get_string("user.email") {
180 Ok(email) => Ok(Some(email)),
181 Err(err) if err.code() == git2::ErrorCode::NotFound => Ok(None),
182 Err(err) => Err(report!("error reading user.email from repo config").attach(err.to_string())),
183 }
184}
185
186pub fn switch(branch_name: &str) -> Result<(), Box<CmdError>> {
191 Command::new("git")
192 .args(["switch", branch_name, "--guess"])
193 .exec()
194 .map_err(Box::new)?;
195 Ok(())
196}
197
198pub fn get_all(repo: &Repository) -> rootcause::Result<Vec<Branch>> {
209 fetch_with_repo(repo, &[]).context("error fetching branches")?;
210
211 let mut out = vec![];
212 for branch_res in repo.branches(None).context("error enumerating branches")? {
213 let branch = branch_res.context("error getting branch result")?;
214 out.push(Branch::try_from(branch).context("error creating branch from result")?);
215 }
216
217 out.sort_unstable_by(|a, b| b.committer_date_time().cmp(a.committer_date_time()));
218
219 Ok(out)
220}
221
222pub fn get_all_no_redundant(repo: &Repository) -> rootcause::Result<Vec<Branch>> {
232 let mut branches = get_all(repo)?;
233 remove_redundant_remotes(&mut branches);
234 Ok(branches)
235}
236
237pub fn fetch(branches: &[&str]) -> rootcause::Result<()> {
247 let repo_path = Path::new(".");
248 let repo = crate::repo::discover(repo_path)
249 .context("error getting repo for fetching branches")
250 .attach_with(|| format!("path={} branches={branches:?}", repo_path.display()))?;
251 fetch_with_repo(&repo, branches)
252}
253
254fn fetch_with_repo(repo: &Repository, branches: &[&str]) -> rootcause::Result<()> {
256 let mut callbacks = RemoteCallbacks::new();
257 callbacks.credentials(|_url, username_from_url, _allowed_types| {
258 Cred::ssh_key_from_agent(username_from_url.unwrap_or("git"))
259 });
260
261 let mut fetch_opts = git2::FetchOptions::new();
262 fetch_opts.remote_callbacks(callbacks);
263
264 repo.find_remote("origin")
265 .context("error finding origin remote")?
266 .fetch(branches, Some(&mut fetch_opts), None)
267 .context("error performing fetch from origin remote")
268 .attach_with(|| format!("branches={branches:?}"))?;
269
270 Ok(())
271}
272
273pub fn remove_redundant_remotes(branches: &mut Vec<Branch>) {
282 let local_names: HashSet<String> = branches
285 .iter()
286 .filter_map(|b| {
287 if let Branch::Local { name, .. } = b {
288 Some(name.clone())
289 } else {
290 None
291 }
292 })
293 .collect();
294
295 branches.retain(|b| match b {
296 Branch::Local { .. } => true,
297 Branch::Remote { name, .. } => {
298 let short = name.split_once('/').map_or(name.as_str(), |(_, rest)| rest);
299 !local_names.contains(short)
300 }
301 });
302}
303
304#[derive(Clone, Debug)]
306#[cfg_attr(any(test, feature = "test-utils"), derive(Eq, PartialEq))]
307pub enum Branch {
308 Local {
310 name: String,
312 committer_email: String,
314 committer_date_time: DateTime<Utc>,
316 },
317 Remote {
319 name: String,
321 committer_email: String,
323 committer_date_time: DateTime<Utc>,
325 },
326}
327
328impl Branch {
329 pub fn name(&self) -> &str {
331 match self {
332 Self::Local { name, .. } | Self::Remote { name, .. } => name,
333 }
334 }
335
336 pub fn name_no_origin(&self) -> &str {
338 self.name().trim_start_matches("origin/")
339 }
340
341 pub fn committer_email(&self) -> &str {
343 match self {
344 Self::Local { committer_email, .. } | Self::Remote { committer_email, .. } => committer_email,
345 }
346 }
347
348 pub const fn committer_date_time(&self) -> &DateTime<Utc> {
350 match self {
351 Self::Local {
352 committer_date_time, ..
353 }
354 | Self::Remote {
355 committer_date_time, ..
356 } => committer_date_time,
357 }
358 }
359}
360
361impl<'a> TryFrom<(git2::Branch<'a>, git2::BranchType)> for Branch {
371 type Error = rootcause::Report;
372
373 fn try_from((raw_branch, branch_type): (git2::Branch<'a>, git2::BranchType)) -> Result<Self, Self::Error> {
374 let branch_name = raw_branch
375 .name()?
376 .ok_or_else(|| report!("error invalid branch name UTF-8"))
377 .attach_with(|| format!("branch_name={:?}", raw_branch.name()))?;
378 let committer = raw_branch.get().peel_to_commit()?.committer().to_owned();
379 let committer_email = committer
380 .email()
381 .ok_or_else(|| report!("error invalid committer email UTF-8"))
382 .attach_with(|| format!("branch_name={branch_name:?}"))?
383 .to_string();
384 let committer_date_time = DateTime::from_timestamp(committer.when().seconds(), 0)
385 .ok_or_else(|| report!("error invalid commit timestamp"))
386 .attach_with(|| format!("seconds={}", committer.when().seconds()))?;
387
388 Ok(match branch_type {
389 git2::BranchType::Local => Self::Local {
390 name: branch_name.to_string(),
391 committer_email,
392 committer_date_time,
393 },
394 git2::BranchType::Remote => Self::Remote {
395 name: branch_name.to_string(),
396 committer_email,
397 committer_date_time,
398 },
399 })
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use git2::Time;
406 use rstest::rstest;
407
408 use super::*;
409
410 #[rstest]
411 #[case::remote_same_short_name(
412 vec![local("feature-x"), remote("origin/feature-x")],
413 vec![local("feature-x")]
414 )]
415 #[case::no_redundant(
416 vec![local("feature-x"), remote("origin/feature-y")],
417 vec![local("feature-x"), remote("origin/feature-y")]
418 )]
419 #[case::multiple_mixed(
420 vec![
421 local("feature-x"),
422 remote("origin/feature-x"),
423 remote("origin/feature-y"),
424 local("main"),
425 remote("upstream/main")
426 ],
427 vec![local("feature-x"), remote("origin/feature-y"), local("main")]
428 )]
429 #[case::different_remote_prefix(
430 vec![local("feature-x"), remote("upstream/feature-x")],
431 vec![local("feature-x")]
432 )]
433 fn remove_redundant_remotes_cases(#[case] mut input: Vec<Branch>, #[case] expected: Vec<Branch>) {
434 remove_redundant_remotes(&mut input);
435 assert_eq!(input, expected);
436 }
437
438 #[test]
439 fn test_branch_try_from_converts_local_branch_successfully() {
440 let (_temp_dir, repo) = crate::tests::init_test_repo(Some(Time::new(42, 3)));
441
442 let head_commit = repo.head().unwrap().peel_to_commit().unwrap();
443 let branch = repo.branch("test-branch", &head_commit, false).unwrap();
444
445 assert2::assert!(let Ok(result) = Branch::try_from((branch, git2::BranchType::Local)));
446
447 pretty_assertions::assert_eq!(
448 result,
449 Branch::Local {
450 name: "test-branch".to_string(),
451 committer_email: "test@example.com".to_string(),
452 committer_date_time: DateTime::from_timestamp(42, 0).unwrap(),
453 }
454 );
455 }
456
457 #[rstest]
458 #[case::local_variant(local("main"), "main")]
459 #[case::remote_variant(remote("origin/feature"), "origin/feature")]
460 fn test_branch_name_when_variant_returns_name(#[case] branch: Branch, #[case] expected: &str) {
461 pretty_assertions::assert_eq!(branch.name(), expected);
462 }
463
464 #[rstest]
465 #[case::local_no_origin(local("main"), "main")]
466 #[case::remote_origin_prefix(remote("origin/main"), "main")]
467 #[case::remote_other_prefix(remote("upstream/feature"), "upstream/feature")]
468 fn test_branch_name_no_origin_when_name_returns_trimmed(#[case] branch: Branch, #[case] expected: &str) {
469 pretty_assertions::assert_eq!(branch.name_no_origin(), expected);
470 }
471
472 #[rstest]
473 #[case::local_variant(
474 Branch::Local {
475 name: "test".to_string(),
476 committer_email: "a@b.com".to_string(),
477 committer_date_time: DateTime::from_timestamp(123_456, 0).unwrap(),
478 },
479 DateTime::from_timestamp(123_456, 0).unwrap()
480 )]
481 #[case::remote_variant(
482 Branch::Remote {
483 name: "origin/test".to_string(),
484 committer_email: "a@b.com".to_string(),
485 committer_date_time: DateTime::from_timestamp(654_321, 0).unwrap(),
486 },
487 DateTime::from_timestamp(654_321, 0).unwrap()
488 )]
489 fn branch_committer_date_time_when_variant_returns_date_time(
490 #[case] branch: Branch,
491 #[case] expected: DateTime<Utc>,
492 ) {
493 pretty_assertions::assert_eq!(branch.committer_date_time(), &expected);
494 }
495
496 fn local(name: &str) -> Branch {
497 Branch::Local {
498 name: name.into(),
499 committer_email: String::new(),
500 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
501 }
502 }
503
504 fn remote(name: &str) -> Branch {
505 Branch::Remote {
506 name: name.into(),
507 committer_email: String::new(),
508 committer_date_time: DateTime::from_timestamp(0, 0).unwrap(),
509 }
510 }
511}