compliance: Make trait more flexible (#53914)

This will make this easier to use with the GitHub worker.

Release Notes:

- N/A
This commit is contained in:
Finn Evers 2026-04-14 23:25:49 +02:00 committed by GitHub
parent 652f1fa3b0
commit d367d3fbbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 165 additions and 60 deletions

View file

@ -6,7 +6,7 @@ use crate::{
git::{CommitDetails, CommitList},
github::{
CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData,
PullRequestReview, ReviewState,
PullRequestReview, Repository, ReviewState,
},
report::Report,
};
@ -118,7 +118,10 @@ impl<'a> Reporter<'a> {
return Err(ReviewFailure::NoPullRequestFound);
};
let pull_request = self.github_client.get_pull_request(pr_number).await?;
let pull_request = self
.github_client
.get_pull_request(&Repository::ZED, pr_number)
.await?;
if let Some(approval) = self
.check_approving_pull_request_review(&pull_request)
@ -152,7 +155,7 @@ impl<'a> Reporter<'a> {
if commit.co_authors().is_some()
&& let Some(commit_authors) = self
.github_client
.get_commit_authors(&[commit.sha()])
.get_commit_authors(&Repository::ZED, &[commit.sha()])
.await?
.get(commit.sha())
.and_then(|authors| authors.co_authors())
@ -162,7 +165,7 @@ impl<'a> Reporter<'a> {
if let Some(github_login) = co_author.user()
&& self
.github_client
.actor_has_repository_write_permission(github_login)
.check_repo_write_permission(&Repository::ZED, github_login)
.await?
{
org_co_authors.push(co_author.clone());
@ -186,7 +189,7 @@ impl<'a> Reporter<'a> {
if let Some(user) = pull_request.user
&& self
.github_client
.actor_has_repository_write_permission(&GithubLogin::new(user.login))
.check_repo_write_permission(&Repository::ZED, &GithubLogin::new(user.login))
.await?
.not()
{
@ -209,7 +212,7 @@ impl<'a> Reporter<'a> {
) -> Result<Option<ReviewSuccess>, ReviewFailure> {
let pr_reviews = self
.github_client
.get_pull_request_reviews(pull_request.number)
.get_pull_request_reviews(&Repository::ZED, pull_request.number)
.await?;
if !pr_reviews.is_empty() {
@ -229,9 +232,10 @@ impl<'a> Reporter<'a> {
.is_some_and(Self::contains_approving_pattern))
&& self
.github_client
.actor_has_repository_write_permission(&GithubLogin::new(
github_login.login.clone(),
))
.check_repo_write_permission(
&Repository::ZED,
&GithubLogin::new(github_login.login.clone()),
)
.await?
{
org_approving_reviews.push(review);
@ -253,7 +257,7 @@ impl<'a> Reporter<'a> {
) -> Result<Option<ReviewSuccess>, ReviewFailure> {
let other_comments = self
.github_client
.get_pull_request_comments(pull_request.number)
.get_pull_request_comments(&Repository::ZED, pull_request.number)
.await?;
if !other_comments.is_empty() {
@ -270,9 +274,10 @@ impl<'a> Reporter<'a> {
.is_some_and(Self::contains_approving_pattern)
&& self
.github_client
.actor_has_repository_write_permission(&GithubLogin::new(
comment.user.login.clone(),
))
.check_repo_write_permission(
&Repository::ZED,
&GithubLogin::new(comment.user.login.clone()),
)
.await?
{
org_approving_comments.push(comment);
@ -327,7 +332,7 @@ mod tests {
use crate::git::{CommitDetails, CommitList, CommitSha};
use crate::github::{
AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin,
PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
PullRequestComment, PullRequestData, PullRequestReview, Repository, ReviewState,
};
use super::{Reporter, ReviewFailure, ReviewSuccess};
@ -342,12 +347,17 @@ mod tests {
#[async_trait::async_trait(?Send)]
impl GitHubApiClient for MockGitHubApi {
async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result<PullRequestData> {
async fn get_pull_request(
&self,
_repo: &Repository<'_>,
_pr_number: u64,
) -> anyhow::Result<PullRequestData> {
Ok(self.pull_request.clone())
}
async fn get_pull_request_reviews(
&self,
_repo: &Repository<'_>,
_pr_number: u64,
) -> anyhow::Result<Vec<PullRequestReview>> {
Ok(self.reviews.clone())
@ -355,6 +365,7 @@ mod tests {
async fn get_pull_request_comments(
&self,
_repo: &Repository<'_>,
_pr_number: u64,
) -> anyhow::Result<Vec<PullRequestComment>> {
Ok(self.comments.clone())
@ -362,23 +373,29 @@ mod tests {
async fn get_commit_authors(
&self,
_repo: &Repository<'_>,
_commit_shas: &[&CommitSha],
) -> anyhow::Result<AuthorsForCommits> {
serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into)
}
async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result<bool> {
async fn check_repo_write_permission(
&self,
_repo: &Repository<'_>,
login: &GithubLogin,
) -> anyhow::Result<bool> {
Ok(self
.org_members
.iter()
.any(|member| member == login.as_str()))
}
async fn check_repo_write_permission(&self, _login: &GithubLogin) -> anyhow::Result<bool> {
Ok(false)
}
async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> {
async fn add_label_to_issue(
&self,
_repo: &Repository<'_>,
_label: &str,
_pr_number: u64,
) -> anyhow::Result<()> {
Ok(())
}
}

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc};
use anyhow::Result;
use derive_more::Deref;
@ -141,22 +141,73 @@ impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
}
}
#[derive(Clone)]
pub struct Repository<'a> {
owner: Cow<'a, str>,
name: Cow<'a, str>,
}
impl<'a> Repository<'a> {
pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
pub fn new(owner: &'a str, name: &'a str) -> Self {
Self {
owner: Cow::Borrowed(owner),
name: Cow::Borrowed(name),
}
}
pub fn owner(&self) -> &str {
&self.owner
}
pub fn name(&self) -> &str {
&self.name
}
}
impl Repository<'static> {
pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
Self {
owner: Cow::Borrowed(owner),
name: Cow::Borrowed(name),
}
}
}
#[async_trait::async_trait(?Send)]
pub trait GitHubApiClient {
async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
async fn actor_has_repository_write_permission(
async fn get_pull_request(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<PullRequestData>;
async fn get_pull_request_reviews(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<Vec<PullRequestReview>>;
async fn get_pull_request_comments(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<Vec<PullRequestComment>>;
async fn get_commit_authors(
&self,
repo: &Repository<'_>,
commit_shas: &[&CommitSha],
) -> Result<AuthorsForCommits>;
async fn check_repo_write_permission(
&self,
repo: &Repository<'_>,
login: &GithubLogin,
) -> anyhow::Result<bool> {
Ok(self.check_org_membership(login).await?
|| self.check_repo_write_permission(login).await?)
}
async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
) -> Result<bool>;
async fn add_label_to_issue(
&self,
repo: &Repository<'_>,
label: &str,
issue_number: u64,
) -> Result<()>;
}
#[derive(Deref)]
@ -170,8 +221,8 @@ impl GitHubClient {
}
#[cfg(feature = "octo-client")]
pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
let client = OctocrabClient::new(app_id, app_private_key).await?;
pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
let client = OctocrabClient::new(app_id, app_private_key, org).await?;
Ok(Self::new(Rc::new(client)))
}
}
@ -276,7 +327,10 @@ mod octo_client {
use serde::de::DeserializeOwned;
use tokio::pin;
use crate::{git::CommitSha, github::graph_ql};
use crate::{
git::CommitSha,
github::{Repository, graph_ql},
};
use super::{
AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
@ -284,15 +338,13 @@ mod octo_client {
};
const PAGE_SIZE: u8 = 100;
const ORG: &str = "zed-industries";
const REPO: &str = "zed";
pub struct OctocrabClient {
client: Octocrab,
}
impl OctocrabClient {
pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
let octocrab = Octocrab::builder()
.cache(InMemoryCache::new())
.app(
@ -311,7 +363,7 @@ mod octo_client {
let installation_id = installations
.into_iter()
.find(|installation| installation.account.login == ORG)
.find(|installation| installation.account.login == org)
.context("Could not find Zed repository in installations")?
.id;
@ -355,8 +407,16 @@ mod octo_client {
#[async_trait::async_trait(?Send)]
impl GitHubApiClient for OctocrabClient {
async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
async fn get_pull_request(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<PullRequestData> {
let pr = self
.client
.pulls(repo.owner.as_ref(), repo.name.as_ref())
.get(pr_number)
.await?;
Ok(PullRequestData {
number: pr.number,
user: pr.user.map(|user| GitHubUser { login: user.login }),
@ -367,10 +427,14 @@ mod octo_client {
})
}
async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
async fn get_pull_request_reviews(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<Vec<PullRequestReview>> {
let page = self
.client
.pulls(ORG, REPO)
.pulls(repo.owner.as_ref(), repo.name.as_ref())
.list_reviews(pr_number)
.per_page(PAGE_SIZE)
.send()
@ -393,11 +457,12 @@ mod octo_client {
async fn get_pull_request_comments(
&self,
repo: &Repository<'_>,
pr_number: u64,
) -> Result<Vec<PullRequestComment>> {
let page = self
.client
.issues(ORG, REPO)
.issues(repo.owner.as_ref(), repo.name.as_ref())
.list_comments(pr_number)
.per_page(PAGE_SIZE)
.send()
@ -418,19 +483,29 @@ mod octo_client {
async fn get_commit_authors(
&self,
repo: &Repository<'_>,
commit_shas: &[&CommitSha],
) -> Result<AuthorsForCommits> {
let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
let query = graph_ql::build_co_authors_query(
repo.owner.as_ref(),
repo.name.as_ref(),
commit_shas.iter().copied(),
);
let query = serde_json::json!({ "query": query });
self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
.await
.map(|response| response.repository)
}
async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
async fn check_repo_write_permission(
&self,
repo: &Repository<'_>,
login: &GithubLogin,
) -> Result<bool> {
// Check org membership first - we save ourselves a few request that way
let page = self
.client
.orgs(ORG)
.orgs(repo.owner.as_ref())
.list_members()
.per_page(PAGE_SIZE)
.send()
@ -438,12 +513,13 @@ mod octo_client {
let members = self.get_all(page).await?;
Ok(members
if members
.into_iter()
.any(|member| member.login == login.as_str()))
}
.any(|member| member.login == login.as_str())
{
return Ok(true);
}
async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
// TODO: octocrab fails to deserialize the permission response and
// does not adhere to the scheme laid out at
// https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
@ -466,7 +542,9 @@ mod octo_client {
self.client
.get::<RepositoryPermissions, _, _>(
format!(
"/repos/{ORG}/{REPO}/collaborators/{user}/permission",
"/repos/{owner}/{repo}/collaborators/{user}/permission",
owner = repo.owner.as_ref(),
repo = repo.name.as_ref(),
user = login.as_str()
),
None::<&()>,
@ -481,9 +559,14 @@ mod octo_client {
.map_err(Into::into)
}
async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
async fn add_label_to_issue(
&self,
repo: &Repository<'_>,
label: &str,
issue_number: u64,
) -> Result<()> {
self.client
.issues(ORG, REPO)
.issues(repo.owner.as_ref(), repo.name.as_ref())
.add_labels(issue_number, &[label.to_owned()])
.await
.map(|_| ())

View file

@ -6,7 +6,7 @@ use clap::Parser;
use compliance::{
checks::Reporter,
git::{CommitsFromVersionToVersion, GetVersionTags, GitCommand, VersionTag},
github::GitHubClient,
github::{GitHubClient, Repository},
report::ReportReviewSummary,
};
@ -69,9 +69,10 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
println!("Checking commit range {range}, {} total", commits.len());
let client = GitHubClient::for_app(
let client = GitHubClient::for_app_in_repo(
app_id.parse().context("Failed to parse app ID as int")?,
key.as_ref(),
Repository::ZED.owner(),
)
.await?;
@ -93,7 +94,7 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
for report in report.errors() {
if let Some(pr_number) = report.commit.pr_number()
&& let Ok(pull_request) = client.get_pull_request(pr_number).await
&& let Ok(pull_request) = client.get_pull_request(&Repository::ZED, pr_number).await
&& pull_request.labels.is_none_or(|labels| {
labels
.iter()
@ -103,7 +104,11 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
println!("Adding review label to PR {}...", pr_number);
client
.add_label_to_issue(compliance::github::PR_REVIEW_LABEL, pr_number)
.add_label_to_issue(
&Repository::ZED,
compliance::github::PR_REVIEW_LABEL,
pr_number,
)
.await?;
}
}