mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-30 20:24:08 +00:00
Fix issues processing captured edit prediction examples (#46773)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
9c5fc6ecbd
commit
445c95aa3c
7 changed files with 289 additions and 17 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
|
@ -5292,6 +5292,7 @@ dependencies = [
|
|||
"dirs 4.0.0",
|
||||
"edit_prediction",
|
||||
"extension",
|
||||
"flate2",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
|
|
@ -6252,9 +6253,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
|||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.1.4"
|
||||
version = "1.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
|
||||
checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::
|
|||
use text::{BufferSnapshot as TextBufferSnapshot, Point};
|
||||
|
||||
pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
|
||||
pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
|
||||
|
||||
pub fn capture_example(
|
||||
project: Entity<Project>,
|
||||
|
|
@ -232,10 +233,15 @@ fn generate_timestamp_name() -> String {
|
|||
}
|
||||
|
||||
pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
|
||||
let default_rate = if cx.is_staff() {
|
||||
DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
|
||||
} else {
|
||||
DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
|
||||
};
|
||||
let capture_rate = language::language_settings::all_language_settings(None, cx)
|
||||
.edit_predictions
|
||||
.example_capture_rate
|
||||
.unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
|
||||
.unwrap_or(default_rate);
|
||||
cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
|
||||
&& rand::random::<u16>() % 10_000 < capture_rate
|
||||
}
|
||||
|
|
|
|||
|
|
@ -214,6 +214,54 @@ pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
|
||||
if prefix.is_empty() {
|
||||
return Cow::Borrowed(diff);
|
||||
}
|
||||
|
||||
let prefix_with_slash = format!("{}/", prefix);
|
||||
let mut needs_rewrite = false;
|
||||
|
||||
for line in diff.lines() {
|
||||
match DiffLine::parse(line) {
|
||||
DiffLine::OldPath { path } | DiffLine::NewPath { path } => {
|
||||
if path.starts_with(&prefix_with_slash) {
|
||||
needs_rewrite = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !needs_rewrite {
|
||||
return Cow::Borrowed(diff);
|
||||
}
|
||||
|
||||
let mut result = String::with_capacity(diff.len());
|
||||
for line in diff.lines() {
|
||||
match DiffLine::parse(line) {
|
||||
DiffLine::OldPath { path } => {
|
||||
let stripped = path
|
||||
.strip_prefix(&prefix_with_slash)
|
||||
.unwrap_or(path.as_ref());
|
||||
result.push_str(&format!("--- a/{}\n", stripped));
|
||||
}
|
||||
DiffLine::NewPath { path } => {
|
||||
let stripped = path
|
||||
.strip_prefix(&prefix_with_slash)
|
||||
.unwrap_or(path.as_ref());
|
||||
result.push_str(&format!("+++ b/{}\n", stripped));
|
||||
}
|
||||
_ => {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Cow::Owned(result)
|
||||
}
|
||||
/// Strip unnecessary git metadata lines from a diff, keeping only the lines
|
||||
/// needed for patch application: path headers (--- and +++), hunk headers (@@),
|
||||
/// and content lines (+, -, space).
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ wasmtime.workspace = true
|
|||
zeta_prompt.workspace = true
|
||||
rand.workspace = true
|
||||
similar = "2.7.0"
|
||||
flate2 = "1.1.8"
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries};
|
||||
use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2,
|
||||
};
|
||||
|
|
@ -111,8 +111,16 @@ async fn cursor_position(
|
|||
}
|
||||
|
||||
let cursor_path_str = example.spec.cursor_path.to_string_lossy();
|
||||
// Also try cursor path with first component stripped - old examples may have
|
||||
// paths like "zed/crates/foo.rs" instead of "crates/foo.rs".
|
||||
let cursor_path_without_prefix: PathBuf =
|
||||
example.spec.cursor_path.components().skip(1).collect();
|
||||
let cursor_path_without_prefix_str = cursor_path_without_prefix.to_string_lossy();
|
||||
|
||||
// We try open_buffers first because the file might be new and not saved to disk
|
||||
let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) {
|
||||
let cursor_buffer = if let Some(buffer) = open_buffers.get(cursor_path_str.as_ref()) {
|
||||
buffer.clone()
|
||||
} else if let Some(buffer) = open_buffers.get(cursor_path_without_prefix_str.as_ref()) {
|
||||
buffer.clone()
|
||||
} else {
|
||||
// Since the worktree scanner is disabled, manually refresh entries for the cursor path.
|
||||
|
|
@ -122,7 +130,9 @@ async fn cursor_position(
|
|||
|
||||
let cursor_path = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.find_project_path(&example.spec.cursor_path, cx)
|
||||
project
|
||||
.find_project_path(&example.spec.cursor_path, cx)
|
||||
.or_else(|| project.find_project_path(&cursor_path_without_prefix, cx))
|
||||
})
|
||||
.with_context(|| {
|
||||
format!(
|
||||
|
|
@ -282,9 +292,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
|||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !example.spec.uncommitted_diff.is_empty() {
|
||||
step_progress.set_substatus("applying diff");
|
||||
|
||||
// old examples had full paths in the uncommitted diff.
|
||||
let uncommitted_diff =
|
||||
strip_diff_path_prefix(&example.spec.uncommitted_diff, &repo_name.name);
|
||||
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
|
|
@ -292,9 +306,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
|||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
|
||||
stdin
|
||||
.write_all(example.spec.uncommitted_diff.as_bytes())
|
||||
.await?;
|
||||
stdin.write_all(uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ use collections::HashSet;
|
|||
use edit_prediction::EditPredictionStore;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use gpui::{AppContext as _, Application};
|
||||
use gpui::{AppContext as _, Application, BackgroundExecutor};
|
||||
use zeta_prompt::ZetaVersion;
|
||||
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
|
@ -279,6 +279,7 @@ async fn load_examples(
|
|||
http_client: Arc<dyn http_client::HttpClient>,
|
||||
args: &EpArgs,
|
||||
output_path: Option<&PathBuf>,
|
||||
background_executor: BackgroundExecutor,
|
||||
) -> anyhow::Result<Vec<Example>> {
|
||||
let mut captured_after_timestamps = Vec::new();
|
||||
let mut file_inputs = Vec::new();
|
||||
|
|
@ -312,6 +313,7 @@ async fn load_examples(
|
|||
http_client,
|
||||
&captured_after_timestamps,
|
||||
max_rows_per_timestamp,
|
||||
background_executor,
|
||||
)
|
||||
.await?;
|
||||
examples.append(&mut captured_examples);
|
||||
|
|
@ -465,8 +467,13 @@ fn main() {
|
|||
|
||||
cx.spawn(async move |cx| {
|
||||
let result = async {
|
||||
let mut examples =
|
||||
load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
|
||||
let mut examples = load_examples(
|
||||
app_state.client.http_client(),
|
||||
&args,
|
||||
output.as_ref(),
|
||||
cx.background_executor().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use flate2::read::GzDecoder;
|
||||
use gpui::BackgroundExecutor;
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request};
|
||||
use indoc::indoc;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value as JsonValue, json};
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{
|
||||
example::Example,
|
||||
|
|
@ -12,9 +16,12 @@ use crate::{
|
|||
use edit_prediction::example_spec::ExampleSpec;
|
||||
|
||||
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
|
||||
const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
|
||||
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
|
||||
|
||||
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(2);
|
||||
const MAX_POLL_ATTEMPTS: usize = 120;
|
||||
|
||||
/// Parse an input token of the form `captured-after:{timestamp}`.
|
||||
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
|
||||
|
|
@ -25,6 +32,7 @@ pub async fn fetch_captured_examples_after(
|
|||
http_client: Arc<dyn HttpClient>,
|
||||
after_timestamps: &[String],
|
||||
max_rows_per_timestamp: usize,
|
||||
background_executor: BackgroundExecutor,
|
||||
) -> Result<Vec<Example>> {
|
||||
if after_timestamps.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
|
|
@ -70,13 +78,60 @@ pub async fn fetch_captured_examples_after(
|
|||
}
|
||||
});
|
||||
|
||||
let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
|
||||
let response = run_sql_with_polling(
|
||||
http_client.clone(),
|
||||
&base_url,
|
||||
&token,
|
||||
&request,
|
||||
&step_progress,
|
||||
background_executor.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
|
||||
let total_rows = response
|
||||
.result_set_meta_data
|
||||
.as_ref()
|
||||
.and_then(|m| m.num_rows)
|
||||
.unwrap_or(response.data.len() as i64);
|
||||
|
||||
let num_partitions = response
|
||||
.result_set_meta_data
|
||||
.as_ref()
|
||||
.map(|m| m.partition_info.len())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
|
||||
step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
|
||||
step_progress.set_substatus("parsing");
|
||||
|
||||
all_examples.extend(examples_from_response(&response)?);
|
||||
|
||||
if num_partitions > 1 {
|
||||
let statement_handle = response
|
||||
.statement_handle
|
||||
.as_ref()
|
||||
.context("response has multiple partitions but no statementHandle")?;
|
||||
|
||||
for partition in 1..num_partitions {
|
||||
step_progress.set_substatus(format!(
|
||||
"fetching partition {}/{}",
|
||||
partition + 1,
|
||||
num_partitions
|
||||
));
|
||||
|
||||
let partition_response = fetch_partition(
|
||||
http_client.clone(),
|
||||
&base_url,
|
||||
&token,
|
||||
statement_handle,
|
||||
partition,
|
||||
)
|
||||
.await?;
|
||||
|
||||
all_examples.extend(examples_from_response(&partition_response)?);
|
||||
}
|
||||
}
|
||||
|
||||
step_progress.set_substatus("done");
|
||||
}
|
||||
|
||||
|
|
@ -84,6 +139,7 @@ pub async fn fetch_captured_examples_after(
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct SnowflakeStatementResponse {
|
||||
#[serde(default)]
|
||||
data: Vec<Vec<JsonValue>>,
|
||||
|
|
@ -93,14 +149,25 @@ struct SnowflakeStatementResponse {
|
|||
code: Option<String>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
#[serde(default)]
|
||||
statement_handle: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct SnowflakeResultSetMetaData {
|
||||
#[serde(default, rename = "rowType")]
|
||||
row_type: Vec<SnowflakeColumnMeta>,
|
||||
#[serde(default)]
|
||||
num_rows: Option<i64>,
|
||||
#[serde(default)]
|
||||
partition_info: Vec<SnowflakePartitionInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct SnowflakePartitionInfo {}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SnowflakeColumnMeta {
|
||||
#[serde(default)]
|
||||
|
|
@ -109,7 +176,7 @@ struct SnowflakeColumnMeta {
|
|||
|
||||
fn examples_from_response(
|
||||
response: &SnowflakeStatementResponse,
|
||||
) -> Result<impl Iterator<Item = Example>> {
|
||||
) -> Result<impl Iterator<Item = Example> + '_> {
|
||||
if let Some(code) = &response.code {
|
||||
if code != SNOWFLAKE_SUCCESS_CODE {
|
||||
anyhow::bail!(
|
||||
|
|
@ -169,6 +236,136 @@ fn examples_from_response(
|
|||
Ok(iter)
|
||||
}
|
||||
|
||||
async fn run_sql_with_polling(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
request: &serde_json::Value,
|
||||
step_progress: &crate::progress::StepProgress,
|
||||
background_executor: BackgroundExecutor,
|
||||
) -> Result<SnowflakeStatementResponse> {
|
||||
let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
|
||||
|
||||
if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
|
||||
let statement_handle = response
|
||||
.statement_handle
|
||||
.as_ref()
|
||||
.context("async query response missing statementHandle")?
|
||||
.clone();
|
||||
|
||||
for attempt in 1..=MAX_POLL_ATTEMPTS {
|
||||
step_progress.set_substatus(format!("polling ({attempt})"));
|
||||
|
||||
background_executor.timer(POLL_INTERVAL).await;
|
||||
|
||||
response =
|
||||
fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
|
||||
|
||||
if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
|
||||
anyhow::bail!(
|
||||
"query still running after {} poll attempts ({} seconds)",
|
||||
MAX_POLL_ATTEMPTS,
|
||||
MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn fetch_partition(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
statement_handle: &str,
|
||||
partition: usize,
|
||||
) -> Result<SnowflakeStatementResponse> {
|
||||
let url = format!(
|
||||
"{}/api/v2/statements/{}?partition={}",
|
||||
base_url.trim_end_matches('/'),
|
||||
statement_handle,
|
||||
partition
|
||||
);
|
||||
|
||||
let http_request = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(url.as_str())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header(
|
||||
"X-Snowflake-Authorization-Token-Type",
|
||||
"PROGRAMMATIC_ACCESS_TOKEN",
|
||||
)
|
||||
.header("Accept", "application/json")
|
||||
.header("Accept-Encoding", "gzip")
|
||||
.body(AsyncBody::empty())?;
|
||||
|
||||
let response = http_client
|
||||
.send(http_request)
|
||||
.await
|
||||
.context("failed to send partition request to Snowflake SQL API")?;
|
||||
|
||||
let status = response.status();
|
||||
let content_encoding = response
|
||||
.headers()
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
let body_bytes = {
|
||||
use futures::AsyncReadExt as _;
|
||||
|
||||
let mut body = response.into_body();
|
||||
let mut bytes = Vec::new();
|
||||
body.read_to_end(&mut bytes)
|
||||
.await
|
||||
.context("failed to read Snowflake SQL API partition response body")?;
|
||||
bytes
|
||||
};
|
||||
|
||||
let body_bytes = if content_encoding.as_deref() == Some("gzip") {
|
||||
let mut decoder = GzDecoder::new(&body_bytes[..]);
|
||||
let mut decompressed = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut decompressed)
|
||||
.context("failed to decompress gzip response")?;
|
||||
decompressed
|
||||
} else {
|
||||
body_bytes
|
||||
};
|
||||
|
||||
if !status.is_success() && status.as_u16() != 202 {
|
||||
let body_text = String::from_utf8_lossy(&body_bytes);
|
||||
anyhow::bail!(
|
||||
"snowflake sql api partition request http {}: {}",
|
||||
status.as_u16(),
|
||||
body_text
|
||||
);
|
||||
}
|
||||
|
||||
if body_bytes.is_empty() {
|
||||
anyhow::bail!(
|
||||
"snowflake sql api partition {} returned empty response body (http {})",
|
||||
partition,
|
||||
status.as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
|
||||
let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
|
||||
format!(
|
||||
"failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
|
||||
partition,
|
||||
status.as_u16(),
|
||||
body_preview
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_sql(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
base_url: &str,
|
||||
|
|
@ -209,7 +406,7 @@ async fn run_sql(
|
|||
bytes
|
||||
};
|
||||
|
||||
if !status.is_success() {
|
||||
if !status.is_success() && status.as_u16() != 202 {
|
||||
let body_text = String::from_utf8_lossy(&body_bytes);
|
||||
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue