diff --git a/src/main.rs b/src/main.rs index 76bfae3..1b654f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,7 @@ -use std::fs; +use std::{ + fs::{self, File}, + io::{self, BufRead}, +}; use clap::{Args, Parser, Subcommand}; use color_eyre::eyre::Result; @@ -16,6 +19,17 @@ struct Cli { enum CliCommands { Track(Track), GetDiff(GetDiff), + UpdatePrs(UpdatePRs), +} + +#[derive(Args, Debug, Default)] +struct UpdatePRs { + #[arg(long, default_value_t = ("./pr.txt".to_string()))] + pr_file: String, + #[arg(long, default_value_t = ("./patches/PR".to_string()))] + path: String, + #[arg(long, default_value_t = ("nixos-unstable".to_string()))] + branch: String, } #[derive(Args, Debug, Default)] @@ -41,9 +55,10 @@ const BRANCHES: &[&str] = &[ struct PR { title: String, state: String, - merged: bool, + merged: Option, merged_at: Option, - merge_commit_sha: String, + // merge conflicts don't have a merge_commit_sha + merge_commit_sha: Option, mergeable: bool, } @@ -64,16 +79,18 @@ async fn get_pr(pr: &str, client: &Client) -> Result { Ok(v) } async fn contains(branch: &str, pr: &PR, client: &Client) -> Result { - let req = client - .get(format!( - "https://api.github.com/repos/nixos/nixpkgs/compare/{}...{}", - branch, pr.merge_commit_sha - )) - .send(); - let text = &req.await?.text().await?; - let v: Compare = serde_json::from_str(text)?; - if v.status == "identical" || v.status == "behind" { - return Ok(true); + if let Some(sha) = &pr.merge_commit_sha { + let req = client + .get(format!( + "https://api.github.com/repos/nixos/nixpkgs/compare/{}...{}", + branch, sha + )) + .send(); + let text = &req.await?.text().await?; + let v: Compare = serde_json::from_str(text)?; + if v.status == "identical" || v.status == "behind" { + return Ok(true); + } } Ok(false) } @@ -101,6 +118,25 @@ async fn main() -> Result<()> { CliCommands::GetDiff(opts) => { get_diff(&opts.pr, &opts.path, &client).await?; } + CliCommands::UpdatePrs(opts) => { + let mut trackable: Vec = Vec::new(); + { + let file = File::open(opts.pr_file.clone())?; + for l in io::BufReader::new(file).lines() { + let l = l?; + let pr = get_pr(&l, &client).await?; + println!("Fetching diff for PR #{}: {}", l, pr.title); + if contains(&opts.branch, &pr, &client).await? { + println!("PR has reached {}, removing diff", opts.branch); + fs::remove_file(format!("{}/{}.diff", opts.path, l))?; + } else { + get_diff(&l, &opts.path, &client).await?; + trackable.push(l.clone()); + } + } + } + fs::write(&opts.pr_file, trackable.join("\n"))?; + } }; Ok(()) }