mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-23 21:05:08 +00:00
zeta2: Build edit prediction prompt and process model output in client (#41870)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev> Co-authored-by: Ben Kunkle <ben@zed.dev> Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
This commit is contained in:
parent
fb87972f44
commit
784fdcaee3
32 changed files with 2198 additions and 2392 deletions
6
Cargo.lock
generated
6
Cargo.lock
generated
|
|
@ -39,6 +39,7 @@ dependencies = [
|
|||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3198,7 +3199,9 @@ dependencies = [
|
|||
"indoc",
|
||||
"ordered-float 2.10.1",
|
||||
"rustc-hash 2.1.1",
|
||||
"schemars 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.2",
|
||||
]
|
||||
|
||||
|
|
@ -21675,10 +21678,10 @@ dependencies = [
|
|||
"language_model",
|
||||
"log",
|
||||
"lsp",
|
||||
"open_ai",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"release_channel",
|
||||
"schemars 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
|
@ -21687,6 +21690,7 @@ dependencies = [
|
|||
"uuid",
|
||||
"workspace",
|
||||
"worktree",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -56,3 +56,4 @@ rand.workspace = true
|
|||
tempfile.workspace = true
|
||||
util.workspace = true
|
||||
settings.workspace = true
|
||||
zlog.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
pub mod predict_edits_v3;
|
||||
pub mod udiff;
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use chrono::Duration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
fmt::{Display, Write as _},
|
||||
ops::{Add, Range, Sub},
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
|
|
@ -11,7 +11,14 @@ use uuid::Uuid;
|
|||
|
||||
use crate::PredictEditsGitInfo;
|
||||
|
||||
// TODO: snippet ordering within file / relative to excerpt
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanContextRetrievalRequest {
|
||||
pub excerpt: String,
|
||||
pub excerpt_path: Arc<Path>,
|
||||
pub excerpt_line_range: Range<Line>,
|
||||
pub cursor_file_max_row: Line,
|
||||
pub events: Vec<Event>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsRequest {
|
||||
|
|
@ -125,15 +132,15 @@ impl Display for Event {
|
|||
write!(
|
||||
f,
|
||||
"// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
|
||||
old_path.display(),
|
||||
new_path.display()
|
||||
DiffPathFmt(old_path),
|
||||
DiffPathFmt(new_path)
|
||||
)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"--- a/{}\n+++ b/{}\n{diff}",
|
||||
old_path.display(),
|
||||
new_path.display()
|
||||
DiffPathFmt(old_path),
|
||||
DiffPathFmt(new_path)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -141,6 +148,24 @@ impl Display for Event {
|
|||
}
|
||||
}
|
||||
|
||||
/// always format the Path as a unix path with `/` as the path sep in Diffs
|
||||
pub struct DiffPathFmt<'a>(pub &'a Path);
|
||||
|
||||
impl<'a> std::fmt::Display for DiffPathFmt<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut is_first = true;
|
||||
for component in self.0.components() {
|
||||
if !is_first {
|
||||
f.write_char('/')?;
|
||||
} else {
|
||||
is_first = false;
|
||||
}
|
||||
write!(f, "{}", component.as_os_str().display())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Signature {
|
||||
pub text: String,
|
||||
|
|
|
|||
|
|
@ -1,294 +0,0 @@
|
|||
use std::{borrow::Cow, fmt::Display};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum DiffLine<'a> {
|
||||
OldPath { path: Cow<'a, str> },
|
||||
NewPath { path: Cow<'a, str> },
|
||||
HunkHeader(Option<HunkLocation>),
|
||||
Context(&'a str),
|
||||
Deletion(&'a str),
|
||||
Addition(&'a str),
|
||||
Garbage(&'a str),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct HunkLocation {
|
||||
start_line_old: u32,
|
||||
count_old: u32,
|
||||
start_line_new: u32,
|
||||
count_new: u32,
|
||||
}
|
||||
|
||||
impl<'a> DiffLine<'a> {
|
||||
pub fn parse(line: &'a str) -> Self {
|
||||
Self::try_parse(line).unwrap_or(Self::Garbage(line))
|
||||
}
|
||||
|
||||
fn try_parse(line: &'a str) -> Option<Self> {
|
||||
if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
|
||||
let path = parse_header_path("a/", header);
|
||||
Some(Self::OldPath { path })
|
||||
} else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
|
||||
Some(Self::NewPath {
|
||||
path: parse_header_path("b/", header),
|
||||
})
|
||||
} else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
|
||||
if header.starts_with("...") {
|
||||
return Some(Self::HunkHeader(None));
|
||||
}
|
||||
|
||||
let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
|
||||
let mut parts = header.split_ascii_whitespace();
|
||||
let count_old = parts.next()?;
|
||||
let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
|
||||
|
||||
Some(Self::HunkHeader(Some(HunkLocation {
|
||||
start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
|
||||
count_old: count_old.parse().ok()?,
|
||||
start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
|
||||
count_new: count_new.parse().ok()?,
|
||||
})))
|
||||
} else if let Some(deleted_header) = line.strip_prefix("-") {
|
||||
Some(Self::Deletion(deleted_header))
|
||||
} else if line.is_empty() {
|
||||
Some(Self::Context(""))
|
||||
} else if let Some(context) = line.strip_prefix(" ") {
|
||||
Some(Self::Context(context))
|
||||
} else {
|
||||
Some(Self::Addition(line.strip_prefix("+")?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Display for DiffLine<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DiffLine::OldPath { path } => write!(f, "--- {path}"),
|
||||
DiffLine::NewPath { path } => write!(f, "+++ {path}"),
|
||||
DiffLine::HunkHeader(Some(hunk_location)) => {
|
||||
write!(
|
||||
f,
|
||||
"@@ -{},{} +{},{} @@",
|
||||
hunk_location.start_line_old + 1,
|
||||
hunk_location.count_old,
|
||||
hunk_location.start_line_new + 1,
|
||||
hunk_location.count_new
|
||||
)
|
||||
}
|
||||
DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
|
||||
DiffLine::Context(content) => write!(f, " {content}"),
|
||||
DiffLine::Deletion(content) => write!(f, "-{content}"),
|
||||
DiffLine::Addition(content) => write!(f, "+{content}"),
|
||||
DiffLine::Garbage(line) => write!(f, "{line}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
|
||||
if !header.contains(['"', '\\']) {
|
||||
let path = header.split_ascii_whitespace().next().unwrap_or(header);
|
||||
return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
|
||||
}
|
||||
|
||||
let mut path = String::with_capacity(header.len());
|
||||
let mut in_quote = false;
|
||||
let mut chars = header.chars().peekable();
|
||||
let mut strip_prefix = Some(strip_prefix);
|
||||
|
||||
while let Some(char) = chars.next() {
|
||||
if char == '"' {
|
||||
in_quote = !in_quote;
|
||||
} else if char == '\\' {
|
||||
let Some(&next_char) = chars.peek() else {
|
||||
break;
|
||||
};
|
||||
chars.next();
|
||||
path.push(next_char);
|
||||
} else if char.is_ascii_whitespace() && !in_quote {
|
||||
break;
|
||||
} else {
|
||||
path.push(char);
|
||||
}
|
||||
|
||||
if let Some(prefix) = strip_prefix
|
||||
&& path == prefix
|
||||
{
|
||||
strip_prefix.take();
|
||||
path.clear();
|
||||
}
|
||||
}
|
||||
|
||||
Cow::Owned(path)
|
||||
}
|
||||
|
||||
fn eat_required_whitespace(header: &str) -> Option<&str> {
|
||||
let trimmed = header.trim_ascii_start();
|
||||
|
||||
if trimmed.len() == header.len() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
|
||||
#[test]
|
||||
fn parse_lines_simple() {
|
||||
let input = indoc! {"
|
||||
diff --git a/text.txt b/text.txt
|
||||
index 86c770d..a1fd855 100644
|
||||
--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,2 +1,3 @@
|
||||
context
|
||||
-deleted
|
||||
+inserted
|
||||
garbage
|
||||
|
||||
--- b/file.txt
|
||||
+++ a/file.txt
|
||||
"};
|
||||
|
||||
let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
|
||||
|
||||
pretty_assertions::assert_eq!(
|
||||
lines,
|
||||
&[
|
||||
DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
|
||||
DiffLine::Garbage("index 86c770d..a1fd855 100644"),
|
||||
DiffLine::OldPath {
|
||||
path: "file.txt".into()
|
||||
},
|
||||
DiffLine::NewPath {
|
||||
path: "file.txt".into()
|
||||
},
|
||||
DiffLine::HunkHeader(Some(HunkLocation {
|
||||
start_line_old: 0,
|
||||
count_old: 2,
|
||||
start_line_new: 0,
|
||||
count_new: 3
|
||||
})),
|
||||
DiffLine::Context("context"),
|
||||
DiffLine::Deletion("deleted"),
|
||||
DiffLine::Addition("inserted"),
|
||||
DiffLine::Garbage("garbage"),
|
||||
DiffLine::Context(""),
|
||||
DiffLine::OldPath {
|
||||
path: "b/file.txt".into()
|
||||
},
|
||||
DiffLine::NewPath {
|
||||
path: "a/file.txt".into()
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_header_extra_space() {
|
||||
let options = ["--- file", "--- file", "---\tfile"];
|
||||
|
||||
for option in options {
|
||||
pretty_assertions::assert_eq!(
|
||||
DiffLine::parse(option),
|
||||
DiffLine::OldPath {
|
||||
path: "file".into()
|
||||
},
|
||||
"{option}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hunk_header_extra_space() {
|
||||
let options = [
|
||||
"@@ -1,2 +1,3 @@",
|
||||
"@@ -1,2 +1,3 @@",
|
||||
"@@\t-1,2\t+1,3\t@@",
|
||||
"@@ -1,2 +1,3 @@",
|
||||
"@@ -1,2 +1,3 @@",
|
||||
"@@ -1,2 +1,3 @@",
|
||||
"@@ -1,2 +1,3 @@ garbage",
|
||||
];
|
||||
|
||||
for option in options {
|
||||
pretty_assertions::assert_eq!(
|
||||
DiffLine::parse(option),
|
||||
DiffLine::HunkHeader(Some(HunkLocation {
|
||||
start_line_old: 0,
|
||||
count_old: 2,
|
||||
start_line_new: 0,
|
||||
count_new: 3
|
||||
})),
|
||||
"{option}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hunk_header_without_location() {
|
||||
pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_path() {
|
||||
assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "foo/bar/baz.txt"),
|
||||
"foo/bar/baz.txt"
|
||||
);
|
||||
assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "a/foo/bar/baz.txt"),
|
||||
"foo/bar/baz.txt"
|
||||
);
|
||||
|
||||
// Extra
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "a/foo/bar/baz.txt 2025"),
|
||||
"foo/bar/baz.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
|
||||
"foo/bar/baz.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "a/foo/bar/baz.txt \""),
|
||||
"foo/bar/baz.txt"
|
||||
);
|
||||
|
||||
// Quoted
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
|
||||
"foo/bar/baz quox.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
|
||||
"foo/bar/baz quox.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
|
||||
"foo/bar/baz quox.txt"
|
||||
);
|
||||
assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷");
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"),
|
||||
"foo/bar/baz quox.txt"
|
||||
);
|
||||
// unescaped quotes are dropped
|
||||
assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
|
||||
|
||||
// Escaped
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
|
||||
"foo/\"bar\"/baz.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
|
||||
"C:\\Projects\\My App\\old file.txt"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -17,5 +17,7 @@ cloud_llm_client.workspace = true
|
|||
indoc.workspace = true
|
||||
ordered-float.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
//! Zeta2 prompt planning and generation code shared with cloud.
|
||||
pub mod retrieval_prompt;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
|
||||
self, DiffPathFmt, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use ordered_float::OrderedFloat;
|
||||
|
|
@ -212,7 +213,7 @@ pub fn write_codeblock<'a>(
|
|||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````{}", path.display()).unwrap();
|
||||
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
|
|
@ -275,7 +276,7 @@ pub fn write_excerpts<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
|
||||
pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
|
||||
if events.is_empty() {
|
||||
return;
|
||||
};
|
||||
|
|
|
|||
92
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs
Normal file
92
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
|
||||
use indoc::indoc;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, sync::LazyLock};
|
||||
|
||||
use crate::{push_events, write_codeblock};
|
||||
|
||||
pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
|
||||
let mut prompt = SEARCH_INSTRUCTIONS.to_string();
|
||||
|
||||
if !request.events.is_empty() {
|
||||
writeln!(&mut prompt, "## User Edits\n")?;
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
writeln!(&mut prompt, "## Excerpt around the cursor\n")?;
|
||||
write_codeblock(
|
||||
&request.excerpt_path,
|
||||
&[Excerpt {
|
||||
start_line: request.excerpt_line_range.start,
|
||||
text: request.excerpt.into(),
|
||||
}],
|
||||
&[],
|
||||
request.cursor_file_max_row,
|
||||
true,
|
||||
&mut prompt,
|
||||
);
|
||||
|
||||
writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Search for relevant code
|
||||
///
|
||||
/// For the best results, run multiple queries at once with a single invocation of this tool.
|
||||
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct SearchToolInput {
|
||||
/// An array of queries to run for gathering context relevant to the next prediction
|
||||
#[schemars(length(max = 5))]
|
||||
pub queries: Box<[SearchToolQuery]>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SearchToolQuery {
|
||||
/// A glob pattern to match file paths in the codebase
|
||||
pub glob: String,
|
||||
/// A regular expression to match content within the files matched by the glob pattern
|
||||
pub regex: String,
|
||||
}
|
||||
|
||||
pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
|
||||
let schema = schemars::schema_for!(SearchToolInput);
|
||||
|
||||
let description = schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
(schema.into(), description)
|
||||
});
|
||||
|
||||
pub const TOOL_NAME: &str = "search";
|
||||
|
||||
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
|
||||
## Task
|
||||
|
||||
You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
|
||||
that will serve as context for predicting the next required edit.
|
||||
|
||||
**Your task:**
|
||||
- Analyze the user's recent edits and current cursor context
|
||||
- Use the `search` tool to find code that may be relevant for predicting the next edit
|
||||
- Focus on finding:
|
||||
- Code patterns that might need similar changes based on the recent edits
|
||||
- Functions, variables, types, and constants referenced in the current cursor context
|
||||
- Related implementations, usages, or dependencies that may require consistent updates
|
||||
|
||||
**Important constraints:**
|
||||
- This conversation has exactly 2 turns
|
||||
- You must make ALL search queries in your first response via the `search` tool
|
||||
- All queries will be executed in parallel and results returned together
|
||||
- In the second turn, you will select the most relevant results via the `select` tool.
|
||||
"#};
|
||||
|
||||
const TOOL_USE_REMINDER: &str = indoc! {"
|
||||
--
|
||||
Use the `search` tool now
|
||||
"};
|
||||
|
|
@ -34,7 +34,7 @@ struct CurrentCompletion {
|
|||
snapshot: BufferSnapshot,
|
||||
/// The edits that should be applied to transform the original text into the predicted text.
|
||||
/// Each edit is a range in the buffer and the text to replace it with.
|
||||
edits: Arc<[(Range<Anchor>, String)]>,
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
/// Preview of how the buffer will look after applying the edits.
|
||||
edit_preview: EditPreview,
|
||||
}
|
||||
|
|
@ -42,7 +42,7 @@ struct CurrentCompletion {
|
|||
impl CurrentCompletion {
|
||||
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
|
||||
/// Returns None if the user's edits conflict with the predicted edits.
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
}
|
||||
}
|
||||
|
|
@ -281,8 +281,8 @@ impl EditPredictionProvider for CodestralCompletionProvider {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let edits: Arc<[(Range<Anchor>, String)]> =
|
||||
vec![(cursor_position..cursor_position, completion_text)].into();
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
|
||||
vec![(cursor_position..cursor_position, completion_text.into())].into();
|
||||
let edit_preview = buffer
|
||||
.read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
|
||||
.await;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::ops::Range;
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use client::EditPredictionUsage;
|
||||
use gpui::{App, Context, Entity, SharedString};
|
||||
|
|
@ -19,7 +19,7 @@ pub enum EditPrediction {
|
|||
/// Edits within the buffer that requested the prediction
|
||||
Local {
|
||||
id: Option<SharedString>,
|
||||
edits: Vec<(Range<language::Anchor>, String)>,
|
||||
edits: Vec<(Range<language::Anchor>, Arc<str>)>,
|
||||
edit_preview: Option<language::EditPreview>,
|
||||
},
|
||||
/// Jump to a different file from the one that requested the prediction
|
||||
|
|
@ -248,8 +248,8 @@ where
|
|||
pub fn interpolate_edits(
|
||||
old_snapshot: &BufferSnapshot,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
current_edits: &[(Range<Anchor>, String)],
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
current_edits: &[(Range<Anchor>, Arc<str>)],
|
||||
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
let mut edits = Vec::new();
|
||||
|
||||
let mut model_edits = current_edits.iter().peekable();
|
||||
|
|
@ -274,7 +274,7 @@ pub fn interpolate_edits(
|
|||
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
|
||||
if !model_suffix.is_empty() {
|
||||
let anchor = old_snapshot.anchor_after(user_edit.old.end);
|
||||
edits.push((anchor..anchor, model_suffix.to_string()));
|
||||
edits.push((anchor..anchor, model_suffix.into()));
|
||||
}
|
||||
|
||||
model_edits.next();
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use edit_prediction::EditPredictionProvider;
|
|||
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
|
||||
use indoc::indoc;
|
||||
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
|
||||
use std::ops::Range;
|
||||
use std::{ops::Range, sync::Arc};
|
||||
use text::{Point, ToOffset};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -24,7 +24,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
|
|||
|
||||
assert_editor_active_edit_completion(&mut cx, |_, edits| {
|
||||
assert_eq!(edits.len(), 1);
|
||||
assert_eq!(edits[0].1.as_str(), "-273.15");
|
||||
assert_eq!(edits[0].1.as_ref(), "-273.15");
|
||||
});
|
||||
|
||||
accept_completion(&mut cx);
|
||||
|
|
@ -46,7 +46,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
|
|||
|
||||
assert_editor_active_edit_completion(&mut cx, |_, edits| {
|
||||
assert_eq!(edits.len(), 1);
|
||||
assert_eq!(edits[0].1.as_str(), "3.14159");
|
||||
assert_eq!(edits[0].1.as_ref(), "3.14159");
|
||||
});
|
||||
|
||||
accept_completion(&mut cx);
|
||||
|
|
@ -330,7 +330,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
|
|||
|
||||
fn assert_editor_active_edit_completion(
|
||||
cx: &mut EditorTestContext,
|
||||
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, String)>),
|
||||
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, Arc<str>)>),
|
||||
) {
|
||||
cx.editor(|editor, _, cx| {
|
||||
let completion_state = editor
|
||||
|
|
|
|||
|
|
@ -616,7 +616,7 @@ pub(crate) enum EditDisplayMode {
|
|||
|
||||
enum EditPrediction {
|
||||
Edit {
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
edits: Vec<(Range<Anchor>, Arc<str>)>,
|
||||
edit_preview: Option<EditPreview>,
|
||||
display_mode: EditDisplayMode,
|
||||
snapshot: BufferSnapshot,
|
||||
|
|
@ -7960,7 +7960,7 @@ impl Editor {
|
|||
let inlay = Inlay::edit_prediction(
|
||||
post_inc(&mut self.next_inlay_id),
|
||||
range.start,
|
||||
new_text.as_str(),
|
||||
new_text.as_ref(),
|
||||
);
|
||||
inlay_ids.push(inlay.id);
|
||||
inlays.push(inlay);
|
||||
|
|
@ -8982,7 +8982,7 @@ impl Editor {
|
|||
newest_selection_head: Option<DisplayPoint>,
|
||||
editor_width: Pixels,
|
||||
style: &EditorStyle,
|
||||
edits: &Vec<(Range<Anchor>, String)>,
|
||||
edits: &Vec<(Range<Anchor>, Arc<str>)>,
|
||||
edit_preview: &Option<language::EditPreview>,
|
||||
snapshot: &language::BufferSnapshot,
|
||||
window: &mut Window,
|
||||
|
|
@ -24382,25 +24382,20 @@ impl InvalidationRegion for SnippetState {
|
|||
|
||||
fn edit_prediction_edit_text(
|
||||
current_snapshot: &BufferSnapshot,
|
||||
edits: &[(Range<Anchor>, String)],
|
||||
edits: &[(Range<Anchor>, impl AsRef<str>)],
|
||||
edit_preview: &EditPreview,
|
||||
include_deletions: bool,
|
||||
cx: &App,
|
||||
) -> HighlightedText {
|
||||
let edits = edits
|
||||
.iter()
|
||||
.map(|(anchor, text)| {
|
||||
(
|
||||
anchor.start.text_anchor..anchor.end.text_anchor,
|
||||
text.clone(),
|
||||
)
|
||||
})
|
||||
.map(|(anchor, text)| (anchor.start.text_anchor..anchor.end.text_anchor, text))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx)
|
||||
}
|
||||
|
||||
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText {
|
||||
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, Arc<str>)], cx: &App) -> HighlightedText {
|
||||
// Fallback for providers that don't provide edit_preview (like Copilot/Supermaven)
|
||||
// Just show the raw edit text with basic styling
|
||||
let mut text = String::new();
|
||||
|
|
@ -24793,7 +24788,7 @@ impl Focusable for BreakpointPromptEditor {
|
|||
}
|
||||
|
||||
fn all_edits_insertions_or_deletions(
|
||||
edits: &Vec<(Range<Anchor>, String)>,
|
||||
edits: &Vec<(Range<Anchor>, Arc<str>)>,
|
||||
snapshot: &MultiBufferSnapshot,
|
||||
) -> bool {
|
||||
let mut all_insertions = true;
|
||||
|
|
|
|||
|
|
@ -22915,7 +22915,7 @@ async fn assert_highlighted_edits(
|
|||
let text_anchor_edits = edits
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit))
|
||||
.map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit.into()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let edit_preview = window
|
||||
|
|
|
|||
|
|
@ -720,7 +720,7 @@ impl EditPreview {
|
|||
pub fn highlight_edits(
|
||||
&self,
|
||||
current_snapshot: &BufferSnapshot,
|
||||
edits: &[(Range<Anchor>, String)],
|
||||
edits: &[(Range<Anchor>, impl AsRef<str>)],
|
||||
include_deletions: bool,
|
||||
cx: &App,
|
||||
) -> HighlightedText {
|
||||
|
|
@ -747,7 +747,8 @@ impl EditPreview {
|
|||
.end
|
||||
.bias_right(&self.old_snapshot)
|
||||
.to_offset(&self.applied_edits_snapshot);
|
||||
let edit_start_in_preview_snapshot = edit_new_end_in_preview_snapshot - edit_text.len();
|
||||
let edit_start_in_preview_snapshot =
|
||||
edit_new_end_in_preview_snapshot - edit_text.as_ref().len();
|
||||
|
||||
let unchanged_range_in_preview_snapshot =
|
||||
offset_in_preview_snapshot..edit_start_in_preview_snapshot;
|
||||
|
|
@ -772,7 +773,7 @@ impl EditPreview {
|
|||
);
|
||||
}
|
||||
|
||||
if !edit_text.is_empty() {
|
||||
if !edit_text.as_ref().is_empty() {
|
||||
highlighted_text.add_text_from_buffer_range(
|
||||
edit_start_in_preview_snapshot..edit_new_end_in_preview_snapshot,
|
||||
&self.applied_edits_snapshot,
|
||||
|
|
@ -796,7 +797,7 @@ impl EditPreview {
|
|||
highlighted_text.build()
|
||||
}
|
||||
|
||||
fn compute_visible_range(&self, edits: &[(Range<Anchor>, String)]) -> Option<Range<usize>> {
|
||||
fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<usize>> {
|
||||
let (first, _) = edits.first()?;
|
||||
let (last, _) = edits.last()?;
|
||||
|
||||
|
|
@ -1130,7 +1131,7 @@ impl Buffer {
|
|||
|
||||
pub fn preview_edits(
|
||||
&self,
|
||||
edits: Arc<[(Range<Anchor>, String)]>,
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
cx: &App,
|
||||
) -> Task<EditPreview> {
|
||||
let registry = self.language_registry();
|
||||
|
|
|
|||
|
|
@ -3120,15 +3120,13 @@ async fn test_preview_edits(cx: &mut TestAppContext) {
|
|||
.map(|(range, text)| {
|
||||
(
|
||||
buffer.anchor_before(range.start)..buffer.anchor_after(range.end),
|
||||
text.to_string(),
|
||||
text.into(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.collect::<Arc<[_]>>()
|
||||
});
|
||||
let edit_preview = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
buffer.preview_edits(edits.clone().into(), cx)
|
||||
})
|
||||
.read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
|
||||
.await;
|
||||
let highlighted_edits = cx.read(|cx| {
|
||||
edit_preview.highlight_edits(&buffer.read(cx).snapshot(), &edits, include_deletions, cx)
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ pub struct FunctionDefinition {
|
|||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
|
|
@ -366,25 +366,42 @@ pub struct ImageUrl {
|
|||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub content: ToolCallContent,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolCallContent {
|
||||
Function { function: FunctionContent },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionContent {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Response {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: RequestMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessageDelta {
|
||||
pub role: Option<Role>,
|
||||
|
|
@ -410,7 +427,7 @@ pub struct FunctionChunk {
|
|||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use language::{Anchor, Buffer, BufferSnapshot};
|
|||
use std::{
|
||||
ops::{AddAssign, Range},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use text::{ToOffset, ToPoint};
|
||||
|
|
@ -51,7 +52,7 @@ fn completion_from_diff(
|
|||
) -> EditPrediction {
|
||||
let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
|
||||
|
||||
let mut edits: Vec<(Range<language::Anchor>, String)> = Vec::new();
|
||||
let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
|
||||
|
||||
let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
|
||||
let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
|
||||
|
|
@ -70,7 +71,10 @@ fn completion_from_diff(
|
|||
if k != 0 {
|
||||
let offset = snapshot.anchor_after(offset);
|
||||
// the range from the current position to item is an inlay.
|
||||
let edit = (offset..offset, completion_graphemes[i..i + k].join(""));
|
||||
let edit = (
|
||||
offset..offset,
|
||||
completion_graphemes[i..i + k].join("").into(),
|
||||
);
|
||||
edits.push(edit);
|
||||
}
|
||||
i += k + 1;
|
||||
|
|
@ -90,7 +94,7 @@ fn completion_from_diff(
|
|||
// there is leftover completion text, so drop it as an inlay.
|
||||
let edit_range = offset..offset;
|
||||
let edit_text = completion_graphemes[i..].join("");
|
||||
edits.push((edit_range, edit_text));
|
||||
edits.push((edit_range, edit_text.into()));
|
||||
}
|
||||
|
||||
EditPrediction::Local {
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ pub struct EditPrediction {
|
|||
path: Arc<Path>,
|
||||
excerpt_range: Range<usize>,
|
||||
cursor_offset: usize,
|
||||
edits: Arc<[(Range<Anchor>, String)]>,
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
snapshot: BufferSnapshot,
|
||||
edit_preview: EditPreview,
|
||||
input_outline: Arc<str>,
|
||||
|
|
@ -150,7 +150,7 @@ impl EditPrediction {
|
|||
.duration_since(self.buffer_snapshotted_at)
|
||||
}
|
||||
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
}
|
||||
}
|
||||
|
|
@ -711,7 +711,7 @@ impl Zeta {
|
|||
cx.spawn(async move |cx| {
|
||||
let output_excerpt: Arc<str> = output_excerpt.into();
|
||||
|
||||
let edits: Arc<[(Range<Anchor>, String)]> = cx
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
|
||||
.background_spawn({
|
||||
let output_excerpt = output_excerpt.clone();
|
||||
let editable_range = editable_range.clone();
|
||||
|
|
@ -725,7 +725,7 @@ impl Zeta {
|
|||
let edits = edits.clone();
|
||||
move |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits: Arc<[(Range<Anchor>, String)]> =
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
|
||||
edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
|
||||
.into();
|
||||
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
|
||||
|
|
@ -759,7 +759,7 @@ impl Zeta {
|
|||
output_excerpt: Arc<str>,
|
||||
editable_range: Range<usize>,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> Result<Vec<(Range<Anchor>, String)>> {
|
||||
) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
let content = output_excerpt.replace(CURSOR_MARKER, "");
|
||||
|
||||
let start_markers = content
|
||||
|
|
@ -817,7 +817,7 @@ impl Zeta {
|
|||
new_text: &str,
|
||||
offset: usize,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> Vec<(Range<Anchor>, String)> {
|
||||
) -> Vec<(Range<Anchor>, Arc<str>)> {
|
||||
text_diff(&old_text, new_text)
|
||||
.into_iter()
|
||||
.map(|(mut old_range, new_text)| {
|
||||
|
|
@ -836,7 +836,7 @@ impl Zeta {
|
|||
);
|
||||
old_range.end = old_range.end.saturating_sub(suffix_len);
|
||||
|
||||
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
|
||||
let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
|
||||
let range = if old_range.is_empty() {
|
||||
let anchor = snapshot.anchor_after(old_range.start);
|
||||
anchor..anchor
|
||||
|
|
@ -1183,7 +1183,7 @@ impl CurrentEditPrediction {
|
|||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text)
|
||||
new_range == old_range && new_text.starts_with(old_text.as_ref())
|
||||
} else {
|
||||
true
|
||||
}
|
||||
|
|
@ -1599,13 +1599,8 @@ mod tests {
|
|||
#[gpui::test]
|
||||
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
|
||||
let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
|
||||
to_completion_edits(
|
||||
[(2..5, "REM".to_string()), (9..11, "".to_string())],
|
||||
&buffer,
|
||||
cx,
|
||||
)
|
||||
.into()
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
|
||||
to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
|
||||
});
|
||||
|
||||
let edit_preview = cx
|
||||
|
|
@ -1635,7 +1630,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
|
||||
vec![(2..5, "REM".into()), (9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
|
||||
|
|
@ -1645,7 +1640,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
|
||||
vec![(2..2, "REM".into()), (6..8, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
|
|
@ -1655,7 +1650,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
|
||||
vec![(2..5, "REM".into()), (9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
|
||||
|
|
@ -1665,7 +1660,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
|
||||
vec![(3..3, "EM".into()), (7..9, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
|
||||
|
|
@ -1675,7 +1670,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
|
||||
vec![(4..4, "M".into()), (8..10, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
|
||||
|
|
@ -1685,7 +1680,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(9..11, "".to_string())]
|
||||
vec![(9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
|
||||
|
|
@ -1695,7 +1690,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
|
||||
vec![(4..4, "M".into()), (8..10, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
|
||||
|
|
@ -1705,7 +1700,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string())]
|
||||
vec![(4..4, "M".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
|
||||
|
|
@ -2211,10 +2206,10 @@ mod tests {
|
|||
}
|
||||
|
||||
fn to_completion_edits(
|
||||
iterator: impl IntoIterator<Item = (Range<usize>, String)>,
|
||||
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &App,
|
||||
) -> Vec<(Range<Anchor>, String)> {
|
||||
) -> Vec<(Range<Anchor>, Arc<str>)> {
|
||||
let buffer = buffer.read(cx);
|
||||
iterator
|
||||
.into_iter()
|
||||
|
|
@ -2228,10 +2223,10 @@ mod tests {
|
|||
}
|
||||
|
||||
fn from_completion_edits(
|
||||
editor_edits: &[(Range<Anchor>, String)],
|
||||
editor_edits: &[(Range<Anchor>, Arc<str>)],
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &App,
|
||||
) -> Vec<(Range<usize>, String)> {
|
||||
) -> Vec<(Range<usize>, Arc<str>)> {
|
||||
let buffer = buffer.read(cx);
|
||||
editor_edits
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ indoc.workspace = true
|
|||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
open_ai.workspace = true
|
||||
project.workspace = true
|
||||
release_channel.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
|
@ -50,3 +50,4 @@ language_model = { workspace = true, features = ["test-support"] }
|
|||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,17 +1,11 @@
|
|||
use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use gpui::{App, AsyncApp, Entity};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
|
||||
};
|
||||
use project::Project;
|
||||
use util::ResultExt;
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(Uuid);
|
||||
pub struct EditPredictionId(pub Uuid);
|
||||
|
||||
impl Into<Uuid> for EditPredictionId {
|
||||
fn into(self) -> Uuid {
|
||||
|
|
@ -34,8 +28,7 @@ impl std::fmt::Display for EditPredictionId {
|
|||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub path: Arc<Path>,
|
||||
pub edits: Arc<[(Range<Anchor>, String)]>,
|
||||
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
|
||||
|
|
@ -43,90 +36,43 @@ pub struct EditPrediction {
|
|||
}
|
||||
|
||||
impl EditPrediction {
|
||||
pub async fn from_response(
|
||||
response: predict_edits_v3::PredictEditsResponse,
|
||||
active_buffer_old_snapshot: &TextBufferSnapshot,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
pub async fn new(
|
||||
id: EditPredictionId,
|
||||
edited_buffer: &Entity<Buffer>,
|
||||
edited_buffer_snapshot: &BufferSnapshot,
|
||||
edits: Vec<(Range<Anchor>, Arc<str>)>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<Self> {
|
||||
// TODO only allow cloud to return one path
|
||||
let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
|
||||
return None;
|
||||
};
|
||||
let (edits, snapshot, edit_preview_task) = edited_buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
|
||||
|
||||
let is_same_path = active_buffer
|
||||
.read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
|
||||
.ok()?;
|
||||
|
||||
let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
|
||||
active_buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
|
||||
|
||||
Some((
|
||||
active_buffer.clone(),
|
||||
edits.clone(),
|
||||
new_snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
} else {
|
||||
let buffer_handle = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(&path, cx)
|
||||
.context("Failed to find project path for zeta edit")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})
|
||||
.ok()?
|
||||
.log_err()?
|
||||
.await
|
||||
.context("Failed to open buffer for zeta edit")
|
||||
.log_err()?;
|
||||
|
||||
buffer_handle
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &snapshot);
|
||||
if edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((
|
||||
buffer_handle.clone(),
|
||||
edits.clone(),
|
||||
snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
};
|
||||
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
|
||||
})
|
||||
.ok()??;
|
||||
|
||||
let edit_preview = edit_preview_task.await;
|
||||
|
||||
Some(EditPrediction {
|
||||
id: EditPredictionId(response.request_id),
|
||||
path,
|
||||
id,
|
||||
edits,
|
||||
snapshot,
|
||||
edit_preview,
|
||||
buffer,
|
||||
buffer: edited_buffer.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn interpolate(
|
||||
&self,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
|
||||
}
|
||||
|
||||
pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
|
||||
buffer_path_eq(buffer, &self.path, cx)
|
||||
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
|
||||
self.snapshot.remote_id() == buffer.remote_id()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -134,21 +80,16 @@ impl std::fmt::Debug for EditPrediction {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("EditPrediction")
|
||||
.field("id", &self.id)
|
||||
.field("path", &self.path)
|
||||
.field("edits", &self.edits)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
|
||||
buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
|
||||
}
|
||||
|
||||
pub fn interpolate_edits(
|
||||
old_snapshot: &TextBufferSnapshot,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
current_edits: Arc<[(Range<Anchor>, String)]>,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
let mut edits = Vec::new();
|
||||
|
||||
let mut model_edits = current_edits.iter().peekable();
|
||||
|
|
@ -173,7 +114,7 @@ pub fn interpolate_edits(
|
|||
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
|
||||
if !model_suffix.is_empty() {
|
||||
let anchor = old_snapshot.anchor_after(user_edit.old.end);
|
||||
edits.push((anchor..anchor, model_suffix.to_string()));
|
||||
edits.push((anchor..anchor, model_suffix.into()));
|
||||
}
|
||||
|
||||
model_edits.next();
|
||||
|
|
@ -190,135 +131,17 @@ pub fn interpolate_edits(
|
|||
if edits.is_empty() { None } else { Some(edits) }
|
||||
}
|
||||
|
||||
pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
|
||||
language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
|
||||
}
|
||||
|
||||
fn edits_from_response(
|
||||
edits: &[predict_edits_v3::Edit],
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> Arc<[(Range<Anchor>, String)]> {
|
||||
edits
|
||||
.iter()
|
||||
.flat_map(|edit| {
|
||||
let point_range = line_range_to_point_range(edit.range.clone());
|
||||
let offset = point_range.to_offset(snapshot).start;
|
||||
let old_text = snapshot.text_for_range(point_range);
|
||||
|
||||
excerpt_edits_from_response(
|
||||
old_text.collect::<Cow<str>>(),
|
||||
&edit.content,
|
||||
offset,
|
||||
&snapshot,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into()
|
||||
}
|
||||
|
||||
fn excerpt_edits_from_response(
|
||||
old_text: Cow<str>,
|
||||
new_text: &str,
|
||||
offset: usize,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> impl Iterator<Item = (Range<Anchor>, String)> {
|
||||
text_diff(&old_text, new_text)
|
||||
.into_iter()
|
||||
.map(move |(mut old_range, new_text)| {
|
||||
old_range.start += offset;
|
||||
old_range.end += offset;
|
||||
|
||||
let prefix_len = common_prefix(
|
||||
snapshot.chars_for_range(old_range.clone()),
|
||||
new_text.chars(),
|
||||
);
|
||||
old_range.start += prefix_len;
|
||||
|
||||
let suffix_len = common_prefix(
|
||||
snapshot.reversed_chars_for_range(old_range.clone()),
|
||||
new_text[prefix_len..].chars().rev(),
|
||||
);
|
||||
old_range.end = old_range.end.saturating_sub(suffix_len);
|
||||
|
||||
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
|
||||
let range = if old_range.is_empty() {
|
||||
let anchor = snapshot.anchor_after(old_range.start);
|
||||
anchor..anchor
|
||||
} else {
|
||||
snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
|
||||
};
|
||||
(range, new_text)
|
||||
})
|
||||
}
|
||||
|
||||
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
|
||||
a.zip(b)
|
||||
.take_while(|(a, b)| a == b)
|
||||
.map(|(a, _)| a.len_utf8())
|
||||
.sum()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::*;
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction_context::Line;
|
||||
use gpui::{App, Entity, TestAppContext, prelude::*};
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, ToOffset as _};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_compute_edits(cx: &mut TestAppContext) {
|
||||
let old = indoc! {r#"
|
||||
fn main() {
|
||||
let args =
|
||||
println!("{}", args[1])
|
||||
}
|
||||
"#};
|
||||
|
||||
let new = indoc! {r#"
|
||||
fn main() {
|
||||
let args = std::env::args();
|
||||
println!("{}", args[1]);
|
||||
}
|
||||
"#};
|
||||
|
||||
let buffer = cx.new(|cx| Buffer::local(old, cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// TODO cover more cases when multi-file is supported
|
||||
let big_edits = vec![predict_edits_v3::Edit {
|
||||
path: PathBuf::from("test.txt").into(),
|
||||
range: Line(0)..Line(old.lines().count() as u32),
|
||||
content: new.into(),
|
||||
}];
|
||||
|
||||
let edits = edits_from_response(&big_edits, &snapshot);
|
||||
assert_eq!(edits.len(), 2);
|
||||
assert_eq!(
|
||||
edits[0].0.to_point(&snapshot).start,
|
||||
language::Point::new(1, 14)
|
||||
);
|
||||
assert_eq!(edits[0].1, " std::env::args();");
|
||||
assert_eq!(
|
||||
edits[1].0.to_point(&snapshot).start,
|
||||
language::Point::new(2, 27)
|
||||
);
|
||||
assert_eq!(edits[1].1, ";");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
|
||||
let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
|
||||
to_prediction_edits(
|
||||
[(2..5, "REM".to_string()), (9..11, "".to_string())],
|
||||
&buffer,
|
||||
cx,
|
||||
)
|
||||
.into()
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
|
||||
to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
|
||||
});
|
||||
|
||||
let edit_preview = cx
|
||||
|
|
@ -329,7 +152,6 @@ mod tests {
|
|||
id: EditPredictionId(Uuid::new_v4()),
|
||||
edits,
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
path: Path::new("test.txt").into(),
|
||||
buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
};
|
||||
|
|
@ -341,7 +163,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
|
||||
vec![(2..5, "REM".into()), (9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
|
||||
|
|
@ -351,7 +173,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
|
||||
vec![(2..2, "REM".into()), (6..8, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
|
|
@ -361,7 +183,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
|
||||
vec![(2..5, "REM".into()), (9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
|
||||
|
|
@ -371,7 +193,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
|
||||
vec![(3..3, "EM".into()), (7..9, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
|
||||
|
|
@ -381,7 +203,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
|
||||
vec![(4..4, "M".into()), (8..10, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
|
||||
|
|
@ -391,7 +213,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(9..11, "".to_string())]
|
||||
vec![(9..11, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
|
||||
|
|
@ -401,7 +223,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
|
||||
vec![(4..4, "M".into()), (8..10, "".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
|
||||
|
|
@ -411,7 +233,7 @@ mod tests {
|
|||
&buffer,
|
||||
cx
|
||||
),
|
||||
vec![(4..4, "M".to_string())]
|
||||
vec![(4..4, "M".into())]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
|
||||
|
|
@ -420,10 +242,10 @@ mod tests {
|
|||
}
|
||||
|
||||
fn to_prediction_edits(
|
||||
iterator: impl IntoIterator<Item = (Range<usize>, String)>,
|
||||
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &App,
|
||||
) -> Vec<(Range<Anchor>, String)> {
|
||||
) -> Vec<(Range<Anchor>, Arc<str>)> {
|
||||
let buffer = buffer.read(cx);
|
||||
iterator
|
||||
.into_iter()
|
||||
|
|
@ -437,10 +259,10 @@ mod tests {
|
|||
}
|
||||
|
||||
fn from_prediction_edits(
|
||||
editor_edits: &[(Range<Anchor>, String)],
|
||||
editor_edits: &[(Range<Anchor>, Arc<str>)],
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &App,
|
||||
) -> Vec<(Range<usize>, String)> {
|
||||
) -> Vec<(Range<usize>, Arc<str>)> {
|
||||
let buffer = buffer.read(cx);
|
||||
editor_edits
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -1,717 +0,0 @@
|
|||
use std::{
|
||||
cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
|
||||
ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_zeta2_prompt::write_codeblock;
|
||||
use collections::HashMap;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
|
||||
use futures::{
|
||||
StreamExt,
|
||||
channel::mpsc::{self, UnboundedSender},
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use indoc::indoc;
|
||||
use language::{
|
||||
Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUse, MessageContent, Role,
|
||||
};
|
||||
use project::{
|
||||
Project, WorktreeSettings,
|
||||
search::{SearchQuery, SearchResult},
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{
|
||||
ResultExt as _,
|
||||
paths::{PathMatcher, PathStyle},
|
||||
};
|
||||
use workspace::item::Settings as _;
|
||||
|
||||
const SEARCH_PROMPT: &str = indoc! {r#"
|
||||
## Task
|
||||
|
||||
You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
|
||||
that will serve as context for predicting the next required edit.
|
||||
|
||||
**Your task:**
|
||||
- Analyze the user's recent edits and current cursor context
|
||||
- Use the `search` tool to find code that may be relevant for predicting the next edit
|
||||
- Focus on finding:
|
||||
- Code patterns that might need similar changes based on the recent edits
|
||||
- Functions, variables, types, and constants referenced in the current cursor context
|
||||
- Related implementations, usages, or dependencies that may require consistent updates
|
||||
|
||||
**Important constraints:**
|
||||
- This conversation has exactly 2 turns
|
||||
- You must make ALL search queries in your first response via the `search` tool
|
||||
- All queries will be executed in parallel and results returned together
|
||||
- In the second turn, you will select the most relevant results via the `select` tool.
|
||||
|
||||
## User Edits
|
||||
|
||||
{edits}
|
||||
|
||||
## Current cursor context
|
||||
|
||||
`````{current_file_path}
|
||||
{cursor_excerpt}
|
||||
`````
|
||||
|
||||
--
|
||||
Use the `search` tool now
|
||||
"#};
|
||||
|
||||
const SEARCH_TOOL_NAME: &str = "search";
|
||||
|
||||
/// Search for relevant code
|
||||
///
|
||||
/// For the best results, run multiple queries at once with a single invocation of this tool.
|
||||
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct SearchToolInput {
|
||||
/// An array of queries to run for gathering context relevant to the next prediction
|
||||
#[schemars(length(max = 5))]
|
||||
pub queries: Box<[SearchToolQuery]>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SearchToolQuery {
|
||||
/// A glob pattern to match file paths in the codebase
|
||||
pub glob: String,
|
||||
/// A regular expression to match content within the files matched by the glob pattern
|
||||
pub regex: String,
|
||||
}
|
||||
|
||||
const RESULTS_MESSAGE: &str = indoc! {"
|
||||
Here are the results of your queries combined and grouped by file:
|
||||
|
||||
"};
|
||||
|
||||
const SELECT_TOOL_NAME: &str = "select";
|
||||
|
||||
const SELECT_PROMPT: &str = indoc! {"
|
||||
Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
|
||||
Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
|
||||
Include up to 200 lines in total.
|
||||
"};
|
||||
|
||||
/// Select line ranges from search results
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct SelectToolInput {
|
||||
/// The line ranges to select from search results.
|
||||
ranges: Vec<SelectLineRange>,
|
||||
}
|
||||
|
||||
/// A specific line range to select from a file
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct SelectLineRange {
|
||||
/// The file path containing the lines to select
|
||||
/// Exactly as it appears in the search result codeblocks.
|
||||
path: PathBuf,
|
||||
/// The starting line number (1-based)
|
||||
#[schemars(range(min = 1))]
|
||||
start_line: u32,
|
||||
/// The ending line number (1-based, inclusive)
|
||||
#[schemars(range(min = 1))]
|
||||
end_line: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct LlmContextOptions {
|
||||
pub excerpt: EditPredictionExcerptOptions,
|
||||
}
|
||||
|
||||
pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
|
||||
pub fn find_related_excerpts(
|
||||
buffer: Entity<language::Buffer>,
|
||||
cursor_position: Anchor,
|
||||
project: &Entity<Project>,
|
||||
mut edit_history_unified_diff: String,
|
||||
options: &LlmContextOptions,
|
||||
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
|
||||
cx: &App,
|
||||
) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
|
||||
let language_model_registry = LanguageModelRegistry::global(cx);
|
||||
let Some(model) = language_model_registry
|
||||
.read(cx)
|
||||
.available_models(cx)
|
||||
.find(|model| {
|
||||
model.provider_id() == MODEL_PROVIDER_ID
|
||||
&& model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
|
||||
// model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
|
||||
// model.provider_id() == LanguageModelProviderId::new("ollama")
|
||||
// && model.id() == LanguageModelId("gpt-oss:20b".into())
|
||||
})
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("could not find context model")));
|
||||
};
|
||||
|
||||
if edit_history_unified_diff.is_empty() {
|
||||
edit_history_unified_diff.push_str("(No user edits yet)");
|
||||
}
|
||||
|
||||
// TODO [zeta2] include breadcrumbs?
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let cursor_point = cursor_position.to_point(&snapshot);
|
||||
let Some(cursor_excerpt) =
|
||||
EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
|
||||
else {
|
||||
return Task::ready(Ok(HashMap::default()));
|
||||
};
|
||||
|
||||
let current_file_path = snapshot
|
||||
.file()
|
||||
.map(|f| f.full_path(cx).display().to_string())
|
||||
.unwrap_or_else(|| "untitled".to_string());
|
||||
|
||||
let prompt = SEARCH_PROMPT
|
||||
.replace("{edits}", &edit_history_unified_diff)
|
||||
.replace("{current_file_path}", ¤t_file_path)
|
||||
.replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
|
||||
ZetaContextRetrievalStartedDebugInfo {
|
||||
project: project.clone(),
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: prompt.clone(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let path_style = project.read(cx).path_style(cx);
|
||||
|
||||
let exclude_matcher = {
|
||||
let global_settings = WorktreeSettings::get_global(cx);
|
||||
let exclude_patterns = global_settings
|
||||
.file_scan_exclusions
|
||||
.sources()
|
||||
.iter()
|
||||
.chain(global_settings.private_files.sources().iter());
|
||||
|
||||
match PathMatcher::new(exclude_patterns, path_style) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(err) => {
|
||||
return Task::ready(Err(anyhow!(err)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let initial_prompt_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
let mut search_stream = request_tool_call::<SearchToolInput>(
|
||||
vec![initial_prompt_message.clone()],
|
||||
SEARCH_TOOL_NAME,
|
||||
&model,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
|
||||
select_request_messages.push(initial_prompt_message);
|
||||
|
||||
let mut regex_by_glob: HashMap<String, String> = HashMap::default();
|
||||
let mut search_calls = Vec::new();
|
||||
|
||||
while let Some(event) = search_stream.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
if !tool_use.is_input_complete {
|
||||
continue;
|
||||
}
|
||||
|
||||
if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
|
||||
let input =
|
||||
serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
|
||||
|
||||
for query in input.queries {
|
||||
let regex = regex_by_glob.entry(query.glob).or_default();
|
||||
if !regex.is_empty() {
|
||||
regex.push('|');
|
||||
}
|
||||
regex.push_str(&query.regex);
|
||||
}
|
||||
|
||||
search_calls.push(tool_use);
|
||||
} else {
|
||||
log::warn!(
|
||||
"context gathering model tried to use unknown tool: {}",
|
||||
tool_use.name
|
||||
);
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::Text(txt) => {
|
||||
if let Some(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
..
|
||||
}) = select_request_messages.last_mut()
|
||||
{
|
||||
if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
|
||||
existing_text.push_str(&txt);
|
||||
} else {
|
||||
content.push(MessageContent::Text(txt));
|
||||
}
|
||||
} else {
|
||||
select_request_messages.push(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::Text(txt)],
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking { text, signature } => {
|
||||
if let Some(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
..
|
||||
}) = select_request_messages.last_mut()
|
||||
{
|
||||
if let Some(MessageContent::Thinking {
|
||||
text: existing_text,
|
||||
signature: existing_signature,
|
||||
}) = content.last_mut()
|
||||
{
|
||||
existing_text.push_str(&text);
|
||||
*existing_signature = signature;
|
||||
} else {
|
||||
content.push(MessageContent::Thinking { text, signature });
|
||||
}
|
||||
} else {
|
||||
select_request_messages.push(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::Thinking { text, signature }],
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::RedactedThinking { data } => {
|
||||
if let Some(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
..
|
||||
}) = select_request_messages.last_mut()
|
||||
{
|
||||
if let Some(MessageContent::RedactedThinking(existing_data)) =
|
||||
content.last_mut()
|
||||
{
|
||||
existing_data.push_str(&data);
|
||||
} else {
|
||||
content.push(MessageContent::RedactedThinking(data));
|
||||
}
|
||||
} else {
|
||||
select_request_messages.push(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::RedactedThinking(data)],
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
|
||||
log::error!("{ev:?}");
|
||||
}
|
||||
ev => {
|
||||
log::trace!("context search event: {ev:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let search_tool_use = if search_calls.is_empty() {
|
||||
log::warn!("context model ran 0 searches");
|
||||
return anyhow::Ok(Default::default());
|
||||
} else if search_calls.len() == 1 {
|
||||
search_calls.swap_remove(0)
|
||||
} else {
|
||||
// In theory, the model could perform multiple search calls
|
||||
// Dealing with them separately is not worth it when it doesn't happen in practice.
|
||||
// If it were to happen, here we would combine them into one.
|
||||
// The second request doesn't need to know it was actually two different calls ;)
|
||||
let input = serde_json::to_value(&SearchToolInput {
|
||||
queries: regex_by_glob
|
||||
.iter()
|
||||
.map(|(glob, regex)| SearchToolQuery {
|
||||
glob: glob.clone(),
|
||||
regex: regex.clone(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
LanguageModelToolUse {
|
||||
id: search_calls.swap_remove(0).id,
|
||||
name: SELECT_TOOL_NAME.into(),
|
||||
raw_input: serde_json::to_string(&input).unwrap_or_default(),
|
||||
input,
|
||||
is_input_complete: true,
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
|
||||
ZetaSearchQueryDebugInfo {
|
||||
project: project.clone(),
|
||||
timestamp: Instant::now(),
|
||||
queries: regex_by_glob
|
||||
.iter()
|
||||
.map(|(glob, regex)| SearchToolQuery {
|
||||
glob: glob.clone(),
|
||||
regex: regex.clone(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (results_tx, mut results_rx) = mpsc::unbounded();
|
||||
|
||||
for (glob, regex) in regex_by_glob {
|
||||
let exclude_matcher = exclude_matcher.clone();
|
||||
let results_tx = results_tx.clone();
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
run_query(
|
||||
&glob,
|
||||
®ex,
|
||||
results_tx.clone(),
|
||||
path_style,
|
||||
exclude_matcher,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
drop(results_tx);
|
||||
|
||||
struct ResultBuffer {
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: TextBufferSnapshot,
|
||||
}
|
||||
|
||||
let (result_buffers_by_path, merged_result) = cx
|
||||
.background_spawn(async move {
|
||||
let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
|
||||
HashMap::default();
|
||||
|
||||
while let Some((buffer, matched)) = results_rx.next().await {
|
||||
match excerpts_by_buffer.entry(buffer) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let entry = entry.get_mut();
|
||||
entry.full_path = matched.full_path;
|
||||
entry.snapshot = matched.snapshot;
|
||||
entry.line_ranges.extend(matched.line_ranges);
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(matched);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut result_buffers_by_path = HashMap::default();
|
||||
let mut merged_result = RESULTS_MESSAGE.to_string();
|
||||
|
||||
for (buffer, mut matched) in excerpts_by_buffer {
|
||||
matched
|
||||
.line_ranges
|
||||
.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
|
||||
|
||||
write_codeblock(
|
||||
&matched.full_path,
|
||||
merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
|
||||
&[],
|
||||
Line(matched.snapshot.max_point().row),
|
||||
true,
|
||||
&mut merged_result,
|
||||
);
|
||||
|
||||
result_buffers_by_path.insert(
|
||||
matched.full_path,
|
||||
ResultBuffer {
|
||||
buffer,
|
||||
snapshot: matched.snapshot.text,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
(result_buffers_by_path, merged_result)
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
|
||||
ZetaContextRetrievalDebugInfo {
|
||||
project: project.clone(),
|
||||
timestamp: Instant::now(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let tool_result = LanguageModelToolResult {
|
||||
tool_use_id: search_tool_use.id.clone(),
|
||||
tool_name: SEARCH_TOOL_NAME.into(),
|
||||
is_error: false,
|
||||
content: merged_result.into(),
|
||||
output: None,
|
||||
};
|
||||
|
||||
select_request_messages.extend([
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(search_tool_use)],
|
||||
cache: false,
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::ToolResult(tool_result)],
|
||||
cache: false,
|
||||
},
|
||||
]);
|
||||
|
||||
if result_buffers_by_path.is_empty() {
|
||||
log::trace!("context gathering queries produced no results");
|
||||
return anyhow::Ok(HashMap::default());
|
||||
}
|
||||
|
||||
select_request_messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![SELECT_PROMPT.into()],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
let mut select_stream = request_tool_call::<SelectToolInput>(
|
||||
select_request_messages,
|
||||
SELECT_TOOL_NAME,
|
||||
&model,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut selected_ranges = Vec::new();
|
||||
|
||||
while let Some(event) = select_stream.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
if !tool_use.is_input_complete {
|
||||
continue;
|
||||
}
|
||||
|
||||
if tool_use.name.as_ref() == SELECT_TOOL_NAME {
|
||||
let call =
|
||||
serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
|
||||
selected_ranges.extend(call.ranges);
|
||||
} else {
|
||||
log::warn!(
|
||||
"context gathering model tried to use unknown tool: {}",
|
||||
tool_use.name
|
||||
);
|
||||
}
|
||||
}
|
||||
ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
|
||||
log::error!("{ev:?}");
|
||||
}
|
||||
ev => {
|
||||
log::trace!("context select event: {ev:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
|
||||
ZetaContextRetrievalDebugInfo {
|
||||
project: project.clone(),
|
||||
timestamp: Instant::now(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if selected_ranges.is_empty() {
|
||||
log::trace!("context gathering selected no ranges")
|
||||
}
|
||||
|
||||
selected_ranges.sort_unstable_by(|a, b| {
|
||||
a.start_line
|
||||
.cmp(&b.start_line)
|
||||
.then(b.end_line.cmp(&a.end_line))
|
||||
});
|
||||
|
||||
let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
|
||||
|
||||
for selected_range in selected_ranges {
|
||||
if let Some(ResultBuffer { buffer, snapshot }) =
|
||||
result_buffers_by_path.get(&selected_range.path)
|
||||
{
|
||||
let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
|
||||
let end_point =
|
||||
snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
|
||||
let range =
|
||||
snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
|
||||
|
||||
related_excerpts_by_buffer
|
||||
.entry(buffer.clone())
|
||||
.or_default()
|
||||
.push(range);
|
||||
} else {
|
||||
log::warn!(
|
||||
"selected path that wasn't included in search results: {}",
|
||||
selected_range.path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(related_excerpts_by_buffer)
|
||||
})
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
async fn request_tool_call<T: JsonSchema>(
|
||||
messages: Vec<LanguageModelRequestMessage>,
|
||||
tool_name: &'static str,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
|
||||
{
|
||||
let schema = schemars::schema_for!(T);
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
messages,
|
||||
tools: vec![LanguageModelRequestTool {
|
||||
name: tool_name.into(),
|
||||
description: schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
input_schema: serde_json::to_value(schema).unwrap(),
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(model.stream_completion(request, cx).await?)
|
||||
}
|
||||
|
||||
const MIN_EXCERPT_LEN: usize = 16;
|
||||
const MAX_EXCERPT_LEN: usize = 768;
|
||||
const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
|
||||
|
||||
struct MatchedBuffer {
|
||||
snapshot: BufferSnapshot,
|
||||
line_ranges: Vec<Range<Line>>,
|
||||
full_path: PathBuf,
|
||||
}
|
||||
|
||||
async fn run_query(
|
||||
glob: &str,
|
||||
regex: &str,
|
||||
results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
|
||||
path_style: PathStyle,
|
||||
exclude_matcher: PathMatcher,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let include_matcher = PathMatcher::new(vec![glob], path_style)?;
|
||||
|
||||
let query = SearchQuery::regex(
|
||||
regex,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
include_matcher,
|
||||
exclude_matcher,
|
||||
true,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let results = project.update(cx, |project, cx| project.search(query, cx))?;
|
||||
futures::pin_mut!(results);
|
||||
|
||||
let mut total_bytes = 0;
|
||||
|
||||
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
|
||||
if ranges.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
|
||||
Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
|
||||
})?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let results_tx = results_tx.clone();
|
||||
cx.background_spawn(async move {
|
||||
let mut line_ranges = Vec::with_capacity(ranges.len());
|
||||
|
||||
for range in ranges {
|
||||
let offset_range = range.to_offset(&snapshot);
|
||||
let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
|
||||
|
||||
if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
|
||||
break;
|
||||
}
|
||||
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(
|
||||
query_point,
|
||||
&snapshot,
|
||||
&EditPredictionExcerptOptions {
|
||||
max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
|
||||
min_bytes: MIN_EXCERPT_LEN,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
if let Some(excerpt) = excerpt {
|
||||
total_bytes += excerpt.range.len();
|
||||
if !excerpt.line_range.is_empty() {
|
||||
line_ranges.push(excerpt.line_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results_tx
|
||||
.unbounded_send((
|
||||
buffer,
|
||||
MatchedBuffer {
|
||||
snapshot,
|
||||
line_ranges,
|
||||
full_path,
|
||||
},
|
||||
))
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
194
crates/zeta2/src/retrieval_search.rs
Normal file
194
crates/zeta2/src/retrieval_search.rs
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
|
||||
use futures::{
|
||||
StreamExt,
|
||||
channel::mpsc::{self, UnboundedSender},
|
||||
};
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _};
|
||||
use project::{
|
||||
Project, WorktreeSettings,
|
||||
search::{SearchQuery, SearchResult},
|
||||
};
|
||||
use util::{
|
||||
ResultExt as _,
|
||||
paths::{PathMatcher, PathStyle},
|
||||
};
|
||||
use workspace::item::Settings as _;
|
||||
|
||||
pub async fn run_retrieval_searches(
|
||||
project: Entity<Project>,
|
||||
regex_by_glob: HashMap<String, String>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
|
||||
let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
|
||||
let global_settings = WorktreeSettings::get_global(cx);
|
||||
let exclude_patterns = global_settings
|
||||
.file_scan_exclusions
|
||||
.sources()
|
||||
.iter()
|
||||
.chain(global_settings.private_files.sources().iter());
|
||||
let path_style = project.path_style(cx);
|
||||
anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
|
||||
})??;
|
||||
|
||||
let (results_tx, mut results_rx) = mpsc::unbounded();
|
||||
|
||||
for (glob, regex) in regex_by_glob {
|
||||
let exclude_matcher = exclude_matcher.clone();
|
||||
let results_tx = results_tx.clone();
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
run_query(
|
||||
&glob,
|
||||
®ex,
|
||||
results_tx.clone(),
|
||||
path_style,
|
||||
exclude_matcher,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
drop(results_tx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
|
||||
let mut snapshots = HashMap::default();
|
||||
|
||||
let mut total_bytes = 0;
|
||||
'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
|
||||
snapshots.insert(buffer.entity_id(), snapshot);
|
||||
let existing = results.entry(buffer).or_default();
|
||||
existing.reserve(excerpts.len());
|
||||
|
||||
for (range, size) in excerpts {
|
||||
// Blunt trimming of the results until we have a proper algorithmic filtering step
|
||||
if (total_bytes + size) > MAX_RESULTS_LEN {
|
||||
log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
|
||||
break 'outer;
|
||||
}
|
||||
total_bytes += size;
|
||||
existing.push(range);
|
||||
}
|
||||
}
|
||||
|
||||
for (buffer, ranges) in results.iter_mut() {
|
||||
if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
|
||||
ranges.sort_unstable_by(|a, b| {
|
||||
a.start
|
||||
.cmp(&b.start, snapshot)
|
||||
.then(b.end.cmp(&b.end, snapshot))
|
||||
});
|
||||
|
||||
let mut index = 1;
|
||||
while index < ranges.len() {
|
||||
if ranges[index - 1]
|
||||
.end
|
||||
.cmp(&ranges[index].start, snapshot)
|
||||
.is_gt()
|
||||
{
|
||||
let removed = ranges.remove(index);
|
||||
ranges[index - 1].end = removed.end;
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
const MIN_EXCERPT_LEN: usize = 16;
|
||||
const MAX_EXCERPT_LEN: usize = 768;
|
||||
const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
|
||||
|
||||
async fn run_query(
|
||||
glob: &str,
|
||||
regex: &str,
|
||||
results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
|
||||
path_style: PathStyle,
|
||||
exclude_matcher: PathMatcher,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let include_matcher = PathMatcher::new(vec![glob], path_style)?;
|
||||
|
||||
let query = SearchQuery::regex(
|
||||
regex,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
include_matcher,
|
||||
exclude_matcher,
|
||||
true,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let results = project.update(cx, |project, cx| project.search(query, cx))?;
|
||||
futures::pin_mut!(results);
|
||||
|
||||
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
|
||||
if results_tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
if ranges.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let results_tx = results_tx.clone();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut excerpts = Vec::with_capacity(ranges.len());
|
||||
|
||||
for range in ranges {
|
||||
let offset_range = range.to_offset(&snapshot);
|
||||
let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
|
||||
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(
|
||||
query_point,
|
||||
&snapshot,
|
||||
&EditPredictionExcerptOptions {
|
||||
max_bytes: MAX_EXCERPT_LEN,
|
||||
min_bytes: MIN_EXCERPT_LEN,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
if let Some(excerpt) = excerpt
|
||||
&& !excerpt.line_range.is_empty()
|
||||
{
|
||||
excerpts.push((
|
||||
snapshot.anchor_after(excerpt.range.start)
|
||||
..snapshot.anchor_before(excerpt.range.end),
|
||||
excerpt.range.len(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts));
|
||||
|
||||
if let Err(err) = send_result
|
||||
&& !err.is_disconnected()
|
||||
{
|
||||
log::error!("{err}");
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
1024
crates/zeta2/src/udiff.rs
Normal file
1024
crates/zeta2/src/udiff.rs
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -27,10 +27,11 @@ log.workspace = true
|
|||
multi_buffer.workspace = true
|
||||
ordered-float.workspace = true
|
||||
project.workspace = true
|
||||
regex-syntax = "0.8.8"
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
telemetry.workspace = true
|
||||
text.workspace = true
|
||||
regex-syntax = "0.8.8"
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
util.workspace = true
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ struct RetrievalRun {
|
|||
started_at: Instant,
|
||||
search_results_generated_at: Option<Instant>,
|
||||
search_results_executed_at: Option<Instant>,
|
||||
search_results_filtered_at: Option<Instant>,
|
||||
finished_at: Option<Instant>,
|
||||
}
|
||||
|
||||
|
|
@ -117,17 +116,12 @@ impl Zeta2ContextView {
|
|||
self.handle_search_queries_executed(info, window, cx);
|
||||
}
|
||||
}
|
||||
ZetaDebugInfo::SearchResultsFiltered(info) => {
|
||||
if info.project == self.project {
|
||||
self.handle_search_results_filtered(info, window, cx);
|
||||
}
|
||||
}
|
||||
ZetaDebugInfo::ContextRetrievalFinished(info) => {
|
||||
if info.project == self.project {
|
||||
self.handle_context_retrieval_finished(info, window, cx);
|
||||
}
|
||||
}
|
||||
ZetaDebugInfo::EditPredicted(_) => {}
|
||||
ZetaDebugInfo::EditPredictionRequested(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -159,7 +153,6 @@ impl Zeta2ContextView {
|
|||
started_at: info.timestamp,
|
||||
search_results_generated_at: None,
|
||||
search_results_executed_at: None,
|
||||
search_results_filtered_at: None,
|
||||
finished_at: None,
|
||||
});
|
||||
|
||||
|
|
@ -218,18 +211,18 @@ impl Zeta2ContextView {
|
|||
|
||||
run.search_results_generated_at = Some(info.timestamp);
|
||||
run.search_queries = info
|
||||
.queries
|
||||
.regex_by_glob
|
||||
.into_iter()
|
||||
.map(|query| {
|
||||
.map(|(glob, regex)| {
|
||||
let mut regex_parser = regex_syntax::ast::parse::Parser::new();
|
||||
|
||||
GlobQueries {
|
||||
glob: query.glob,
|
||||
alternations: match regex_parser.parse(&query.regex) {
|
||||
glob,
|
||||
alternations: match regex_parser.parse(®ex) {
|
||||
Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => {
|
||||
alt.asts.iter().map(|ast| ast.to_string()).collect()
|
||||
}
|
||||
_ => vec![query.regex],
|
||||
_ => vec![regex],
|
||||
},
|
||||
}
|
||||
})
|
||||
|
|
@ -256,20 +249,6 @@ impl Zeta2ContextView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_search_results_filtered(
|
||||
&mut self,
|
||||
info: ZetaContextRetrievalDebugInfo,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(run) = self.runs.back_mut() else {
|
||||
return;
|
||||
};
|
||||
|
||||
run.search_results_filtered_at = Some(info.timestamp);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_go_back(
|
||||
&mut self,
|
||||
_: &Zeta2ContextGoBack,
|
||||
|
|
@ -398,19 +377,10 @@ impl Zeta2ContextView {
|
|||
};
|
||||
div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis()));
|
||||
|
||||
let Some(t3) = run.search_results_filtered_at else {
|
||||
return pending_message(div, "Filtering results...");
|
||||
};
|
||||
div =
|
||||
div.child(format!("Filtered results: {:>5} ms", (t3 - t2).as_millis()));
|
||||
|
||||
let Some(t4) = run.finished_at else {
|
||||
return pending_message(div, "Building excerpts");
|
||||
};
|
||||
div = div
|
||||
.child(format!("Build excerpts: {:>5} µs", (t4 - t3).as_micros()))
|
||||
.child(format!("Total: {:>5} ms", (t4 - t0).as_millis()));
|
||||
div
|
||||
div.child(format!(
|
||||
"Total: {:>5} ms",
|
||||
(run.finished_at.unwrap_or(t0) - t0).as_millis()
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
|||
use chrono::TimeDelta;
|
||||
use client::{Client, UserStore};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat,
|
||||
DeclarationScoreComponents, PredictEditsRequest, PromptFormat,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
|
||||
|
|
@ -23,7 +23,7 @@ use ui_input::InputField;
|
|||
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
|
||||
use workspace::{Item, SplitDirection, Workspace};
|
||||
use zeta2::{
|
||||
ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, Zeta, Zeta2FeatureFlag,
|
||||
AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag,
|
||||
ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
|
||||
};
|
||||
|
||||
|
|
@ -123,6 +123,7 @@ struct LastPrediction {
|
|||
context_editor: Entity<Editor>,
|
||||
prompt_editor: Entity<Editor>,
|
||||
retrieval_time: TimeDelta,
|
||||
request_time: Option<TimeDelta>,
|
||||
buffer: WeakEntity<Buffer>,
|
||||
position: language::Anchor,
|
||||
state: LastPredictionState,
|
||||
|
|
@ -143,7 +144,7 @@ enum LastPredictionState {
|
|||
model_response_editor: Entity<Editor>,
|
||||
feedback_editor: Entity<Editor>,
|
||||
feedback: Option<Feedback>,
|
||||
response: predict_edits_v3::PredictEditsResponse,
|
||||
request_id: String,
|
||||
},
|
||||
Failed {
|
||||
message: String,
|
||||
|
|
@ -217,7 +218,7 @@ impl Zeta2Inspector {
|
|||
});
|
||||
|
||||
match &options.context {
|
||||
ContextMode::Llm(_) => {
|
||||
ContextMode::Agentic(_) => {
|
||||
self.context_mode = ContextModeState::Llm;
|
||||
}
|
||||
ContextMode::Syntax(_) => {
|
||||
|
|
@ -307,9 +308,11 @@ impl Zeta2Inspector {
|
|||
};
|
||||
|
||||
let context = match zeta_options.context {
|
||||
ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions {
|
||||
excerpt: excerpt_options,
|
||||
}),
|
||||
ContextMode::Agentic(_context_options) => {
|
||||
ContextMode::Agentic(AgenticContextOptions {
|
||||
excerpt: excerpt_options,
|
||||
})
|
||||
}
|
||||
ContextMode::Syntax(context_options) => {
|
||||
let max_retrieved_declarations = match &this.context_mode {
|
||||
ContextModeState::Llm => {
|
||||
|
|
@ -368,7 +371,7 @@ impl Zeta2Inspector {
|
|||
let language_registry = self.project.read(cx).languages().clone();
|
||||
async move |this, cx| {
|
||||
let mut languages = HashMap::default();
|
||||
let ZetaDebugInfo::EditPredicted(prediction) = prediction else {
|
||||
let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else {
|
||||
return;
|
||||
};
|
||||
for ext in prediction
|
||||
|
|
@ -396,6 +399,8 @@ impl Zeta2Inspector {
|
|||
.await
|
||||
.log_err();
|
||||
|
||||
let json_language = language_registry.language_for_name("Json").await.log_err();
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let context_editor = cx.new(|cx| {
|
||||
let mut excerpt_score_components = HashMap::default();
|
||||
|
|
@ -492,25 +497,15 @@ impl Zeta2Inspector {
|
|||
|
||||
let task = cx.spawn_in(window, {
|
||||
let markdown_language = markdown_language.clone();
|
||||
let json_language = json_language.clone();
|
||||
async move |this, cx| {
|
||||
let response = response_rx.await;
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
if let Some(prediction) = this.last_prediction.as_mut() {
|
||||
prediction.state = match response {
|
||||
Ok(Ok(response)) => {
|
||||
if let Some(debug_info) = &response.debug_info {
|
||||
prediction.prompt_editor.update(
|
||||
cx,
|
||||
|prompt_editor, cx| {
|
||||
prompt_editor.set_text(
|
||||
debug_info.prompt.as_str(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok((Ok(response), request_time)) => {
|
||||
prediction.request_time = Some(request_time);
|
||||
|
||||
let feedback_editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| {
|
||||
|
|
@ -577,16 +572,11 @@ impl Zeta2Inspector {
|
|||
model_response_editor: cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| {
|
||||
let mut buffer = Buffer::local(
|
||||
response
|
||||
.debug_info
|
||||
.as_ref()
|
||||
.map(|p| p.model_response.as_str())
|
||||
.unwrap_or(
|
||||
"(Debug info not available)",
|
||||
),
|
||||
serde_json::to_string_pretty(&response)
|
||||
.unwrap_or_default(),
|
||||
cx,
|
||||
);
|
||||
buffer.set_language(markdown_language, cx);
|
||||
buffer.set_language(json_language, cx);
|
||||
buffer
|
||||
});
|
||||
let buffer = cx.new(|cx| {
|
||||
|
|
@ -607,10 +597,11 @@ impl Zeta2Inspector {
|
|||
}),
|
||||
feedback_editor,
|
||||
feedback: None,
|
||||
response,
|
||||
request_id: response.id.clone(),
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
Ok((Err(err), request_time)) => {
|
||||
prediction.request_time = Some(request_time);
|
||||
LastPredictionState::Failed { message: err }
|
||||
}
|
||||
Err(oneshot::Canceled) => LastPredictionState::Failed {
|
||||
|
|
@ -644,6 +635,7 @@ impl Zeta2Inspector {
|
|||
editor
|
||||
}),
|
||||
retrieval_time,
|
||||
request_time: None,
|
||||
buffer,
|
||||
position,
|
||||
state: LastPredictionState::Requested,
|
||||
|
|
@ -700,7 +692,7 @@ impl Zeta2Inspector {
|
|||
feedback: feedback_state,
|
||||
feedback_editor,
|
||||
model_response_editor,
|
||||
response,
|
||||
request_id,
|
||||
..
|
||||
} = &mut last_prediction.state
|
||||
else {
|
||||
|
|
@ -734,11 +726,10 @@ impl Zeta2Inspector {
|
|||
|
||||
telemetry::event!(
|
||||
"Zeta2 Prediction Rated",
|
||||
id = response.request_id,
|
||||
id = request_id,
|
||||
kind = kind,
|
||||
text = text,
|
||||
request = last_prediction.request,
|
||||
response = response,
|
||||
project_snapshot = project_snapshot,
|
||||
);
|
||||
})
|
||||
|
|
@ -834,11 +825,11 @@ impl Zeta2Inspector {
|
|||
let current_options =
|
||||
this.zeta.read(cx).options().clone();
|
||||
match current_options.context.clone() {
|
||||
ContextMode::Llm(_) => {}
|
||||
ContextMode::Agentic(_) => {}
|
||||
ContextMode::Syntax(context_options) => {
|
||||
let options = ZetaOptions {
|
||||
context: ContextMode::Llm(
|
||||
LlmContextOptions {
|
||||
context: ContextMode::Agentic(
|
||||
AgenticContextOptions {
|
||||
excerpt: context_options.excerpt,
|
||||
},
|
||||
),
|
||||
|
|
@ -865,7 +856,7 @@ impl Zeta2Inspector {
|
|||
let current_options =
|
||||
this.zeta.read(cx).options().clone();
|
||||
match current_options.context.clone() {
|
||||
ContextMode::Llm(context_options) => {
|
||||
ContextMode::Agentic(context_options) => {
|
||||
let options = ZetaOptions {
|
||||
context: ContextMode::Syntax(
|
||||
EditPredictionContextOptions {
|
||||
|
|
@ -976,25 +967,6 @@ impl Zeta2Inspector {
|
|||
return None;
|
||||
};
|
||||
|
||||
let (prompt_planning_time, inference_time, parsing_time) =
|
||||
if let LastPredictionState::Success {
|
||||
response:
|
||||
PredictEditsResponse {
|
||||
debug_info: Some(debug_info),
|
||||
..
|
||||
},
|
||||
..
|
||||
} = &prediction.state
|
||||
{
|
||||
(
|
||||
Some(debug_info.prompt_planning_time),
|
||||
Some(debug_info.inference_time),
|
||||
Some(debug_info.parsing_time),
|
||||
)
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
|
|
@ -1005,12 +977,7 @@ impl Zeta2Inspector {
|
|||
"Context retrieval",
|
||||
Some(prediction.retrieval_time),
|
||||
))
|
||||
.child(Self::render_duration(
|
||||
"Prompt planning",
|
||||
prompt_planning_time,
|
||||
))
|
||||
.child(Self::render_duration("Inference", inference_time))
|
||||
.child(Self::render_duration("Parsing", parsing_time)),
|
||||
.child(Self::render_duration("Request", prediction.request_time)),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@ use std::{
|
|||
|
||||
use anyhow::Result;
|
||||
use clap::Args;
|
||||
use cloud_llm_client::udiff::DiffLine;
|
||||
use collections::HashSet;
|
||||
use gpui::AsyncApp;
|
||||
use zeta2::udiff::DiffLine;
|
||||
|
||||
use crate::{
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::CACHE_DIR,
|
||||
predict::{PredictionDetails, zeta2_predict},
|
||||
};
|
||||
|
||||
|
|
@ -54,10 +55,8 @@ pub async fn run_evaluate_one(
|
|||
app_state: Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<EvaluationResult> {
|
||||
let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
|
||||
.join("../../target/zeta-prediction-cache");
|
||||
let example = NamedExample::load(&example_path).unwrap();
|
||||
let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
|
||||
let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
|
||||
|
||||
let predictions = if !re_run && example_cache_path.exists() {
|
||||
let file_contents = fs::read_to_string(&example_cache_path)?;
|
||||
|
|
@ -74,7 +73,7 @@ pub async fn run_evaluate_one(
|
|||
};
|
||||
|
||||
if !example_cache_path.exists() {
|
||||
fs::create_dir_all(&cache_dir).unwrap();
|
||||
fs::create_dir_all(&*CACHE_DIR).unwrap();
|
||||
fs::write(
|
||||
example_cache_path,
|
||||
serde_json::to_string(&predictions).unwrap(),
|
||||
|
|
|
|||
|
|
@ -1,28 +1,31 @@
|
|||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
fs,
|
||||
io::Write,
|
||||
mem,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use collections::{HashMap, HashSet};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use gpui::{AsyncApp, Entity, http_client::Url};
|
||||
use language::Buffer;
|
||||
use language::{Anchor, Buffer};
|
||||
use project::{Project, ProjectPath};
|
||||
use pulldown_cmark::CowStr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta2::udiff::OpenedBuffers;
|
||||
|
||||
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
|
||||
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
|
|
@ -215,12 +218,10 @@ impl NamedExample {
|
|||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
let file_name = self.file_name();
|
||||
|
||||
let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
|
||||
let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
|
||||
fs::create_dir_all(&repos_dir)?;
|
||||
fs::create_dir_all(&worktrees_dir)?;
|
||||
fs::create_dir_all(&*REPOS_DIR)?;
|
||||
fs::create_dir_all(&*WORKTREES_DIR)?;
|
||||
|
||||
let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
|
|
@ -251,7 +252,7 @@ impl NamedExample {
|
|||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = worktrees_dir.join(&file_name);
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name);
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
|
|
@ -309,7 +310,6 @@ impl NamedExample {
|
|||
.collect()
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
// git@github.com:owner/repo.git
|
||||
if self.example.repository_url.contains('@') {
|
||||
|
|
@ -344,13 +344,63 @@ impl NamedExample {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn cursor_position(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})?;
|
||||
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = self
|
||||
.example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))?;
|
||||
let mut cursor_excerpt = self.example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let Some((excerpt_offset, _)) = matches.next() else {
|
||||
anyhow::bail!(
|
||||
"Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
|
||||
);
|
||||
};
|
||||
assert!(matches.next().is_none());
|
||||
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_edit_history(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<HashSet<Entity<Buffer>>> {
|
||||
apply_diff(&self.example.edit_history, project, cx).await
|
||||
) -> Result<OpenedBuffers<'_>> {
|
||||
zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -446,404 +496,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
|||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_diff(
|
||||
diff: &str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<HashSet<Entity<Buffer>>> {
|
||||
use cloud_llm_client::udiff::DiffLine;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct HunkState {
|
||||
context: String,
|
||||
edits: Vec<Edit>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Edit {
|
||||
range: Range<usize>,
|
||||
text: String,
|
||||
}
|
||||
|
||||
let mut old_path = None;
|
||||
let mut new_path = None;
|
||||
let mut hunk = HunkState::default();
|
||||
let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
|
||||
let mut open_buffers = HashSet::default();
|
||||
|
||||
while let Some(diff_line) = diff_lines.next() {
|
||||
match diff_line {
|
||||
DiffLine::OldPath { path } => old_path = Some(path),
|
||||
DiffLine::NewPath { path } => {
|
||||
if old_path.is_none() {
|
||||
anyhow::bail!(
|
||||
"Found a new path header (`+++`) before an (`---`) old path header"
|
||||
);
|
||||
}
|
||||
new_path = Some(path)
|
||||
}
|
||||
DiffLine::Context(ctx) => {
|
||||
writeln!(&mut hunk.context, "{ctx}")?;
|
||||
}
|
||||
DiffLine::Deletion(del) => {
|
||||
let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
|
||||
if let Some(last_edit) = hunk.edits.last_mut()
|
||||
&& last_edit.range.end == range.start
|
||||
{
|
||||
last_edit.range.end = range.end;
|
||||
} else {
|
||||
hunk.edits.push(Edit {
|
||||
range,
|
||||
text: String::new(),
|
||||
});
|
||||
}
|
||||
writeln!(&mut hunk.context, "{del}")?;
|
||||
}
|
||||
DiffLine::Addition(add) => {
|
||||
let range = hunk.context.len()..hunk.context.len();
|
||||
if let Some(last_edit) = hunk.edits.last_mut()
|
||||
&& last_edit.range.end == range.start
|
||||
{
|
||||
writeln!(&mut last_edit.text, "{add}").unwrap();
|
||||
} else {
|
||||
hunk.edits.push(Edit {
|
||||
range,
|
||||
text: format!("{add}\n"),
|
||||
});
|
||||
}
|
||||
}
|
||||
DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {}
|
||||
}
|
||||
|
||||
let at_hunk_end = match diff_lines.peek() {
|
||||
Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if at_hunk_end {
|
||||
let hunk = mem::take(&mut hunk);
|
||||
|
||||
let Some(old_path) = old_path.as_deref() else {
|
||||
anyhow::bail!("Missing old path (`---`) header")
|
||||
};
|
||||
|
||||
let Some(new_path) = new_path.as_deref() else {
|
||||
anyhow::bail!("Missing new path (`+++`) header")
|
||||
};
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(old_path, cx)
|
||||
.context("Failed to find old_path in project")?;
|
||||
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
open_buffers.insert(buffer.clone());
|
||||
|
||||
if old_path != new_path {
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
|
||||
let new_path = ProjectPath {
|
||||
worktree_id: project_file.worktree_id(cx),
|
||||
path: project_file.path.clone(),
|
||||
};
|
||||
project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
|
||||
})?
|
||||
.await?;
|
||||
}
|
||||
|
||||
// TODO is it worth using project search?
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let context_offset = if hunk.context.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let text = buffer.text();
|
||||
if let Some(offset) = text.find(&hunk.context) {
|
||||
if text[offset + 1..].contains(&hunk.context) {
|
||||
anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
|
||||
}
|
||||
offset
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"Failed to match context:\n{}\n\nBuffer:\n{}",
|
||||
hunk.context,
|
||||
text
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
buffer.edit(
|
||||
hunk.edits.into_iter().map(|edit| {
|
||||
(
|
||||
context_offset + edit.range.start..context_offset + edit.range.end,
|
||||
edit.text,
|
||||
)
|
||||
}),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(open_buffers)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ::fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_successful(cx: &mut TestAppContext) {
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
let buffer_1_text_final = indoc! {r#"
|
||||
3
|
||||
4
|
||||
5
|
||||
"# };
|
||||
|
||||
let buffer_2_text = indoc! {r#"
|
||||
six
|
||||
seven
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
"# };
|
||||
|
||||
let buffer_2_text_final = indoc! {r#"
|
||||
5
|
||||
six
|
||||
seven
|
||||
7.5
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
11
|
||||
"# };
|
||||
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
"file2": buffer_2_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
3
|
||||
-four
|
||||
-five
|
||||
+4
|
||||
+5
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
-one
|
||||
-two
|
||||
3
|
||||
4
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
+5
|
||||
six
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
seven
|
||||
+7.5
|
||||
eight
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
ten
|
||||
+11
|
||||
"#};
|
||||
|
||||
let _buffers = apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
buffer_1.read_with(cx, |buffer, _cx| {
|
||||
assert_eq!(buffer.text(), buffer_1_text_final);
|
||||
});
|
||||
let buffer_2 = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
buffer_2.read_with(cx, |buffer, _cx| {
|
||||
assert_eq!(buffer.text(), buffer_2_text_final);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.expect_err("Non-unique edits should fail");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
|
||||
let start = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
let end = indoc! {r#"
|
||||
one
|
||||
two
|
||||
3
|
||||
four
|
||||
5
|
||||
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": start,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
-five
|
||||
+5
|
||||
"#};
|
||||
|
||||
let _buffers = apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let buffer_1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
buffer_1.read_with(cx, |buffer, _cx| {
|
||||
assert_eq!(buffer.text(), end);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
mod evaluate;
|
||||
mod example;
|
||||
mod headless;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod syntax_retrieval_stats;
|
||||
|
|
@ -10,28 +11,22 @@ use crate::evaluate::{EvaluateArguments, run_evaluate};
|
|||
use crate::example::{ExampleFormat, NamedExample};
|
||||
use crate::predict::{PredictArguments, run_zeta2_predict};
|
||||
use crate::syntax_retrieval_stats::retrieval_stats;
|
||||
use ::serde::Serialize;
|
||||
use ::util::paths::PathStyle;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction_context::{
|
||||
EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
|
||||
EditPredictionScoreOptions, Line,
|
||||
EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
|
||||
};
|
||||
use futures::StreamExt as _;
|
||||
use futures::channel::mpsc;
|
||||
use gpui::{Application, AsyncApp, Entity, prelude::*};
|
||||
use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language::{Bias, Buffer, BufferSnapshot, Point};
|
||||
use project::{Project, Worktree};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde_json::json;
|
||||
use std::io::{self};
|
||||
use std::time::Duration;
|
||||
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
|
||||
use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
|
||||
use zeta2::ContextMode;
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::source_location::SourceLocation;
|
||||
|
|
@ -79,12 +74,6 @@ enum Zeta2Command {
|
|||
#[command(subcommand)]
|
||||
command: Zeta2SyntaxCommand,
|
||||
},
|
||||
Llm {
|
||||
#[clap(flatten)]
|
||||
args: Zeta2Args,
|
||||
#[command(subcommand)]
|
||||
command: Zeta2LlmCommand,
|
||||
},
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
}
|
||||
|
|
@ -107,14 +96,6 @@ enum Zeta2SyntaxCommand {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Zeta2LlmCommand {
|
||||
Context {
|
||||
#[clap(flatten)]
|
||||
context_args: ContextArgs,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
#[group(requires = "worktree")]
|
||||
struct ContextArgs {
|
||||
|
|
@ -388,197 +369,6 @@ async fn zeta2_syntax_context(
|
|||
Ok(output)
|
||||
}
|
||||
|
||||
async fn zeta2_llm_context(
|
||||
zeta2_args: Zeta2Args,
|
||||
context_args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
let LoadedContext {
|
||||
buffer,
|
||||
clipped_cursor,
|
||||
snapshot: cursor_snapshot,
|
||||
project,
|
||||
..
|
||||
} = load_context(&context_args, app_state, cx).await?;
|
||||
|
||||
let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
|
||||
|
||||
cx.update(|cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry
|
||||
.provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
|
||||
.unwrap()
|
||||
.authenticate(cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let edit_history_unified_diff = match context_args.edit_history {
|
||||
Some(events) => events.read_to_string().await?,
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let (debug_tx, mut debug_rx) = mpsc::unbounded();
|
||||
|
||||
let excerpt_options = EditPredictionExcerptOptions {
|
||||
max_bytes: zeta2_args.max_excerpt_bytes,
|
||||
min_bytes: zeta2_args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
|
||||
};
|
||||
|
||||
let related_excerpts = cx
|
||||
.update(|cx| {
|
||||
zeta2::related_excerpts::find_related_excerpts(
|
||||
buffer,
|
||||
cursor_position,
|
||||
&project,
|
||||
edit_history_unified_diff,
|
||||
&LlmContextOptions {
|
||||
excerpt: excerpt_options.clone(),
|
||||
},
|
||||
Some(debug_tx),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
|
||||
clipped_cursor,
|
||||
&cursor_snapshot,
|
||||
&excerpt_options,
|
||||
None,
|
||||
)
|
||||
.context("line didn't fit")?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Output {
|
||||
excerpts: Vec<OutputExcerpt>,
|
||||
formatted_excerpts: String,
|
||||
meta: OutputMeta,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize)]
|
||||
struct OutputMeta {
|
||||
search_prompt: String,
|
||||
search_queries: Vec<SearchToolQuery>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OutputExcerpt {
|
||||
path: PathBuf,
|
||||
#[serde(flatten)]
|
||||
excerpt: Excerpt,
|
||||
}
|
||||
|
||||
let mut meta = OutputMeta::default();
|
||||
|
||||
while let Some(debug_info) = debug_rx.next().await {
|
||||
match debug_info {
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
|
||||
meta.search_prompt = info.search_prompt;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
|
||||
meta.search_queries = info.queries
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
let mut excerpts = Vec::new();
|
||||
let mut formatted_excerpts = String::new();
|
||||
|
||||
let cursor_insertions = [(
|
||||
predict_edits_v3::Point {
|
||||
line: Line(clipped_cursor.row),
|
||||
column: clipped_cursor.column,
|
||||
},
|
||||
CURSOR_MARKER,
|
||||
)];
|
||||
|
||||
let mut cursor_excerpt_added = false;
|
||||
|
||||
for (buffer, ranges) in related_excerpts {
|
||||
let excerpt_snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
let mut line_ranges = ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let point_range = range.to_point(&excerpt_snapshot);
|
||||
Line(point_range.start.row)..Line(point_range.end.row)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let Some(file) = excerpt_snapshot.file() else {
|
||||
continue;
|
||||
};
|
||||
let path = file.full_path(cx);
|
||||
|
||||
let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
|
||||
if is_cursor_file {
|
||||
let insertion_ix = line_ranges
|
||||
.binary_search_by(|probe| {
|
||||
probe
|
||||
.start
|
||||
.cmp(&cursor_excerpt.line_range.start)
|
||||
.then(cursor_excerpt.line_range.end.cmp(&probe.end))
|
||||
})
|
||||
.unwrap_or_else(|ix| ix);
|
||||
line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
|
||||
cursor_excerpt_added = true;
|
||||
}
|
||||
|
||||
let merged_excerpts =
|
||||
zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
|
||||
.into_iter()
|
||||
.map(|excerpt| OutputExcerpt {
|
||||
path: path.clone(),
|
||||
excerpt,
|
||||
});
|
||||
|
||||
let excerpt_start_ix = excerpts.len();
|
||||
excerpts.extend(merged_excerpts);
|
||||
|
||||
write_codeblock(
|
||||
&path,
|
||||
excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
|
||||
if is_cursor_file {
|
||||
&cursor_insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
Line(excerpt_snapshot.max_point().row),
|
||||
true,
|
||||
&mut formatted_excerpts,
|
||||
);
|
||||
}
|
||||
|
||||
if !cursor_excerpt_added {
|
||||
write_codeblock(
|
||||
&cursor_snapshot.file().unwrap().full_path(cx),
|
||||
&[Excerpt {
|
||||
start_line: cursor_excerpt.line_range.start,
|
||||
text: cursor_excerpt.text(&cursor_snapshot).body.into(),
|
||||
}],
|
||||
&cursor_insertions,
|
||||
Line(cursor_snapshot.max_point().row),
|
||||
true,
|
||||
&mut formatted_excerpts,
|
||||
);
|
||||
}
|
||||
|
||||
let output = Output {
|
||||
excerpts,
|
||||
formatted_excerpts,
|
||||
meta,
|
||||
};
|
||||
|
||||
Ok(serde_json::to_string_pretty(&output)?)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn zeta1_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
|
|
@ -670,13 +460,6 @@ fn main() {
|
|||
};
|
||||
println!("{}", result.unwrap());
|
||||
}
|
||||
Zeta2Command::Llm { args, command } => match command {
|
||||
Zeta2LlmCommand::Context { context_args } => {
|
||||
let result =
|
||||
zeta2_llm_context(args, context_args, &app_state, cx).await;
|
||||
println!("{}", result.unwrap());
|
||||
}
|
||||
},
|
||||
},
|
||||
Command::ConvertExample {
|
||||
path,
|
||||
|
|
|
|||
8
crates/zeta_cli/src/paths.rs
Normal file
8
crates/zeta_cli/src/paths.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
use std::{env, path::PathBuf, sync::LazyLock};
|
||||
|
||||
static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache"));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
|
||||
pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));
|
||||
|
|
@ -1,22 +1,20 @@
|
|||
use crate::example::{ActualExcerpt, NamedExample};
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::paths::LOGS_DIR;
|
||||
use ::serde::Serialize;
|
||||
use ::util::paths::PathStyle;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::Args;
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::AsyncApp;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use project::{Project, ProjectPath};
|
||||
use project::Project;
|
||||
use serde::Deserialize;
|
||||
use std::cell::Cell;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct PredictArguments {
|
||||
|
|
@ -50,21 +48,12 @@ pub async fn zeta2_predict(
|
|||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<PredictionDetails> {
|
||||
fs::create_dir_all(&*LOGS_DIR)?;
|
||||
let worktree_path = example.setup_worktree().await?;
|
||||
|
||||
if !AUTHENTICATED.get() {
|
||||
AUTHENTICATED.set(true);
|
||||
|
||||
cx.update(|cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry
|
||||
.provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
|
||||
.unwrap()
|
||||
.authenticate(cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
app_state
|
||||
.client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
|
|
@ -83,6 +72,8 @@ pub async fn zeta2_predict(
|
|||
)
|
||||
})?;
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
|
|
@ -94,58 +85,30 @@ pub async fn zeta2_predict(
|
|||
})?
|
||||
.await;
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await?;
|
||||
|
||||
let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
|
||||
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let cursor_offset_within_excerpt = example
|
||||
.example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))?;
|
||||
let mut cursor_excerpt = example.example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let Some((excerpt_offset, _)) = matches.next() else {
|
||||
anyhow::bail!(
|
||||
"Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
|
||||
);
|
||||
};
|
||||
assert!(matches.next().is_none());
|
||||
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
|
||||
let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
|
||||
zeta2::Zeta::try_global(cx)
|
||||
.unwrap()
|
||||
.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await?;
|
||||
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
|
||||
|
||||
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
|
||||
|
||||
let refresh_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(&cursor_buffer, &project, cx);
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?;
|
||||
|
||||
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
|
||||
let mut context_retrieval_started_at = None;
|
||||
let mut context_retrieval_finished_at = None;
|
||||
let mut search_queries_generated_at = None;
|
||||
|
|
@ -159,9 +122,14 @@ pub async fn zeta2_predict(
|
|||
match event {
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
|
||||
context_retrieval_started_at = Some(info.timestamp);
|
||||
fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
|
||||
search_queries_generated_at = Some(info.timestamp);
|
||||
fs::write(
|
||||
LOGS_DIR.join("search_queries.json"),
|
||||
serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
|
||||
)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
|
||||
search_queries_executed_at = Some(info.timestamp);
|
||||
|
|
@ -173,11 +141,21 @@ pub async fn zeta2_predict(
|
|||
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?);
|
||||
}
|
||||
zeta2::ZetaDebugInfo::EditPredicted(request) => {
|
||||
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
|
||||
prediction_started_at = Some(Instant::now());
|
||||
request.response_rx.await?.map_err(|err| anyhow!(err))?;
|
||||
fs::write(
|
||||
LOGS_DIR.join("prediction_prompt.md"),
|
||||
&request.local_prompt.unwrap_or_default(),
|
||||
)?;
|
||||
|
||||
let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
prediction_finished_at = Some(Instant::now());
|
||||
|
||||
fs::write(
|
||||
LOGS_DIR.join("prediction_response.json"),
|
||||
&serde_json::to_string_pretty(&response).unwrap(),
|
||||
)?;
|
||||
|
||||
for included_file in request.request.included_files {
|
||||
let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
|
||||
result
|
||||
|
|
@ -201,7 +179,6 @@ pub async fn zeta2_predict(
|
|||
}
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue