mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-28 01:24:17 +00:00
When --provider is not provided, `ep` will now use whatever provider is recorded in the data. Release Notes: - N/A
331 lines
11 KiB
Rust
331 lines
11 KiB
Rust
use crate::{
|
|
FormatPromptArgs, PredictArgs, PredictionProvider,
|
|
anthropic_client::AnthropicClient,
|
|
example::{Example, ExamplePrediction, ExamplePrompt},
|
|
format_prompt::{TeacherPrompt, run_format_prompt},
|
|
headless::EpAppState,
|
|
load_project::run_load_project,
|
|
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
|
|
progress::{ExampleProgress, InfoStyle, Step},
|
|
retrieve_context::run_context_retrieval,
|
|
};
|
|
use anyhow::Context as _;
|
|
use edit_prediction::{DebugEvent, EditPredictionStore};
|
|
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
|
use gpui::{AppContext as _, AsyncApp, Task};
|
|
use std::{
|
|
fs,
|
|
sync::{
|
|
Arc, Mutex, OnceLock,
|
|
atomic::{AtomicUsize, Ordering::SeqCst},
|
|
},
|
|
};
|
|
use zeta_prompt::ZetaVersion;
|
|
|
|
static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
|
|
|
|
pub async fn run_prediction(
|
|
example: &mut Example,
|
|
args: &PredictArgs,
|
|
app_state: Arc<EpAppState>,
|
|
example_progress: &ExampleProgress,
|
|
mut cx: AsyncApp,
|
|
) -> anyhow::Result<()> {
|
|
let repetition_count = args.repetitions;
|
|
|
|
if let Some(existing_prediction) = example.predictions.first() {
|
|
let has_prediction = existing_prediction.actual_patch.is_some()
|
|
|| !existing_prediction.actual_output.is_empty();
|
|
if has_prediction {
|
|
match args.provider {
|
|
None => return Ok(()),
|
|
Some(provider) if existing_prediction.provider == provider => return Ok(()),
|
|
Some(_) => example.predictions.clear(),
|
|
}
|
|
}
|
|
}
|
|
|
|
let Some(provider) = args.provider else {
|
|
anyhow::bail!(
|
|
"No existing predictions found. Use --provider to specify which model to use for prediction."
|
|
);
|
|
};
|
|
|
|
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
|
|
|
|
if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
|
|
provider
|
|
{
|
|
let _step_progress = example_progress.start(Step::Predict);
|
|
|
|
run_format_prompt(
|
|
example,
|
|
&FormatPromptArgs { provider },
|
|
app_state.clone(),
|
|
example_progress,
|
|
cx,
|
|
)
|
|
.await?;
|
|
|
|
let batched = matches!(provider, PredictionProvider::Teacher(..));
|
|
return predict_anthropic(example, repetition_count, version, batched).await;
|
|
}
|
|
|
|
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
|
|
|
|
let step_progress = example_progress.start(Step::Predict);
|
|
|
|
if matches!(
|
|
provider,
|
|
PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
|
|
) {
|
|
step_progress.set_substatus("authenticating");
|
|
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
|
AUTHENTICATED
|
|
.get_or_init(|| {
|
|
let client = app_state.client.clone();
|
|
cx.spawn(async move |cx| {
|
|
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
|
|
eprintln!("Authentication failed: {}", e);
|
|
}
|
|
})
|
|
.shared()
|
|
})
|
|
.clone()
|
|
.await;
|
|
}
|
|
|
|
let ep_store = cx
|
|
.update(|cx| EditPredictionStore::try_global(cx))
|
|
.context("EditPredictionStore not initialized")?;
|
|
|
|
ep_store.update(&mut cx, |store, _cx| {
|
|
let model = match provider {
|
|
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
|
PredictionProvider::Zeta2(version) => {
|
|
edit_prediction::EditPredictionModel::Zeta2 { version }
|
|
}
|
|
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
|
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
|
PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
|
|
unreachable!()
|
|
}
|
|
};
|
|
store.set_edit_prediction_model(model);
|
|
});
|
|
step_progress.set_substatus("configuring model");
|
|
let state = example.state.as_ref().context("state must be set")?;
|
|
let run_dir = RUN_DIR.join(&example.spec.name);
|
|
|
|
let updated_example = Arc::new(Mutex::new(example.clone()));
|
|
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
|
|
|
let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
|
|
let debug_task = cx.background_spawn({
|
|
let updated_example = updated_example.clone();
|
|
let current_run_ix = current_run_ix.clone();
|
|
let run_dir = run_dir.clone();
|
|
async move {
|
|
while let Some(event) = debug_rx.next().await {
|
|
let run_ix = current_run_ix.load(SeqCst);
|
|
let mut updated_example = updated_example.lock().unwrap();
|
|
|
|
let run_dir = if repetition_count > 1 {
|
|
run_dir.join(format!("{:03}", run_ix))
|
|
} else {
|
|
run_dir.clone()
|
|
};
|
|
|
|
match event {
|
|
DebugEvent::EditPredictionStarted(request) => {
|
|
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
|
|
|
if let Some(prompt) = request.prompt {
|
|
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
|
|
if matches!(provider, PredictionProvider::Zeta2(_)) {
|
|
updated_example.prompt.get_or_insert(ExamplePrompt {
|
|
input: prompt,
|
|
expected_output: String::new(),
|
|
provider,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
DebugEvent::EditPredictionFinished(request) => {
|
|
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
|
|
|
if let Some(output) = request.model_output {
|
|
fs::write(run_dir.join("prediction_response.md"), &output)?;
|
|
updated_example
|
|
.predictions
|
|
.last_mut()
|
|
.unwrap()
|
|
.actual_output = output;
|
|
}
|
|
if run_ix >= repetition_count {
|
|
break;
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
anyhow::Ok(())
|
|
}
|
|
});
|
|
|
|
for ix in 0..repetition_count {
|
|
current_run_ix.store(ix, SeqCst);
|
|
let run_dir = if repetition_count > 1 {
|
|
run_dir.join(format!("{:03}", ix))
|
|
} else {
|
|
run_dir.clone()
|
|
};
|
|
|
|
fs::create_dir_all(&run_dir)?;
|
|
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
|
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
|
}
|
|
#[cfg(unix)]
|
|
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
|
#[cfg(windows)]
|
|
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
|
|
|
updated_example
|
|
.lock()
|
|
.unwrap()
|
|
.predictions
|
|
.push(ExamplePrediction {
|
|
actual_patch: None,
|
|
actual_output: String::new(),
|
|
provider,
|
|
});
|
|
|
|
step_progress.set_substatus("requesting prediction");
|
|
let prediction = ep_store
|
|
.update(&mut cx, |store, cx| {
|
|
store.request_prediction(
|
|
&state.project,
|
|
&state.buffer,
|
|
state.cursor_position,
|
|
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
|
cx,
|
|
)
|
|
})
|
|
.await?;
|
|
|
|
let actual_patch = prediction.and_then(|prediction| {
|
|
let prediction = prediction.prediction.ok()?;
|
|
prediction
|
|
.edit_preview
|
|
.as_unified_diff(prediction.snapshot.file(), &prediction.edits)
|
|
});
|
|
|
|
let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
|
|
|
|
updated_example
|
|
.lock()
|
|
.unwrap()
|
|
.predictions
|
|
.last_mut()
|
|
.unwrap()
|
|
.actual_patch = actual_patch;
|
|
|
|
if ix == repetition_count - 1 {
|
|
let (info, style) = if has_prediction {
|
|
("predicted", InfoStyle::Normal)
|
|
} else {
|
|
("no prediction", InfoStyle::Warning)
|
|
};
|
|
step_progress.set_info(info, style);
|
|
}
|
|
}
|
|
|
|
ep_store.update(&mut cx, |store, _| {
|
|
store.remove_project(&state.project);
|
|
});
|
|
debug_task.await?;
|
|
|
|
*example = Arc::into_inner(updated_example)
|
|
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
|
|
.into_inner()
|
|
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn predict_anthropic(
|
|
example: &mut Example,
|
|
_repetition_count: usize,
|
|
version: ZetaVersion,
|
|
batched: bool,
|
|
) -> anyhow::Result<()> {
|
|
let llm_model_name = "claude-sonnet-4-5";
|
|
let max_tokens = 16384;
|
|
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
|
|
let client = if batched {
|
|
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
|
|
} else {
|
|
AnthropicClient::plain()
|
|
};
|
|
client.expect("Failed to create Anthropic client")
|
|
});
|
|
|
|
let prompt = example.prompt.as_ref().context("Prompt is required")?;
|
|
|
|
let messages = vec![anthropic::Message {
|
|
role: anthropic::Role::User,
|
|
content: vec![anthropic::RequestContent::Text {
|
|
text: prompt.input.clone(),
|
|
cache_control: None,
|
|
}],
|
|
}];
|
|
|
|
let Some(response) = llm_client
|
|
.generate(llm_model_name, max_tokens, messages)
|
|
.await?
|
|
else {
|
|
// Request stashed for batched processing
|
|
return Ok(());
|
|
};
|
|
|
|
let actual_output = response
|
|
.content
|
|
.into_iter()
|
|
.filter_map(|content| match content {
|
|
anthropic::ResponseContent::Text { text } => Some(text),
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<String>>()
|
|
.join("\n");
|
|
|
|
let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
|
|
|
|
let prediction = ExamplePrediction {
|
|
actual_patch: Some(actual_patch),
|
|
actual_output,
|
|
provider: if batched {
|
|
PredictionProvider::Teacher(version)
|
|
} else {
|
|
PredictionProvider::TeacherNonBatching(version)
|
|
},
|
|
};
|
|
|
|
example.predictions.push(prediction);
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
|
|
match provider {
|
|
Some(PredictionProvider::Teacher(..)) => {
|
|
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
|
|
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
|
|
.expect("Failed to create Anthropic client")
|
|
});
|
|
llm_client
|
|
.sync_batches()
|
|
.await
|
|
.context("Failed to sync batches")?;
|
|
}
|
|
_ => (),
|
|
};
|
|
Ok(())
|
|
}
|