mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 05:51:14 +00:00
ep: Combine PredictionProvider and ZetaVersion (#46896)
We can specify prompt version in the provider name itself, like this `--provider zeta2:0113`. This kind of tag will also be stored in the `provider` field of jsonlines files. This drops the `--version` parameter. Release Notes: - N/A
This commit is contained in:
parent
3ed6c68f3b
commit
a10fdfd2b8
4 changed files with 105 additions and 72 deletions
|
|
@ -42,7 +42,7 @@ pub async fn run_format_prompt(
|
|||
provider: args.provider,
|
||||
});
|
||||
}
|
||||
PredictionProvider::Zeta2 => {
|
||||
PredictionProvider::Zeta2(version) => {
|
||||
step_progress.set_substatus("formatting zeta2 prompt");
|
||||
|
||||
let context_start = prompt_inputs.context_range.start;
|
||||
|
|
@ -59,7 +59,7 @@ pub async fn run_format_prompt(
|
|||
events: prompt_inputs.edit_history.clone(),
|
||||
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
|
||||
};
|
||||
let prompt = format_zeta_prompt(&input, args.version);
|
||||
let prompt = format_zeta_prompt(&input, version);
|
||||
let expected_output = zeta2_output_for_patch(
|
||||
&input,
|
||||
&example
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ use gpui::{AppContext as _, Application, BackgroundExecutor};
|
|||
use zeta_prompt::ZetaVersion;
|
||||
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt::Display;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
|
@ -152,47 +152,19 @@ impl Display for Command {
|
|||
Command::ParseExample => write!(f, "parse-example"),
|
||||
Command::LoadProject => write!(f, "load-project"),
|
||||
Command::Context => write!(f, "context"),
|
||||
Command::FormatPrompt(format_prompt_args) => write!(
|
||||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Predict(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"predict --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
Command::FormatPrompt(args) => {
|
||||
write!(f, "format-prompt --provider={}", args.provider)
|
||||
}
|
||||
Command::Score(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"score --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
Command::Predict(args) => {
|
||||
write!(f, "predict --provider={}", args.provider)
|
||||
}
|
||||
Command::Score(args) => {
|
||||
write!(f, "score --provider={}", args.provider)
|
||||
}
|
||||
Command::Distill => write!(f, "distill"),
|
||||
Command::Eval(predict_args) => write!(
|
||||
f,
|
||||
"eval --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Eval(args) => {
|
||||
write!(f, "eval --provider={}", args.provider)
|
||||
}
|
||||
Command::Synthesize(args) => {
|
||||
write!(f, "synthesize --repos {}", args.repos.join(" "))
|
||||
}
|
||||
|
|
@ -205,43 +177,96 @@ impl Display for Command {
|
|||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long, short)]
|
||||
#[clap(long, short('p'), default_value_t = PredictionProvider::default())]
|
||||
provider: PredictionProvider,
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
|
||||
value_parser = ZetaVersion::parse,
|
||||
default_value_t = ZetaVersion::default(),
|
||||
)]
|
||||
version: ZetaVersion,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct PredictArgs {
|
||||
#[clap(long, short)]
|
||||
#[clap(long, short('p'), default_value_t = PredictionProvider::default())]
|
||||
provider: PredictionProvider,
|
||||
#[clap(long, default_value_t = 1)]
|
||||
repetitions: usize,
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
|
||||
value_parser = ZetaVersion::parse,
|
||||
)]
|
||||
version: ZetaVersion,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
enum PredictionProvider {
|
||||
Sweep,
|
||||
Mercury,
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Zeta2(ZetaVersion),
|
||||
Teacher,
|
||||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
impl Default for PredictionProvider {
|
||||
fn default() -> Self {
|
||||
PredictionProvider::Zeta2(ZetaVersion::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PredictionProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PredictionProvider::Sweep => write!(f, "sweep"),
|
||||
PredictionProvider::Mercury => write!(f, "mercury"),
|
||||
PredictionProvider::Zeta1 => write!(f, "zeta1"),
|
||||
PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
|
||||
PredictionProvider::Teacher => write!(f, "teacher"),
|
||||
PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for PredictionProvider {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let s_lower = s.to_lowercase();
|
||||
match s_lower.as_str() {
|
||||
"sweep" => Ok(PredictionProvider::Sweep),
|
||||
"mercury" => Ok(PredictionProvider::Mercury),
|
||||
"zeta1" => Ok(PredictionProvider::Zeta1),
|
||||
// Handle both old format "zeta2" and new format with version
|
||||
"zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
|
||||
"teacher" => Ok(PredictionProvider::Teacher),
|
||||
"teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
|
||||
Ok(PredictionProvider::TeacherNonBatching)
|
||||
}
|
||||
_ if s_lower.starts_with("zeta2:") => {
|
||||
let version_str = &s[6..];
|
||||
let version = ZetaVersion::parse(version_str)?;
|
||||
Ok(PredictionProvider::Zeta2(version))
|
||||
}
|
||||
_ => anyhow::bail!(
|
||||
"unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
|
||||
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
|
||||
Available zeta versions:\n{}",
|
||||
ZetaVersion::options_as_string()
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PredictionProvider {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PredictionProvider {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse().map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct SynthesizeArgs {
|
||||
/// Repository URLs (git@github.com:owner/repo or https://...)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ pub async fn run_prediction(
|
|||
) -> anyhow::Result<()> {
|
||||
let provider = args.provider;
|
||||
let repetition_count = args.repetitions;
|
||||
let zeta_version = args.version;
|
||||
|
||||
if let Some(existing_prediction) = example.predictions.first() {
|
||||
if existing_prediction.provider == provider {
|
||||
|
|
@ -51,10 +50,7 @@ pub async fn run_prediction(
|
|||
|
||||
run_format_prompt(
|
||||
example,
|
||||
&FormatPromptArgs {
|
||||
provider,
|
||||
version: args.version,
|
||||
},
|
||||
&FormatPromptArgs { provider },
|
||||
app_state.clone(),
|
||||
cx,
|
||||
)
|
||||
|
|
@ -70,7 +66,7 @@ pub async fn run_prediction(
|
|||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
|
||||
) {
|
||||
step_progress.set_substatus("authenticating");
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
|
|
@ -95,9 +91,9 @@ pub async fn run_prediction(
|
|||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
|
||||
version: zeta_version,
|
||||
},
|
||||
PredictionProvider::Zeta2(version) => {
|
||||
edit_prediction::EditPredictionModel::Zeta2 { version }
|
||||
}
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
|
|
@ -135,7 +131,7 @@ pub async fn run_prediction(
|
|||
|
||||
if let Some(prompt) = request.prompt {
|
||||
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
if provider == PredictionProvider::Zeta2 {
|
||||
if matches!(provider, PredictionProvider::Zeta2(_)) {
|
||||
updated_example.prompt.get_or_insert(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: String::new(),
|
||||
|
|
|
|||
|
|
@ -18,7 +18,19 @@ pub struct ZetaPromptInput {
|
|||
pub related_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
|
||||
#[derive(
|
||||
Default,
|
||||
Clone,
|
||||
Copy,
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Hash,
|
||||
EnumIter,
|
||||
IntoStaticStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum ZetaVersion {
|
||||
V0112_MiddleAtEnd,
|
||||
|
|
@ -54,7 +66,7 @@ impl ZetaVersion {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
fn options_as_string() -> String {
|
||||
pub fn options_as_string() -> String {
|
||||
ZetaVersion::iter()
|
||||
.map(|version| format!("- {}\n", <&'static str>::from(version)))
|
||||
.collect::<Vec<_>>()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue