ep: Combine PredictionProvider and ZetaVersion (#46896)

We can specify prompt version in the provider name itself, like this
`--provider zeta2:0113`.

This kind of tag will also be stored in the `provider` field of
jsonlines files.

This drops the `--version` parameter.


Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2026-01-15 19:00:21 +02:00 committed by GitHub
parent 3ed6c68f3b
commit a10fdfd2b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 105 additions and 72 deletions

View file

@ -42,7 +42,7 @@ pub async fn run_format_prompt(
provider: args.provider,
});
}
PredictionProvider::Zeta2 => {
PredictionProvider::Zeta2(version) => {
step_progress.set_substatus("formatting zeta2 prompt");
let context_start = prompt_inputs.context_range.start;
@ -59,7 +59,7 @@ pub async fn run_format_prompt(
events: prompt_inputs.edit_history.clone(),
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
};
let prompt = format_zeta_prompt(&input, args.version);
let prompt = format_zeta_prompt(&input, version);
let expected_output = zeta2_output_for_patch(
&input,
&example

View file

@ -25,7 +25,7 @@ use gpui::{AppContext as _, Application, BackgroundExecutor};
use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::Display;
use std::fs::{File, OpenOptions};
use std::hash::{Hash, Hasher};
@ -152,47 +152,19 @@ impl Display for Command {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(format_prompt_args) => write!(
f,
"format-prompt --prompt-format={}",
format_prompt_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Predict(predict_args) => {
write!(
f,
"predict --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
Command::FormatPrompt(args) => {
write!(f, "format-prompt --provider={}", args.provider)
}
Command::Score(predict_args) => {
write!(
f,
"score --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
Command::Predict(args) => {
write!(f, "predict --provider={}", args.provider)
}
Command::Score(args) => {
write!(f, "score --provider={}", args.provider)
}
Command::Distill => write!(f, "distill"),
Command::Eval(predict_args) => write!(
f,
"eval --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Eval(args) => {
write!(f, "eval --provider={}", args.provider)
}
Command::Synthesize(args) => {
write!(f, "synthesize --repos {}", args.repos.join(" "))
}
@ -205,43 +177,96 @@ impl Display for Command {
#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
#[clap(long, short)]
#[clap(long, short('p'), default_value_t = PredictionProvider::default())]
provider: PredictionProvider,
#[clap(
long,
short,
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
value_parser = ZetaVersion::parse,
default_value_t = ZetaVersion::default(),
)]
version: ZetaVersion,
}
#[derive(Debug, Args, Clone)]
struct PredictArgs {
#[clap(long, short)]
#[clap(long, short('p'), default_value_t = PredictionProvider::default())]
provider: PredictionProvider,
#[clap(long, default_value_t = 1)]
repetitions: usize,
#[clap(
long,
short,
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
value_parser = ZetaVersion::parse,
)]
version: ZetaVersion,
}
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum PredictionProvider {
Sweep,
Mercury,
Zeta1,
Zeta2,
Zeta2(ZetaVersion),
Teacher,
TeacherNonBatching,
}
impl Default for PredictionProvider {
fn default() -> Self {
PredictionProvider::Zeta2(ZetaVersion::default())
}
}
impl std::fmt::Display for PredictionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PredictionProvider::Sweep => write!(f, "sweep"),
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
PredictionProvider::Teacher => write!(f, "teacher"),
PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
}
}
}
impl std::str::FromStr for PredictionProvider {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s_lower = s.to_lowercase();
match s_lower.as_str() {
"sweep" => Ok(PredictionProvider::Sweep),
"mercury" => Ok(PredictionProvider::Mercury),
"zeta1" => Ok(PredictionProvider::Zeta1),
// Handle both old format "zeta2" and new format with version
"zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
"teacher" => Ok(PredictionProvider::Teacher),
"teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
Ok(PredictionProvider::TeacherNonBatching)
}
_ if s_lower.starts_with("zeta2:") => {
let version_str = &s[6..];
let version = ZetaVersion::parse(version_str)?;
Ok(PredictionProvider::Zeta2(version))
}
_ => anyhow::bail!(
"unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
Available zeta versions:\n{}",
ZetaVersion::options_as_string()
),
}
}
}
impl Serialize for PredictionProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for PredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Args, Clone)]
struct SynthesizeArgs {
/// Repository URLs (git@github.com:owner/repo or https://...)

View file

@ -31,7 +31,6 @@ pub async fn run_prediction(
) -> anyhow::Result<()> {
let provider = args.provider;
let repetition_count = args.repetitions;
let zeta_version = args.version;
if let Some(existing_prediction) = example.predictions.first() {
if existing_prediction.provider == provider {
@ -51,10 +50,7 @@ pub async fn run_prediction(
run_format_prompt(
example,
&FormatPromptArgs {
provider,
version: args.version,
},
&FormatPromptArgs { provider },
app_state.clone(),
cx,
)
@ -70,7 +66,7 @@ pub async fn run_prediction(
if matches!(
provider,
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
) {
step_progress.set_substatus("authenticating");
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
@ -95,9 +91,9 @@ pub async fn run_prediction(
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
version: zeta_version,
},
PredictionProvider::Zeta2(version) => {
edit_prediction::EditPredictionModel::Zeta2 { version }
}
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@ -135,7 +131,7 @@ pub async fn run_prediction(
if let Some(prompt) = request.prompt {
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
if provider == PredictionProvider::Zeta2 {
if matches!(provider, PredictionProvider::Zeta2(_)) {
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
expected_output: String::new(),

View file

@ -18,7 +18,19 @@ pub struct ZetaPromptInput {
pub related_files: Vec<RelatedFile>,
}
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
#[derive(
Default,
Clone,
Copy,
Debug,
PartialEq,
Eq,
Hash,
EnumIter,
IntoStaticStr,
Serialize,
Deserialize,
)]
#[allow(non_camel_case_types)]
pub enum ZetaVersion {
V0112_MiddleAtEnd,
@ -54,7 +66,7 @@ impl ZetaVersion {
Ok(result)
}
fn options_as_string() -> String {
pub fn options_as_string() -> String {
ZetaVersion::iter()
.map(|version| format!("- {}\n", <&'static str>::from(version)))
.collect::<Vec<_>>()