ep: Add settled data fetching from snowflake (#50326)

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...
This commit is contained in:
Ben Kunkle 2026-03-01 10:20:32 -06:00 committed by GitHub
parent 14358b711c
commit ceb9d83dd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 372 additions and 2 deletions

View file

@ -55,6 +55,7 @@ use crate::load_project::run_load_project;
use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
use crate::split_commit::SplitCommitArgs;
@ -132,6 +133,10 @@ Inputs can be file paths or special specifiers:
Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
These are predictions that were shown to users but rejected (useful for DPO training).
settled-after:{timestamp}
Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
These are examples from the edit prediction settled stream.
rated-after:{timestamp}
Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
These are predictions that users explicitly rated as positive or negative via the
@ -166,6 +171,9 @@ Examples:
# Read user-rated predictions
ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
# Read settled stream examples
ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
# Read only positively rated predictions
ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
@ -635,6 +643,7 @@ async fn load_examples(
let mut captured_after_timestamps = Vec::new();
let mut rejected_after_timestamps = Vec::new();
let mut requested_after_timestamps = Vec::new();
let mut settled_after_timestamps = Vec::new();
let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
Vec::new();
let mut file_inputs = Vec::new();
@ -651,6 +660,8 @@ async fn load_examples(
pull_examples::parse_requested_after_input(input_string.as_ref())
{
requested_after_timestamps.push(timestamp.to_string());
} else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
settled_after_timestamps.push(timestamp.to_string());
} else if let Some((timestamp, rating_filter)) =
pull_examples::parse_rated_after_input(input_string.as_ref())
{
@ -718,6 +729,21 @@ async fn load_examples(
examples.append(&mut requested_examples);
}
if !settled_after_timestamps.is_empty() {
settled_after_timestamps.sort();
let mut settled_examples = fetch_settled_examples_after(
http_client.clone(),
&settled_after_timestamps,
max_rows_per_timestamp,
remaining_offset,
background_executor.clone(),
Some(MIN_CAPTURE_VERSION),
)
.await?;
examples.append(&mut settled_examples);
}
if !rated_after_inputs.is_empty() {
rated_after_inputs.sort();

View file

@ -5,24 +5,25 @@ use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
use std::fmt::Write as _;
use std::io::Read;
use std::sync::Arc;
use std::time::Duration;
use telemetry_events::EditPredictionRating;
use zeta_prompt::ZetaPromptInput;
use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
use crate::example::Example;
use crate::progress::{InfoStyle, Progress, Step};
const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
use std::fmt::Write as _;
pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
/// Minimum Zed version for filtering captured examples.
/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
@ -34,6 +35,7 @@ pub struct MinCaptureVersion {
}
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
@ -52,6 +54,11 @@ pub fn parse_requested_after_input(input: &str) -> Option<&str> {
input.strip_prefix("requested-after:")
}
/// Parse an input token of the form `settled-after:{timestamp}`.
pub fn parse_settled_after_input(input: &str) -> Option<&str> {
input.strip_prefix("settled-after:")
}
/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
/// or `rated-negative-after:{timestamp}`.
/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
@ -596,6 +603,163 @@ pub async fn fetch_requested_examples_after(
Ok(all_examples)
}
pub async fn fetch_settled_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
max_rows_per_timestamp: usize,
offset: usize,
background_executor: BackgroundExecutor,
min_capture_version: Option<MinCaptureVersion>,
) -> Result<Vec<Example>> {
if after_timestamps.is_empty() {
return Ok(Vec::new());
}
let progress = Progress::global();
let token = std::env::var("EP_SNOWFLAKE_API_KEY")
.context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
"missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
)?;
let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
let mut all_examples = Vec::new();
for after_date in after_timestamps.iter() {
let step_progress_name = format!("settled>{after_date}");
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
let statement = indoc! {r#"
WITH requested AS (
SELECT
req.event_properties:request_id::string AS request_id,
req.device_id::string AS device_id,
req.time AS req_time,
req.time::string AS time,
req.event_properties:input AS input,
req.event_properties:format::string AS requested_format,
req.event_properties:output::string AS requested_output,
req.event_properties:zed_version::string AS zed_version
FROM events req
WHERE req.event_type = ?
AND req.event_properties:version = 'V3'
AND req.event_properties:input:can_collect_data = true
AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
)
SELECT
req.request_id AS request_id,
req.device_id AS device_id,
req.time AS time,
req.input AS input,
req.requested_output AS requested_output,
settled.event_properties:settled_editable_region::string AS settled_editable_region,
req.requested_format AS requested_format,
req.zed_version AS zed_version
FROM requested req
INNER JOIN events settled
ON req.request_id = settled.event_properties:request_id::string
WHERE settled.event_type = ?
ORDER BY req.req_time ASC
LIMIT ?
OFFSET ?
"#};
let _ = min_capture_version;
let request = json!({
"statement": statement,
"timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
"database": "EVENTS",
"schema": "PUBLIC",
"warehouse": "DBT",
"role": role,
"bindings": {
"1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
"2": { "type": "TEXT", "value": after_date },
"3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
"4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
"5": { "type": "FIXED", "value": offset.to_string() }
}
});
let response = run_sql_with_polling(
http_client.clone(),
&base_url,
&token,
&request,
&step_progress,
background_executor.clone(),
)
.await?;
let total_rows = response
.result_set_meta_data
.as_ref()
.and_then(|m| m.num_rows)
.unwrap_or(response.data.len() as i64);
let num_partitions = response
.result_set_meta_data
.as_ref()
.map(|m| m.partition_info.len())
.unwrap_or(1)
.max(1);
step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
step_progress.set_substatus("parsing");
let column_indices = get_column_indices(
&response.result_set_meta_data,
&[
"request_id",
"device_id",
"time",
"input",
"requested_output",
"settled_editable_region",
"requested_format",
"zed_version",
],
);
all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
if num_partitions > 1 {
let statement_handle = response
.statement_handle
.as_ref()
.context("response has multiple partitions but no statementHandle")?;
for partition in 1..num_partitions {
step_progress.set_substatus(format!(
"fetching partition {}/{}",
partition + 1,
num_partitions
));
let partition_response = fetch_partition(
http_client.clone(),
&base_url,
&token,
statement_handle,
partition,
)
.await?;
all_examples.extend(settled_examples_from_response(
&partition_response,
&column_indices,
)?);
}
}
step_progress.set_substatus("done");
}
Ok(all_examples)
}
pub async fn fetch_rated_examples_after(
http_client: Arc<dyn HttpClient>,
inputs: &[(String, Option<EditPredictionRating>)],
@ -989,6 +1153,186 @@ fn requested_examples_from_response<'a>(
Ok(iter)
}
fn settled_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,
) -> Result<impl Iterator<Item = Example> + 'a> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
"snowflake sql api returned error code={code} message={}",
response.message.as_deref().unwrap_or("<no message>")
);
}
}
let iter = response
.data
.iter()
.enumerate()
.filter_map(move |(row_index, data_row)| {
let get_value = |name: &str| -> Option<JsonValue> {
let index = column_indices.get(name).copied()?;
let value = data_row.get(index)?;
if value.is_null() {
None
} else {
Some(value.clone())
}
};
let get_string = |name: &str| -> Option<String> {
match get_value(name)? {
JsonValue::String(s) => Some(s),
other => Some(other.to_string()),
}
};
let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
let value = raw?;
match value {
JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
other => Some(other.clone()),
}
};
let request_id_str = get_string("request_id");
let device_id = get_string("device_id");
let time = get_string("time");
let input_raw = get_value("input");
let input_json = parse_json_value("input", input_raw.as_ref());
let input: Option<ZetaPromptInput> = input_json
.as_ref()
.and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
let requested_output = get_string("requested_output");
let settled_editable_region = get_string("settled_editable_region");
let requested_format =
get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
let zed_version = get_string("zed_version");
match (
request_id_str.clone(),
device_id.clone(),
time.clone(),
input.clone(),
requested_output.clone(),
settled_editable_region.clone(),
requested_format,
) {
(
Some(request_id),
Some(device_id),
Some(time),
Some(input),
Some(requested_output),
Some(settled_editable_region),
Some(requested_format),
) => Some(build_settled_example(
request_id,
device_id,
time,
input,
requested_output,
settled_editable_region,
requested_format,
zed_version,
)),
_ => {
let mut missing_fields = Vec::new();
if request_id_str.is_none() {
missing_fields.push("request_id");
}
if device_id.is_none() {
missing_fields.push("device_id");
}
if time.is_none() {
missing_fields.push("time");
}
if input_raw.is_none() || input_json.is_none() || input.is_none() {
missing_fields.push("input");
}
if requested_output.is_none() {
missing_fields.push("requested_output");
}
if settled_editable_region.is_none() {
missing_fields.push("settled_editable_region");
}
if requested_format.is_none() {
missing_fields.push("requested_format");
}
log::warn!(
"skipping settled row {row_index}: [{}]",
missing_fields.join(", "),
);
None
}
}
});
Ok(iter)
}
fn build_settled_example(
request_id: String,
device_id: String,
time: String,
input: ZetaPromptInput,
requested_output: String,
settled_editable_region: String,
requested_format: ZetaFormat,
zed_version: Option<String>,
) -> Example {
let requested_editable_range = input
.excerpt_ranges
.as_ref()
.map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
.unwrap_or_else(|| input.editable_range_in_excerpt.clone());
let base_cursor_excerpt = input.cursor_excerpt.to_string();
let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
&& requested_editable_range.end <= base_cursor_excerpt.len();
let mut example = build_example_from_snowflake(
request_id.clone(),
device_id,
time,
input,
vec!["settled".to_string()],
None,
zed_version,
);
if !requested_range_is_valid {
log::warn!(
"skipping malformed requested range for request {}: requested={:?} (base_len={})",
request_id,
requested_editable_range,
base_cursor_excerpt.len(),
);
return example;
}
let settled_replacement = settled_editable_region.as_str();
let rejected_patch = build_output_patch(
&example.spec.cursor_path,
&base_cursor_excerpt,
&requested_editable_range,
&requested_output,
);
let expected_patch = build_output_patch(
&example.spec.cursor_path,
&base_cursor_excerpt,
&requested_editable_range,
settled_replacement,
);
example.spec.expected_patches = vec![expected_patch];
example.spec.rejected_patch = Some(rejected_patch);
example
}
fn rejected_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,