Fix issues processing captured edit prediction examples (#46773)

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Max Brunsfeld 2026-01-14 14:32:42 -08:00 committed by GitHub
parent 9c5fc6ecbd
commit 445c95aa3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 289 additions and 17 deletions

5
Cargo.lock generated
View file

@ -5292,6 +5292,7 @@ dependencies = [
"dirs 4.0.0",
"edit_prediction",
"extension",
"flate2",
"fs",
"futures 0.3.31",
"gpui",
@ -6252,9 +6253,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369"
dependencies = [
"crc32fast",
"miniz_oxide",

View file

@ -13,6 +13,7 @@ use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::
use text::{BufferSnapshot as TextBufferSnapshot, Point};
pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
pub fn capture_example(
project: Entity<Project>,
@ -232,10 +233,15 @@ fn generate_timestamp_name() -> String {
}
pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
let default_rate = if cx.is_staff() {
DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
} else {
DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
};
let capture_rate = language::language_settings::all_language_settings(None, cx)
.edit_predictions
.example_capture_rate
.unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
.unwrap_or(default_rate);
cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
&& rand::random::<u16>() % 10_000 < capture_rate
}

View file

@ -214,6 +214,54 @@ pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
Ok(result)
}
pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
if prefix.is_empty() {
return Cow::Borrowed(diff);
}
let prefix_with_slash = format!("{}/", prefix);
let mut needs_rewrite = false;
for line in diff.lines() {
match DiffLine::parse(line) {
DiffLine::OldPath { path } | DiffLine::NewPath { path } => {
if path.starts_with(&prefix_with_slash) {
needs_rewrite = true;
break;
}
}
_ => {}
}
}
if !needs_rewrite {
return Cow::Borrowed(diff);
}
let mut result = String::with_capacity(diff.len());
for line in diff.lines() {
match DiffLine::parse(line) {
DiffLine::OldPath { path } => {
let stripped = path
.strip_prefix(&prefix_with_slash)
.unwrap_or(path.as_ref());
result.push_str(&format!("--- a/{}\n", stripped));
}
DiffLine::NewPath { path } => {
let stripped = path
.strip_prefix(&prefix_with_slash)
.unwrap_or(path.as_ref());
result.push_str(&format!("+++ b/{}\n", stripped));
}
_ => {
result.push_str(line);
result.push('\n');
}
}
}
Cow::Owned(result)
}
/// Strip unnecessary git metadata lines from a diff, keeping only the lines
/// needed for patch application: path headers (--- and +++), hunk headers (@@),
/// and content lines (+, -, space).

View file

@ -57,6 +57,7 @@ wasmtime.workspace = true
zeta_prompt.workspace = true
rand.workspace = true
similar = "2.7.0"
flate2 = "1.1.8"
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.

View file

@ -5,7 +5,7 @@ use crate::{
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Context as _, Result};
use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries};
use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix};
use edit_prediction::{
EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2,
};
@ -111,8 +111,16 @@ async fn cursor_position(
}
let cursor_path_str = example.spec.cursor_path.to_string_lossy();
// Also try cursor path with first component stripped - old examples may have
// paths like "zed/crates/foo.rs" instead of "crates/foo.rs".
let cursor_path_without_prefix: PathBuf =
example.spec.cursor_path.components().skip(1).collect();
let cursor_path_without_prefix_str = cursor_path_without_prefix.to_string_lossy();
// We try open_buffers first because the file might be new and not saved to disk
let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) {
let cursor_buffer = if let Some(buffer) = open_buffers.get(cursor_path_str.as_ref()) {
buffer.clone()
} else if let Some(buffer) = open_buffers.get(cursor_path_without_prefix_str.as_ref()) {
buffer.clone()
} else {
// Since the worktree scanner is disabled, manually refresh entries for the cursor path.
@ -122,7 +130,9 @@ async fn cursor_position(
let cursor_path = project
.read_with(cx, |project, cx| {
project.find_project_path(&example.spec.cursor_path, cx)
project
.find_project_path(&example.spec.cursor_path, cx)
.or_else(|| project.find_project_path(&cursor_path_without_prefix, cx))
})
.with_context(|| {
format!(
@ -282,9 +292,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !example.spec.uncommitted_diff.is_empty() {
step_progress.set_substatus("applying diff");
// old examples had full paths in the uncommitted diff.
let uncommitted_diff =
strip_diff_path_prefix(&example.spec.uncommitted_diff, &repo_name.name);
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
@ -292,9 +306,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
.spawn()?;
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
stdin
.write_all(example.spec.uncommitted_diff.as_bytes())
.await?;
stdin.write_all(uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);

View file

@ -21,7 +21,7 @@ use collections::HashSet;
use edit_prediction::EditPredictionStore;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use gpui::{AppContext as _, Application};
use gpui::{AppContext as _, Application, BackgroundExecutor};
use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
@ -279,6 +279,7 @@ async fn load_examples(
http_client: Arc<dyn http_client::HttpClient>,
args: &EpArgs,
output_path: Option<&PathBuf>,
background_executor: BackgroundExecutor,
) -> anyhow::Result<Vec<Example>> {
let mut captured_after_timestamps = Vec::new();
let mut file_inputs = Vec::new();
@ -312,6 +313,7 @@ async fn load_examples(
http_client,
&captured_after_timestamps,
max_rows_per_timestamp,
background_executor,
)
.await?;
examples.append(&mut captured_examples);
@ -465,8 +467,13 @@ fn main() {
cx.spawn(async move |cx| {
let result = async {
let mut examples =
load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
let mut examples = load_examples(
app_state.client.http_client(),
&args,
output.as_ref(),
cx.background_executor().clone(),
)
.await?;
match &command {
Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {

View file

@ -1,9 +1,13 @@
use anyhow::{Context as _, Result};
use flate2::read::GzDecoder;
use gpui::BackgroundExecutor;
use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
use std::io::Read;
use std::sync::Arc;
use std::time::Duration;
use crate::{
example::Example,
@ -12,9 +16,12 @@ use crate::{
use edit_prediction::example_spec::ExampleSpec;
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
const POLL_INTERVAL: Duration = Duration::from_secs(2);
const MAX_POLL_ATTEMPTS: usize = 120;
/// Parse an input token of the form `captured-after:{timestamp}`.
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@ -25,6 +32,7 @@ pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
max_rows_per_timestamp: usize,
background_executor: BackgroundExecutor,
) -> Result<Vec<Example>> {
if after_timestamps.is_empty() {
return Ok(Vec::new());
@ -70,13 +78,60 @@ pub async fn fetch_captured_examples_after(
}
});
let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
let response = run_sql_with_polling(
http_client.clone(),
&base_url,
&token,
&request,
&step_progress,
background_executor.clone(),
)
.await?;
step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
let total_rows = response
.result_set_meta_data
.as_ref()
.and_then(|m| m.num_rows)
.unwrap_or(response.data.len() as i64);
let num_partitions = response
.result_set_meta_data
.as_ref()
.map(|m| m.partition_info.len())
.unwrap_or(1)
.max(1);
step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
step_progress.set_substatus("parsing");
all_examples.extend(examples_from_response(&response)?);
if num_partitions > 1 {
let statement_handle = response
.statement_handle
.as_ref()
.context("response has multiple partitions but no statementHandle")?;
for partition in 1..num_partitions {
step_progress.set_substatus(format!(
"fetching partition {}/{}",
partition + 1,
num_partitions
));
let partition_response = fetch_partition(
http_client.clone(),
&base_url,
&token,
statement_handle,
partition,
)
.await?;
all_examples.extend(examples_from_response(&partition_response)?);
}
}
step_progress.set_substatus("done");
}
@ -84,6 +139,7 @@ pub async fn fetch_captured_examples_after(
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SnowflakeStatementResponse {
#[serde(default)]
data: Vec<Vec<JsonValue>>,
@ -93,14 +149,25 @@ struct SnowflakeStatementResponse {
code: Option<String>,
#[serde(default)]
message: Option<String>,
#[serde(default)]
statement_handle: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SnowflakeResultSetMetaData {
#[serde(default, rename = "rowType")]
row_type: Vec<SnowflakeColumnMeta>,
#[serde(default)]
num_rows: Option<i64>,
#[serde(default)]
partition_info: Vec<SnowflakePartitionInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SnowflakePartitionInfo {}
#[derive(Debug, Clone, Deserialize)]
struct SnowflakeColumnMeta {
#[serde(default)]
@ -109,7 +176,7 @@ struct SnowflakeColumnMeta {
fn examples_from_response(
response: &SnowflakeStatementResponse,
) -> Result<impl Iterator<Item = Example>> {
) -> Result<impl Iterator<Item = Example> + '_> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@ -169,6 +236,136 @@ fn examples_from_response(
Ok(iter)
}
async fn run_sql_with_polling(
http_client: Arc<dyn HttpClient>,
base_url: &str,
token: &str,
request: &serde_json::Value,
step_progress: &crate::progress::StepProgress,
background_executor: BackgroundExecutor,
) -> Result<SnowflakeStatementResponse> {
let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
let statement_handle = response
.statement_handle
.as_ref()
.context("async query response missing statementHandle")?
.clone();
for attempt in 1..=MAX_POLL_ATTEMPTS {
step_progress.set_substatus(format!("polling ({attempt})"));
background_executor.timer(POLL_INTERVAL).await;
response =
fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
break;
}
}
if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
anyhow::bail!(
"query still running after {} poll attempts ({} seconds)",
MAX_POLL_ATTEMPTS,
MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
);
}
}
Ok(response)
}
async fn fetch_partition(
http_client: Arc<dyn HttpClient>,
base_url: &str,
token: &str,
statement_handle: &str,
partition: usize,
) -> Result<SnowflakeStatementResponse> {
let url = format!(
"{}/api/v2/statements/{}?partition={}",
base_url.trim_end_matches('/'),
statement_handle,
partition
);
let http_request = Request::builder()
.method(Method::GET)
.uri(url.as_str())
.header("Authorization", format!("Bearer {token}"))
.header(
"X-Snowflake-Authorization-Token-Type",
"PROGRAMMATIC_ACCESS_TOKEN",
)
.header("Accept", "application/json")
.header("Accept-Encoding", "gzip")
.body(AsyncBody::empty())?;
let response = http_client
.send(http_request)
.await
.context("failed to send partition request to Snowflake SQL API")?;
let status = response.status();
let content_encoding = response
.headers()
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_lowercase());
let body_bytes = {
use futures::AsyncReadExt as _;
let mut body = response.into_body();
let mut bytes = Vec::new();
body.read_to_end(&mut bytes)
.await
.context("failed to read Snowflake SQL API partition response body")?;
bytes
};
let body_bytes = if content_encoding.as_deref() == Some("gzip") {
let mut decoder = GzDecoder::new(&body_bytes[..]);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.context("failed to decompress gzip response")?;
decompressed
} else {
body_bytes
};
if !status.is_success() && status.as_u16() != 202 {
let body_text = String::from_utf8_lossy(&body_bytes);
anyhow::bail!(
"snowflake sql api partition request http {}: {}",
status.as_u16(),
body_text
);
}
if body_bytes.is_empty() {
anyhow::bail!(
"snowflake sql api partition {} returned empty response body (http {})",
partition,
status.as_u16()
);
}
serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
format!(
"failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
partition,
status.as_u16(),
body_preview
)
})
}
async fn run_sql(
http_client: Arc<dyn HttpClient>,
base_url: &str,
@ -209,7 +406,7 @@ async fn run_sql(
bytes
};
if !status.is_success() {
if !status.is_success() && status.as_u16() != 202 {
let body_text = String::from_utf8_lossy(&body_bytes);
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
}