zeta2: Build edit prediction prompt and process model output in client (#41870)

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
This commit is contained in:
Max Brunsfeld 2025-11-06 15:36:58 -08:00 committed by GitHub
parent fb87972f44
commit 784fdcaee3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 2198 additions and 2392 deletions

6
Cargo.lock generated
View file

@ -39,6 +39,7 @@ dependencies = [
"util",
"uuid",
"watch",
"zlog",
]
[[package]]
@ -3198,7 +3199,9 @@ dependencies = [
"indoc",
"ordered-float 2.10.1",
"rustc-hash 2.1.1",
"schemars 1.0.4",
"serde",
"serde_json",
"strum 0.27.2",
]
@ -21675,10 +21678,10 @@ dependencies = [
"language_model",
"log",
"lsp",
"open_ai",
"pretty_assertions",
"project",
"release_channel",
"schemars 1.0.4",
"serde",
"serde_json",
"settings",
@ -21687,6 +21690,7 @@ dependencies = [
"uuid",
"workspace",
"worktree",
"zlog",
]
[[package]]

View file

@ -56,3 +56,4 @@ rand.workspace = true
tempfile.workspace = true
util.workspace = true
settings.workspace = true
zlog.workspace = true

View file

@ -1,5 +1,4 @@
pub mod predict_edits_v3;
pub mod udiff;
use std::str::FromStr;
use std::sync::Arc;

View file

@ -1,7 +1,7 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
fmt::{Display, Write as _},
ops::{Add, Range, Sub},
path::{Path, PathBuf},
sync::Arc,
@ -11,7 +11,14 @@ use uuid::Uuid;
use crate::PredictEditsGitInfo;
// TODO: snippet ordering within file / relative to excerpt
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanContextRetrievalRequest {
pub excerpt: String,
pub excerpt_path: Arc<Path>,
pub excerpt_line_range: Range<Line>,
pub cursor_file_max_row: Line,
pub events: Vec<Event>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsRequest {
@ -125,15 +132,15 @@ impl Display for Event {
write!(
f,
"// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
old_path.display(),
new_path.display()
DiffPathFmt(old_path),
DiffPathFmt(new_path)
)
} else {
write!(
f,
"--- a/{}\n+++ b/{}\n{diff}",
old_path.display(),
new_path.display()
DiffPathFmt(old_path),
DiffPathFmt(new_path)
)
}
}
@ -141,6 +148,24 @@ impl Display for Event {
}
}
/// always format the Path as a unix path with `/` as the path sep in Diffs
pub struct DiffPathFmt<'a>(pub &'a Path);
impl<'a> std::fmt::Display for DiffPathFmt<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut is_first = true;
for component in self.0.components() {
if !is_first {
f.write_char('/')?;
} else {
is_first = false;
}
write!(f, "{}", component.as_os_str().display())?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signature {
pub text: String,

View file

@ -1,294 +0,0 @@
use std::{borrow::Cow, fmt::Display};
#[derive(Debug, PartialEq)]
pub enum DiffLine<'a> {
OldPath { path: Cow<'a, str> },
NewPath { path: Cow<'a, str> },
HunkHeader(Option<HunkLocation>),
Context(&'a str),
Deletion(&'a str),
Addition(&'a str),
Garbage(&'a str),
}
#[derive(Debug, PartialEq)]
pub struct HunkLocation {
start_line_old: u32,
count_old: u32,
start_line_new: u32,
count_new: u32,
}
impl<'a> DiffLine<'a> {
pub fn parse(line: &'a str) -> Self {
Self::try_parse(line).unwrap_or(Self::Garbage(line))
}
fn try_parse(line: &'a str) -> Option<Self> {
if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
let path = parse_header_path("a/", header);
Some(Self::OldPath { path })
} else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
Some(Self::NewPath {
path: parse_header_path("b/", header),
})
} else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
if header.starts_with("...") {
return Some(Self::HunkHeader(None));
}
let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
let mut parts = header.split_ascii_whitespace();
let count_old = parts.next()?;
let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
Some(Self::HunkHeader(Some(HunkLocation {
start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
count_old: count_old.parse().ok()?,
start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
count_new: count_new.parse().ok()?,
})))
} else if let Some(deleted_header) = line.strip_prefix("-") {
Some(Self::Deletion(deleted_header))
} else if line.is_empty() {
Some(Self::Context(""))
} else if let Some(context) = line.strip_prefix(" ") {
Some(Self::Context(context))
} else {
Some(Self::Addition(line.strip_prefix("+")?))
}
}
}
impl<'a> Display for DiffLine<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DiffLine::OldPath { path } => write!(f, "--- {path}"),
DiffLine::NewPath { path } => write!(f, "+++ {path}"),
DiffLine::HunkHeader(Some(hunk_location)) => {
write!(
f,
"@@ -{},{} +{},{} @@",
hunk_location.start_line_old + 1,
hunk_location.count_old,
hunk_location.start_line_new + 1,
hunk_location.count_new
)
}
DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
DiffLine::Context(content) => write!(f, " {content}"),
DiffLine::Deletion(content) => write!(f, "-{content}"),
DiffLine::Addition(content) => write!(f, "+{content}"),
DiffLine::Garbage(line) => write!(f, "{line}"),
}
}
}
fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
if !header.contains(['"', '\\']) {
let path = header.split_ascii_whitespace().next().unwrap_or(header);
return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
}
let mut path = String::with_capacity(header.len());
let mut in_quote = false;
let mut chars = header.chars().peekable();
let mut strip_prefix = Some(strip_prefix);
while let Some(char) = chars.next() {
if char == '"' {
in_quote = !in_quote;
} else if char == '\\' {
let Some(&next_char) = chars.peek() else {
break;
};
chars.next();
path.push(next_char);
} else if char.is_ascii_whitespace() && !in_quote {
break;
} else {
path.push(char);
}
if let Some(prefix) = strip_prefix
&& path == prefix
{
strip_prefix.take();
path.clear();
}
}
Cow::Owned(path)
}
fn eat_required_whitespace(header: &str) -> Option<&str> {
let trimmed = header.trim_ascii_start();
if trimmed.len() == header.len() {
None
} else {
Some(trimmed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
#[test]
fn parse_lines_simple() {
let input = indoc! {"
diff --git a/text.txt b/text.txt
index 86c770d..a1fd855 100644
--- a/file.txt
+++ b/file.txt
@@ -1,2 +1,3 @@
context
-deleted
+inserted
garbage
--- b/file.txt
+++ a/file.txt
"};
let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
pretty_assertions::assert_eq!(
lines,
&[
DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
DiffLine::Garbage("index 86c770d..a1fd855 100644"),
DiffLine::OldPath {
path: "file.txt".into()
},
DiffLine::NewPath {
path: "file.txt".into()
},
DiffLine::HunkHeader(Some(HunkLocation {
start_line_old: 0,
count_old: 2,
start_line_new: 0,
count_new: 3
})),
DiffLine::Context("context"),
DiffLine::Deletion("deleted"),
DiffLine::Addition("inserted"),
DiffLine::Garbage("garbage"),
DiffLine::Context(""),
DiffLine::OldPath {
path: "b/file.txt".into()
},
DiffLine::NewPath {
path: "a/file.txt".into()
},
]
);
}
#[test]
fn file_header_extra_space() {
let options = ["--- file", "--- file", "---\tfile"];
for option in options {
pretty_assertions::assert_eq!(
DiffLine::parse(option),
DiffLine::OldPath {
path: "file".into()
},
"{option}",
);
}
}
#[test]
fn hunk_header_extra_space() {
let options = [
"@@ -1,2 +1,3 @@",
"@@ -1,2 +1,3 @@",
"@@\t-1,2\t+1,3\t@@",
"@@ -1,2 +1,3 @@",
"@@ -1,2 +1,3 @@",
"@@ -1,2 +1,3 @@",
"@@ -1,2 +1,3 @@ garbage",
];
for option in options {
pretty_assertions::assert_eq!(
DiffLine::parse(option),
DiffLine::HunkHeader(Some(HunkLocation {
start_line_old: 0,
count_old: 2,
start_line_new: 0,
count_new: 3
})),
"{option}",
);
}
}
#[test]
fn hunk_header_without_location() {
pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
}
#[test]
fn test_parse_path() {
assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
assert_eq!(
parse_header_path("a/", "foo/bar/baz.txt"),
"foo/bar/baz.txt"
);
assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
assert_eq!(
parse_header_path("a/", "a/foo/bar/baz.txt"),
"foo/bar/baz.txt"
);
// Extra
assert_eq!(
parse_header_path("a/", "a/foo/bar/baz.txt 2025"),
"foo/bar/baz.txt"
);
assert_eq!(
parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
"foo/bar/baz.txt"
);
assert_eq!(
parse_header_path("a/", "a/foo/bar/baz.txt \""),
"foo/bar/baz.txt"
);
// Quoted
assert_eq!(
parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
"foo/bar/baz quox.txt"
);
assert_eq!(
parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
"foo/bar/baz quox.txt"
);
assert_eq!(
parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
"foo/bar/baz quox.txt"
);
assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷");
assert_eq!(
parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"),
"foo/bar/baz quox.txt"
);
// unescaped quotes are dropped
assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
// Escaped
assert_eq!(
parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
"foo/\"bar\"/baz.txt"
);
assert_eq!(
parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
"C:\\Projects\\My App\\old file.txt"
);
}
}

View file

@ -17,5 +17,7 @@ cloud_llm_client.workspace = true
indoc.workspace = true
ordered-float.workspace = true
rustc-hash.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
strum.workspace = true

View file

@ -1,8 +1,9 @@
//! Zeta2 prompt planning and generation code shared with cloud.
pub mod retrieval_prompt;
use anyhow::{Context as _, Result, anyhow};
use cloud_llm_client::predict_edits_v3::{
self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
self, DiffPathFmt, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
};
use indoc::indoc;
use ordered_float::OrderedFloat;
@ -212,7 +213,7 @@ pub fn write_codeblock<'a>(
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````{}", path.display()).unwrap();
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
@ -275,7 +276,7 @@ pub fn write_excerpts<'a>(
}
}
fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
if events.is_empty() {
return;
};

View file

@ -0,0 +1,92 @@
use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
use indoc::indoc;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, sync::LazyLock};
use crate::{push_events, write_codeblock};
pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
let mut prompt = SEARCH_INSTRUCTIONS.to_string();
if !request.events.is_empty() {
writeln!(&mut prompt, "## User Edits\n")?;
push_events(&mut prompt, &request.events);
}
writeln!(&mut prompt, "## Excerpt around the cursor\n")?;
write_codeblock(
&request.excerpt_path,
&[Excerpt {
start_line: request.excerpt_line_range.start,
text: request.excerpt.into(),
}],
&[],
request.cursor_file_max_row,
true,
&mut prompt,
);
writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
Ok(prompt)
}
/// Search for relevant code
///
/// For the best results, run multiple queries at once with a single invocation of this tool.
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
pub struct SearchToolInput {
/// An array of queries to run for gathering context relevant to the next prediction
#[schemars(length(max = 5))]
pub queries: Box<[SearchToolQuery]>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SearchToolQuery {
/// A glob pattern to match file paths in the codebase
pub glob: String,
/// A regular expression to match content within the files matched by the glob pattern
pub regex: String,
}
pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
let schema = schemars::schema_for!(SearchToolInput);
let description = schema
.get("description")
.and_then(|description| description.as_str())
.unwrap()
.to_string();
(schema.into(), description)
});
pub const TOOL_NAME: &str = "search";
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
## Task
You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
that will serve as context for predicting the next required edit.
**Your task:**
- Analyze the user's recent edits and current cursor context
- Use the `search` tool to find code that may be relevant for predicting the next edit
- Focus on finding:
- Code patterns that might need similar changes based on the recent edits
- Functions, variables, types, and constants referenced in the current cursor context
- Related implementations, usages, or dependencies that may require consistent updates
**Important constraints:**
- This conversation has exactly 2 turns
- You must make ALL search queries in your first response via the `search` tool
- All queries will be executed in parallel and results returned together
- In the second turn, you will select the most relevant results via the `select` tool.
"#};
const TOOL_USE_REMINDER: &str = indoc! {"
--
Use the `search` tool now
"};

View file

@ -34,7 +34,7 @@ struct CurrentCompletion {
snapshot: BufferSnapshot,
/// The edits that should be applied to transform the original text into the predicted text.
/// Each edit is a range in the buffer and the text to replace it with.
edits: Arc<[(Range<Anchor>, String)]>,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
/// Preview of how the buffer will look after applying the edits.
edit_preview: EditPreview,
}
@ -42,7 +42,7 @@ struct CurrentCompletion {
impl CurrentCompletion {
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
/// Returns None if the user's edits conflict with the predicted edits.
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
@ -281,8 +281,8 @@ impl EditPredictionProvider for CodestralCompletionProvider {
return Ok(());
}
let edits: Arc<[(Range<Anchor>, String)]> =
vec![(cursor_position..cursor_position, completion_text)].into();
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
vec![(cursor_position..cursor_position, completion_text.into())].into();
let edit_preview = buffer
.read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
.await;

View file

@ -1,4 +1,4 @@
use std::ops::Range;
use std::{ops::Range, sync::Arc};
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
@ -19,7 +19,7 @@ pub enum EditPrediction {
/// Edits within the buffer that requested the prediction
Local {
id: Option<SharedString>,
edits: Vec<(Range<language::Anchor>, String)>,
edits: Vec<(Range<language::Anchor>, Arc<str>)>,
edit_preview: Option<language::EditPreview>,
},
/// Jump to a different file from the one that requested the prediction
@ -248,8 +248,8 @@ where
pub fn interpolate_edits(
old_snapshot: &BufferSnapshot,
new_snapshot: &BufferSnapshot,
current_edits: &[(Range<Anchor>, String)],
) -> Option<Vec<(Range<Anchor>, String)>> {
current_edits: &[(Range<Anchor>, Arc<str>)],
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
let mut edits = Vec::new();
let mut model_edits = current_edits.iter().peekable();
@ -274,7 +274,7 @@ pub fn interpolate_edits(
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
if !model_suffix.is_empty() {
let anchor = old_snapshot.anchor_after(user_edit.old.end);
edits.push((anchor..anchor, model_suffix.to_string()));
edits.push((anchor..anchor, model_suffix.into()));
}
model_edits.next();

View file

@ -2,7 +2,7 @@ use edit_prediction::EditPredictionProvider;
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
use std::ops::Range;
use std::{ops::Range, sync::Arc};
use text::{Point, ToOffset};
use crate::{
@ -24,7 +24,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
assert_editor_active_edit_completion(&mut cx, |_, edits| {
assert_eq!(edits.len(), 1);
assert_eq!(edits[0].1.as_str(), "-273.15");
assert_eq!(edits[0].1.as_ref(), "-273.15");
});
accept_completion(&mut cx);
@ -46,7 +46,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
assert_editor_active_edit_completion(&mut cx, |_, edits| {
assert_eq!(edits.len(), 1);
assert_eq!(edits[0].1.as_str(), "3.14159");
assert_eq!(edits[0].1.as_ref(), "3.14159");
});
accept_completion(&mut cx);
@ -330,7 +330,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
fn assert_editor_active_edit_completion(
cx: &mut EditorTestContext,
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, String)>),
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, Arc<str>)>),
) {
cx.editor(|editor, _, cx| {
let completion_state = editor

View file

@ -616,7 +616,7 @@ pub(crate) enum EditDisplayMode {
enum EditPrediction {
Edit {
edits: Vec<(Range<Anchor>, String)>,
edits: Vec<(Range<Anchor>, Arc<str>)>,
edit_preview: Option<EditPreview>,
display_mode: EditDisplayMode,
snapshot: BufferSnapshot,
@ -7960,7 +7960,7 @@ impl Editor {
let inlay = Inlay::edit_prediction(
post_inc(&mut self.next_inlay_id),
range.start,
new_text.as_str(),
new_text.as_ref(),
);
inlay_ids.push(inlay.id);
inlays.push(inlay);
@ -8982,7 +8982,7 @@ impl Editor {
newest_selection_head: Option<DisplayPoint>,
editor_width: Pixels,
style: &EditorStyle,
edits: &Vec<(Range<Anchor>, String)>,
edits: &Vec<(Range<Anchor>, Arc<str>)>,
edit_preview: &Option<language::EditPreview>,
snapshot: &language::BufferSnapshot,
window: &mut Window,
@ -24382,25 +24382,20 @@ impl InvalidationRegion for SnippetState {
fn edit_prediction_edit_text(
current_snapshot: &BufferSnapshot,
edits: &[(Range<Anchor>, String)],
edits: &[(Range<Anchor>, impl AsRef<str>)],
edit_preview: &EditPreview,
include_deletions: bool,
cx: &App,
) -> HighlightedText {
let edits = edits
.iter()
.map(|(anchor, text)| {
(
anchor.start.text_anchor..anchor.end.text_anchor,
text.clone(),
)
})
.map(|(anchor, text)| (anchor.start.text_anchor..anchor.end.text_anchor, text))
.collect::<Vec<_>>();
edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx)
}
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText {
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, Arc<str>)], cx: &App) -> HighlightedText {
// Fallback for providers that don't provide edit_preview (like Copilot/Supermaven)
// Just show the raw edit text with basic styling
let mut text = String::new();
@ -24793,7 +24788,7 @@ impl Focusable for BreakpointPromptEditor {
}
fn all_edits_insertions_or_deletions(
edits: &Vec<(Range<Anchor>, String)>,
edits: &Vec<(Range<Anchor>, Arc<str>)>,
snapshot: &MultiBufferSnapshot,
) -> bool {
let mut all_insertions = true;

View file

@ -22915,7 +22915,7 @@ async fn assert_highlighted_edits(
let text_anchor_edits = edits
.clone()
.into_iter()
.map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit))
.map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit.into()))
.collect::<Vec<_>>();
let edit_preview = window

View file

@ -720,7 +720,7 @@ impl EditPreview {
pub fn highlight_edits(
&self,
current_snapshot: &BufferSnapshot,
edits: &[(Range<Anchor>, String)],
edits: &[(Range<Anchor>, impl AsRef<str>)],
include_deletions: bool,
cx: &App,
) -> HighlightedText {
@ -747,7 +747,8 @@ impl EditPreview {
.end
.bias_right(&self.old_snapshot)
.to_offset(&self.applied_edits_snapshot);
let edit_start_in_preview_snapshot = edit_new_end_in_preview_snapshot - edit_text.len();
let edit_start_in_preview_snapshot =
edit_new_end_in_preview_snapshot - edit_text.as_ref().len();
let unchanged_range_in_preview_snapshot =
offset_in_preview_snapshot..edit_start_in_preview_snapshot;
@ -772,7 +773,7 @@ impl EditPreview {
);
}
if !edit_text.is_empty() {
if !edit_text.as_ref().is_empty() {
highlighted_text.add_text_from_buffer_range(
edit_start_in_preview_snapshot..edit_new_end_in_preview_snapshot,
&self.applied_edits_snapshot,
@ -796,7 +797,7 @@ impl EditPreview {
highlighted_text.build()
}
fn compute_visible_range(&self, edits: &[(Range<Anchor>, String)]) -> Option<Range<usize>> {
fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<usize>> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
@ -1130,7 +1131,7 @@ impl Buffer {
pub fn preview_edits(
&self,
edits: Arc<[(Range<Anchor>, String)]>,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
cx: &App,
) -> Task<EditPreview> {
let registry = self.language_registry();

View file

@ -3120,15 +3120,13 @@ async fn test_preview_edits(cx: &mut TestAppContext) {
.map(|(range, text)| {
(
buffer.anchor_before(range.start)..buffer.anchor_after(range.end),
text.to_string(),
text.into(),
)
})
.collect::<Vec<_>>()
.collect::<Arc<[_]>>()
});
let edit_preview = buffer
.read_with(cx, |buffer, cx| {
buffer.preview_edits(edits.clone().into(), cx)
})
.read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
.await;
let highlighted_edits = cx.read(|cx| {
edit_preview.highlight_edits(&buffer.read(cx).snapshot(), &edits, include_deletions, cx)

View file

@ -293,7 +293,7 @@ pub struct FunctionDefinition {
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {
Assistant {
@ -366,25 +366,42 @@ pub struct ImageUrl {
pub detail: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Response {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Choice {
pub index: u32,
pub message: RequestMessage,
pub finish_reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessageDelta {
pub role: Option<Role>,
@ -410,7 +427,7 @@ pub struct FunctionChunk {
pub arguments: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,

View file

@ -7,6 +7,7 @@ use language::{Anchor, Buffer, BufferSnapshot};
use std::{
ops::{AddAssign, Range},
path::Path,
sync::Arc,
time::Duration,
};
use text::{ToOffset, ToPoint};
@ -51,7 +52,7 @@ fn completion_from_diff(
) -> EditPrediction {
let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
let mut edits: Vec<(Range<language::Anchor>, String)> = Vec::new();
let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
@ -70,7 +71,10 @@ fn completion_from_diff(
if k != 0 {
let offset = snapshot.anchor_after(offset);
// the range from the current position to item is an inlay.
let edit = (offset..offset, completion_graphemes[i..i + k].join(""));
let edit = (
offset..offset,
completion_graphemes[i..i + k].join("").into(),
);
edits.push(edit);
}
i += k + 1;
@ -90,7 +94,7 @@ fn completion_from_diff(
// there is leftover completion text, so drop it as an inlay.
let edit_range = offset..offset;
let edit_text = completion_graphemes[i..].join("");
edits.push((edit_range, edit_text));
edits.push((edit_range, edit_text.into()));
}
EditPrediction::Local {

View file

@ -133,7 +133,7 @@ pub struct EditPrediction {
path: Arc<Path>,
excerpt_range: Range<usize>,
cursor_offset: usize,
edits: Arc<[(Range<Anchor>, String)]>,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
snapshot: BufferSnapshot,
edit_preview: EditPreview,
input_outline: Arc<str>,
@ -150,7 +150,7 @@ impl EditPrediction {
.duration_since(self.buffer_snapshotted_at)
}
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
@ -711,7 +711,7 @@ impl Zeta {
cx.spawn(async move |cx| {
let output_excerpt: Arc<str> = output_excerpt.into();
let edits: Arc<[(Range<Anchor>, String)]> = cx
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
.background_spawn({
let output_excerpt = output_excerpt.clone();
let editable_range = editable_range.clone();
@ -725,7 +725,7 @@ impl Zeta {
let edits = edits.clone();
move |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[(Range<Anchor>, String)]> =
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
@ -759,7 +759,7 @@ impl Zeta {
output_excerpt: Arc<str>,
editable_range: Range<usize>,
snapshot: &BufferSnapshot,
) -> Result<Vec<(Range<Anchor>, String)>> {
) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
let start_markers = content
@ -817,7 +817,7 @@ impl Zeta {
new_text: &str,
offset: usize,
snapshot: &BufferSnapshot,
) -> Vec<(Range<Anchor>, String)> {
) -> Vec<(Range<Anchor>, Arc<str>)> {
text_diff(&old_text, new_text)
.into_iter()
.map(|(mut old_range, new_text)| {
@ -836,7 +836,7 @@ impl Zeta {
);
old_range.end = old_range.end.saturating_sub(suffix_len);
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
let range = if old_range.is_empty() {
let anchor = snapshot.anchor_after(old_range.start);
anchor..anchor
@ -1183,7 +1183,7 @@ impl CurrentEditPrediction {
if old_edits.len() == 1 && new_edits.len() == 1 {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
new_range == old_range && new_text.starts_with(old_text)
new_range == old_range && new_text.starts_with(old_text.as_ref())
} else {
true
}
@ -1599,13 +1599,8 @@ mod tests {
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
to_completion_edits(
[(2..5, "REM".to_string()), (9..11, "".to_string())],
&buffer,
cx,
)
.into()
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
@ -1635,7 +1630,7 @@ mod tests {
&buffer,
cx
),
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
@ -1645,7 +1640,7 @@ mod tests {
&buffer,
cx
),
vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
@ -1655,7 +1650,7 @@ mod tests {
&buffer,
cx
),
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
@ -1665,7 +1660,7 @@ mod tests {
&buffer,
cx
),
vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
@ -1675,7 +1670,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
@ -1685,7 +1680,7 @@ mod tests {
&buffer,
cx
),
vec![(9..11, "".to_string())]
vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
@ -1695,7 +1690,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
@ -1705,7 +1700,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string())]
vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
@ -2211,10 +2206,10 @@ mod tests {
}
fn to_completion_edits(
iterator: impl IntoIterator<Item = (Range<usize>, String)>,
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<Anchor>, String)> {
) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
@ -2228,10 +2223,10 @@ mod tests {
}
fn from_completion_edits(
editor_edits: &[(Range<Anchor>, String)],
editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<usize>, String)> {
) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()

View file

@ -28,9 +28,9 @@ indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
open_ai.workspace = true
project.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
@ -50,3 +50,4 @@ language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View file

@ -1,17 +1,11 @@
use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
use std::{ops::Range, sync::Arc};
use anyhow::Context as _;
use cloud_llm_client::predict_edits_v3;
use gpui::{App, AsyncApp, Entity};
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
};
use project::Project;
use util::ResultExt;
use gpui::{AsyncApp, Entity};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
use uuid::Uuid;
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(Uuid);
pub struct EditPredictionId(pub Uuid);
impl Into<Uuid> for EditPredictionId {
fn into(self) -> Uuid {
@ -34,8 +28,7 @@ impl std::fmt::Display for EditPredictionId {
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
pub path: Arc<Path>,
pub edits: Arc<[(Range<Anchor>, String)]>,
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
@ -43,90 +36,43 @@ pub struct EditPrediction {
}
impl EditPrediction {
pub async fn from_response(
response: predict_edits_v3::PredictEditsResponse,
active_buffer_old_snapshot: &TextBufferSnapshot,
active_buffer: &Entity<Buffer>,
project: &Entity<Project>,
pub async fn new(
id: EditPredictionId,
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
edits: Vec<(Range<Anchor>, Arc<str>)>,
cx: &mut AsyncApp,
) -> Option<Self> {
// TODO only allow cloud to return one path
let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
return None;
};
let (edits, snapshot, edit_preview_task) = edited_buffer
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
let is_same_path = active_buffer
.read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
.ok()?;
let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
active_buffer
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
let edits: Arc<[_]> =
interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
Some((
active_buffer.clone(),
edits.clone(),
new_snapshot,
buffer.preview_edits(edits, cx),
))
})
.ok()??
} else {
let buffer_handle = project
.update(cx, |project, cx| {
let project_path = project
.find_project_path(&path, cx)
.context("Failed to find project path for zeta edit")?;
anyhow::Ok(project.open_buffer(project_path, cx))
})
.ok()?
.log_err()?
.await
.context("Failed to open buffer for zeta edit")
.log_err()?;
buffer_handle
.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let edits = edits_from_response(&response.edits, &snapshot);
if edits.is_empty() {
return None;
}
Some((
buffer_handle.clone(),
edits.clone(),
snapshot,
buffer.preview_edits(edits, cx),
))
})
.ok()??
};
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})
.ok()??;
let edit_preview = edit_preview_task.await;
Some(EditPrediction {
id: EditPredictionId(response.request_id),
path,
id,
edits,
snapshot,
edit_preview,
buffer,
buffer: edited_buffer.clone(),
})
}
pub fn interpolate(
&self,
new_snapshot: &TextBufferSnapshot,
) -> Option<Vec<(Range<Anchor>, String)>> {
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
}
pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
buffer_path_eq(buffer, &self.path, cx)
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
self.snapshot.remote_id() == buffer.remote_id()
}
}
@ -134,21 +80,16 @@ impl std::fmt::Debug for EditPrediction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EditPrediction")
.field("id", &self.id)
.field("path", &self.path)
.field("edits", &self.edits)
.finish()
}
}
pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
}
pub fn interpolate_edits(
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
current_edits: Arc<[(Range<Anchor>, String)]>,
) -> Option<Vec<(Range<Anchor>, String)>> {
current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
let mut edits = Vec::new();
let mut model_edits = current_edits.iter().peekable();
@ -173,7 +114,7 @@ pub fn interpolate_edits(
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
if !model_suffix.is_empty() {
let anchor = old_snapshot.anchor_after(user_edit.old.end);
edits.push((anchor..anchor, model_suffix.to_string()));
edits.push((anchor..anchor, model_suffix.into()));
}
model_edits.next();
@ -190,135 +131,17 @@ pub fn interpolate_edits(
if edits.is_empty() { None } else { Some(edits) }
}
pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
}
fn edits_from_response(
edits: &[predict_edits_v3::Edit],
snapshot: &TextBufferSnapshot,
) -> Arc<[(Range<Anchor>, String)]> {
edits
.iter()
.flat_map(|edit| {
let point_range = line_range_to_point_range(edit.range.clone());
let offset = point_range.to_offset(snapshot).start;
let old_text = snapshot.text_for_range(point_range);
excerpt_edits_from_response(
old_text.collect::<Cow<str>>(),
&edit.content,
offset,
&snapshot,
)
})
.collect::<Vec<_>>()
.into()
}
fn excerpt_edits_from_response(
old_text: Cow<str>,
new_text: &str,
offset: usize,
snapshot: &TextBufferSnapshot,
) -> impl Iterator<Item = (Range<Anchor>, String)> {
text_diff(&old_text, new_text)
.into_iter()
.map(move |(mut old_range, new_text)| {
old_range.start += offset;
old_range.end += offset;
let prefix_len = common_prefix(
snapshot.chars_for_range(old_range.clone()),
new_text.chars(),
);
old_range.start += prefix_len;
let suffix_len = common_prefix(
snapshot.reversed_chars_for_range(old_range.clone()),
new_text[prefix_len..].chars().rev(),
);
old_range.end = old_range.end.saturating_sub(suffix_len);
let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
let range = if old_range.is_empty() {
let anchor = snapshot.anchor_after(old_range.start);
anchor..anchor
} else {
snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
};
(range, new_text)
})
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
a.zip(b)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.len_utf8())
.sum()
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::*;
use cloud_llm_client::predict_edits_v3;
use edit_prediction_context::Line;
use gpui::{App, Entity, TestAppContext, prelude::*};
use indoc::indoc;
use language::{Buffer, ToOffset as _};
#[gpui::test]
async fn test_compute_edits(cx: &mut TestAppContext) {
let old = indoc! {r#"
fn main() {
let args =
println!("{}", args[1])
}
"#};
let new = indoc! {r#"
fn main() {
let args = std::env::args();
println!("{}", args[1]);
}
"#};
let buffer = cx.new(|cx| Buffer::local(old, cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
// TODO cover more cases when multi-file is supported
let big_edits = vec![predict_edits_v3::Edit {
path: PathBuf::from("test.txt").into(),
range: Line(0)..Line(old.lines().count() as u32),
content: new.into(),
}];
let edits = edits_from_response(&big_edits, &snapshot);
assert_eq!(edits.len(), 2);
assert_eq!(
edits[0].0.to_point(&snapshot).start,
language::Point::new(1, 14)
);
assert_eq!(edits[0].1, " std::env::args();");
assert_eq!(
edits[1].0.to_point(&snapshot).start,
language::Point::new(2, 27)
);
assert_eq!(edits[1].1, ";");
}
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
to_prediction_edits(
[(2..5, "REM".to_string()), (9..11, "".to_string())],
&buffer,
cx,
)
.into()
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
@ -329,7 +152,6 @@ mod tests {
id: EditPredictionId(Uuid::new_v4()),
edits,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
path: Path::new("test.txt").into(),
buffer: buffer.clone(),
edit_preview,
};
@ -341,7 +163,7 @@ mod tests {
&buffer,
cx
),
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
@ -351,7 +173,7 @@ mod tests {
&buffer,
cx
),
vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
@ -361,7 +183,7 @@ mod tests {
&buffer,
cx
),
vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
@ -371,7 +193,7 @@ mod tests {
&buffer,
cx
),
vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
@ -381,7 +203,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
@ -391,7 +213,7 @@ mod tests {
&buffer,
cx
),
vec![(9..11, "".to_string())]
vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
@ -401,7 +223,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string()), (8..10, "".to_string())]
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
@ -411,7 +233,7 @@ mod tests {
&buffer,
cx
),
vec![(4..4, "M".to_string())]
vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
@ -420,10 +242,10 @@ mod tests {
}
fn to_prediction_edits(
iterator: impl IntoIterator<Item = (Range<usize>, String)>,
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<Anchor>, String)> {
) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
@ -437,10 +259,10 @@ mod tests {
}
fn from_prediction_edits(
editor_edits: &[(Range<Anchor>, String)],
editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<usize>, String)> {
) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()

View file

@ -1,717 +0,0 @@
use std::{
cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
};
use crate::{
ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
};
use anyhow::{Result, anyhow};
use cloud_zeta2_prompt::write_codeblock;
use collections::HashMap;
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
use futures::{
StreamExt,
channel::mpsc::{self, UnboundedSender},
stream::BoxStream,
};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use indoc::indoc;
use language::{
Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUse, MessageContent, Role,
};
use project::{
Project, WorktreeSettings,
search::{SearchQuery, SearchResult},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::{
ResultExt as _,
paths::{PathMatcher, PathStyle},
};
use workspace::item::Settings as _;
const SEARCH_PROMPT: &str = indoc! {r#"
## Task
You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
that will serve as context for predicting the next required edit.
**Your task:**
- Analyze the user's recent edits and current cursor context
- Use the `search` tool to find code that may be relevant for predicting the next edit
- Focus on finding:
- Code patterns that might need similar changes based on the recent edits
- Functions, variables, types, and constants referenced in the current cursor context
- Related implementations, usages, or dependencies that may require consistent updates
**Important constraints:**
- This conversation has exactly 2 turns
- You must make ALL search queries in your first response via the `search` tool
- All queries will be executed in parallel and results returned together
- In the second turn, you will select the most relevant results via the `select` tool.
## User Edits
{edits}
## Current cursor context
`````{current_file_path}
{cursor_excerpt}
`````
--
Use the `search` tool now
"#};
const SEARCH_TOOL_NAME: &str = "search";
/// Search for relevant code
///
/// For the best results, run multiple queries at once with a single invocation of this tool.
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
pub struct SearchToolInput {
/// An array of queries to run for gathering context relevant to the next prediction
#[schemars(length(max = 5))]
pub queries: Box<[SearchToolQuery]>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SearchToolQuery {
/// A glob pattern to match file paths in the codebase
pub glob: String,
/// A regular expression to match content within the files matched by the glob pattern
pub regex: String,
}
const RESULTS_MESSAGE: &str = indoc! {"
Here are the results of your queries combined and grouped by file:
"};
const SELECT_TOOL_NAME: &str = "select";
const SELECT_PROMPT: &str = indoc! {"
Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
Include up to 200 lines in total.
"};
/// Select line ranges from search results
#[derive(Deserialize, JsonSchema)]
struct SelectToolInput {
/// The line ranges to select from search results.
ranges: Vec<SelectLineRange>,
}
/// A specific line range to select from a file
#[derive(Debug, Deserialize, JsonSchema)]
struct SelectLineRange {
/// The file path containing the lines to select
/// Exactly as it appears in the search result codeblocks.
path: PathBuf,
/// The starting line number (1-based)
#[schemars(range(min = 1))]
start_line: u32,
/// The ending line number (1-based, inclusive)
#[schemars(range(min = 1))]
end_line: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LlmContextOptions {
pub excerpt: EditPredictionExcerptOptions,
}
pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
pub fn find_related_excerpts(
buffer: Entity<language::Buffer>,
cursor_position: Anchor,
project: &Entity<Project>,
mut edit_history_unified_diff: String,
options: &LlmContextOptions,
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
cx: &App,
) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
let language_model_registry = LanguageModelRegistry::global(cx);
let Some(model) = language_model_registry
.read(cx)
.available_models(cx)
.find(|model| {
model.provider_id() == MODEL_PROVIDER_ID
&& model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
// model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
// model.provider_id() == LanguageModelProviderId::new("ollama")
// && model.id() == LanguageModelId("gpt-oss:20b".into())
})
else {
return Task::ready(Err(anyhow!("could not find context model")));
};
if edit_history_unified_diff.is_empty() {
edit_history_unified_diff.push_str("(No user edits yet)");
}
// TODO [zeta2] include breadcrumbs?
let snapshot = buffer.read(cx).snapshot();
let cursor_point = cursor_position.to_point(&snapshot);
let Some(cursor_excerpt) =
EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
else {
return Task::ready(Ok(HashMap::default()));
};
let current_file_path = snapshot
.file()
.map(|f| f.full_path(cx).display().to_string())
.unwrap_or_else(|| "untitled".to_string());
let prompt = SEARCH_PROMPT
.replace("{edits}", &edit_history_unified_diff)
.replace("{current_file_path}", &current_file_path)
.replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
ZetaContextRetrievalStartedDebugInfo {
project: project.clone(),
timestamp: Instant::now(),
search_prompt: prompt.clone(),
},
))
.ok();
}
let path_style = project.read(cx).path_style(cx);
let exclude_matcher = {
let global_settings = WorktreeSettings::get_global(cx);
let exclude_patterns = global_settings
.file_scan_exclusions
.sources()
.iter()
.chain(global_settings.private_files.sources().iter());
match PathMatcher::new(exclude_patterns, path_style) {
Ok(matcher) => matcher,
Err(err) => {
return Task::ready(Err(anyhow!(err)));
}
}
};
let project = project.clone();
cx.spawn(async move |cx| {
let initial_prompt_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
};
let mut search_stream = request_tool_call::<SearchToolInput>(
vec![initial_prompt_message.clone()],
SEARCH_TOOL_NAME,
&model,
cx,
)
.await?;
let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
select_request_messages.push(initial_prompt_message);
let mut regex_by_glob: HashMap<String, String> = HashMap::default();
let mut search_calls = Vec::new();
while let Some(event) = search_stream.next().await {
match event? {
LanguageModelCompletionEvent::ToolUse(tool_use) => {
if !tool_use.is_input_complete {
continue;
}
if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
let input =
serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
for query in input.queries {
let regex = regex_by_glob.entry(query.glob).or_default();
if !regex.is_empty() {
regex.push('|');
}
regex.push_str(&query.regex);
}
search_calls.push(tool_use);
} else {
log::warn!(
"context gathering model tried to use unknown tool: {}",
tool_use.name
);
}
}
LanguageModelCompletionEvent::Text(txt) => {
if let Some(LanguageModelRequestMessage {
role: Role::Assistant,
content,
..
}) = select_request_messages.last_mut()
{
if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
existing_text.push_str(&txt);
} else {
content.push(MessageContent::Text(txt));
}
} else {
select_request_messages.push(LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::Text(txt)],
cache: false,
});
}
}
LanguageModelCompletionEvent::Thinking { text, signature } => {
if let Some(LanguageModelRequestMessage {
role: Role::Assistant,
content,
..
}) = select_request_messages.last_mut()
{
if let Some(MessageContent::Thinking {
text: existing_text,
signature: existing_signature,
}) = content.last_mut()
{
existing_text.push_str(&text);
*existing_signature = signature;
} else {
content.push(MessageContent::Thinking { text, signature });
}
} else {
select_request_messages.push(LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::Thinking { text, signature }],
cache: false,
});
}
}
LanguageModelCompletionEvent::RedactedThinking { data } => {
if let Some(LanguageModelRequestMessage {
role: Role::Assistant,
content,
..
}) = select_request_messages.last_mut()
{
if let Some(MessageContent::RedactedThinking(existing_data)) =
content.last_mut()
{
existing_data.push_str(&data);
} else {
content.push(MessageContent::RedactedThinking(data));
}
} else {
select_request_messages.push(LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::RedactedThinking(data)],
cache: false,
});
}
}
ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
log::error!("{ev:?}");
}
ev => {
log::trace!("context search event: {ev:?}")
}
}
}
let search_tool_use = if search_calls.is_empty() {
log::warn!("context model ran 0 searches");
return anyhow::Ok(Default::default());
} else if search_calls.len() == 1 {
search_calls.swap_remove(0)
} else {
// In theory, the model could perform multiple search calls
// Dealing with them separately is not worth it when it doesn't happen in practice.
// If it were to happen, here we would combine them into one.
// The second request doesn't need to know it was actually two different calls ;)
let input = serde_json::to_value(&SearchToolInput {
queries: regex_by_glob
.iter()
.map(|(glob, regex)| SearchToolQuery {
glob: glob.clone(),
regex: regex.clone(),
})
.collect(),
})
.unwrap_or_default();
LanguageModelToolUse {
id: search_calls.swap_remove(0).id,
name: SELECT_TOOL_NAME.into(),
raw_input: serde_json::to_string(&input).unwrap_or_default(),
input,
is_input_complete: true,
}
};
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
ZetaSearchQueryDebugInfo {
project: project.clone(),
timestamp: Instant::now(),
queries: regex_by_glob
.iter()
.map(|(glob, regex)| SearchToolQuery {
glob: glob.clone(),
regex: regex.clone(),
})
.collect(),
},
))
.ok();
}
let (results_tx, mut results_rx) = mpsc::unbounded();
for (glob, regex) in regex_by_glob {
let exclude_matcher = exclude_matcher.clone();
let results_tx = results_tx.clone();
let project = project.clone();
cx.spawn(async move |cx| {
run_query(
&glob,
&regex,
results_tx.clone(),
path_style,
exclude_matcher,
&project,
cx,
)
.await
.log_err();
})
.detach()
}
drop(results_tx);
struct ResultBuffer {
buffer: Entity<Buffer>,
snapshot: TextBufferSnapshot,
}
let (result_buffers_by_path, merged_result) = cx
.background_spawn(async move {
let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
HashMap::default();
while let Some((buffer, matched)) = results_rx.next().await {
match excerpts_by_buffer.entry(buffer) {
Entry::Occupied(mut entry) => {
let entry = entry.get_mut();
entry.full_path = matched.full_path;
entry.snapshot = matched.snapshot;
entry.line_ranges.extend(matched.line_ranges);
}
Entry::Vacant(entry) => {
entry.insert(matched);
}
}
}
let mut result_buffers_by_path = HashMap::default();
let mut merged_result = RESULTS_MESSAGE.to_string();
for (buffer, mut matched) in excerpts_by_buffer {
matched
.line_ranges
.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
write_codeblock(
&matched.full_path,
merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
&[],
Line(matched.snapshot.max_point().row),
true,
&mut merged_result,
);
result_buffers_by_path.insert(
matched.full_path,
ResultBuffer {
buffer,
snapshot: matched.snapshot.text,
},
);
}
(result_buffers_by_path, merged_result)
})
.await;
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
ZetaContextRetrievalDebugInfo {
project: project.clone(),
timestamp: Instant::now(),
},
))
.ok();
}
let tool_result = LanguageModelToolResult {
tool_use_id: search_tool_use.id.clone(),
tool_name: SEARCH_TOOL_NAME.into(),
is_error: false,
content: merged_result.into(),
output: None,
};
select_request_messages.extend([
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(search_tool_use)],
cache: false,
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false,
},
]);
if result_buffers_by_path.is_empty() {
log::trace!("context gathering queries produced no results");
return anyhow::Ok(HashMap::default());
}
select_request_messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![SELECT_PROMPT.into()],
cache: false,
});
let mut select_stream = request_tool_call::<SelectToolInput>(
select_request_messages,
SELECT_TOOL_NAME,
&model,
cx,
)
.await?;
cx.background_spawn(async move {
let mut selected_ranges = Vec::new();
while let Some(event) = select_stream.next().await {
match event? {
LanguageModelCompletionEvent::ToolUse(tool_use) => {
if !tool_use.is_input_complete {
continue;
}
if tool_use.name.as_ref() == SELECT_TOOL_NAME {
let call =
serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
selected_ranges.extend(call.ranges);
} else {
log::warn!(
"context gathering model tried to use unknown tool: {}",
tool_use.name
);
}
}
ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
log::error!("{ev:?}");
}
ev => {
log::trace!("context select event: {ev:?}")
}
}
}
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
ZetaContextRetrievalDebugInfo {
project: project.clone(),
timestamp: Instant::now(),
},
))
.ok();
}
if selected_ranges.is_empty() {
log::trace!("context gathering selected no ranges")
}
selected_ranges.sort_unstable_by(|a, b| {
a.start_line
.cmp(&b.start_line)
.then(b.end_line.cmp(&a.end_line))
});
let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
for selected_range in selected_ranges {
if let Some(ResultBuffer { buffer, snapshot }) =
result_buffers_by_path.get(&selected_range.path)
{
let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
let end_point =
snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
let range =
snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
related_excerpts_by_buffer
.entry(buffer.clone())
.or_default()
.push(range);
} else {
log::warn!(
"selected path that wasn't included in search results: {}",
selected_range.path.display()
);
}
}
anyhow::Ok(related_excerpts_by_buffer)
})
.await
})
}
async fn request_tool_call<T: JsonSchema>(
messages: Vec<LanguageModelRequestMessage>,
tool_name: &'static str,
model: &Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
{
let schema = schemars::schema_for!(T);
let request = LanguageModelRequest {
messages,
tools: vec![LanguageModelRequestTool {
name: tool_name.into(),
description: schema
.get("description")
.and_then(|description| description.as_str())
.unwrap()
.to_string(),
input_schema: serde_json::to_value(schema).unwrap(),
}],
..Default::default()
};
Ok(model.stream_completion(request, cx).await?)
}
const MIN_EXCERPT_LEN: usize = 16;
const MAX_EXCERPT_LEN: usize = 768;
const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
struct MatchedBuffer {
snapshot: BufferSnapshot,
line_ranges: Vec<Range<Line>>,
full_path: PathBuf,
}
async fn run_query(
glob: &str,
regex: &str,
results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
path_style: PathStyle,
exclude_matcher: PathMatcher,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<()> {
let include_matcher = PathMatcher::new(vec![glob], path_style)?;
let query = SearchQuery::regex(
regex,
false,
true,
false,
true,
include_matcher,
exclude_matcher,
true,
None,
)?;
let results = project.update(cx, |project, cx| project.search(query, cx))?;
futures::pin_mut!(results);
let mut total_bytes = 0;
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if ranges.is_empty() {
continue;
}
let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
})?
else {
continue;
};
let results_tx = results_tx.clone();
cx.background_spawn(async move {
let mut line_ranges = Vec::with_capacity(ranges.len());
for range in ranges {
let offset_range = range.to_offset(&snapshot);
let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
break;
}
let excerpt = EditPredictionExcerpt::select_from_buffer(
query_point,
&snapshot,
&EditPredictionExcerptOptions {
max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
min_bytes: MIN_EXCERPT_LEN,
target_before_cursor_over_total_bytes: 0.5,
},
None,
);
if let Some(excerpt) = excerpt {
total_bytes += excerpt.range.len();
if !excerpt.line_range.is_empty() {
line_ranges.push(excerpt.line_range);
}
}
}
results_tx
.unbounded_send((
buffer,
MatchedBuffer {
snapshot,
line_ranges,
full_path,
},
))
.log_err();
})
.detach();
}
anyhow::Ok(())
}

View file

@ -0,0 +1,194 @@
use std::ops::Range;
use anyhow::Result;
use collections::HashMap;
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
use futures::{
StreamExt,
channel::mpsc::{self, UnboundedSender},
};
use gpui::{AppContext, AsyncApp, Entity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _};
use project::{
Project, WorktreeSettings,
search::{SearchQuery, SearchResult},
};
use util::{
ResultExt as _,
paths::{PathMatcher, PathStyle},
};
use workspace::item::Settings as _;
pub async fn run_retrieval_searches(
project: Entity<Project>,
regex_by_glob: HashMap<String, String>,
cx: &mut AsyncApp,
) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
let global_settings = WorktreeSettings::get_global(cx);
let exclude_patterns = global_settings
.file_scan_exclusions
.sources()
.iter()
.chain(global_settings.private_files.sources().iter());
let path_style = project.path_style(cx);
anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
})??;
let (results_tx, mut results_rx) = mpsc::unbounded();
for (glob, regex) in regex_by_glob {
let exclude_matcher = exclude_matcher.clone();
let results_tx = results_tx.clone();
let project = project.clone();
cx.spawn(async move |cx| {
run_query(
&glob,
&regex,
results_tx.clone(),
path_style,
exclude_matcher,
&project,
cx,
)
.await
.log_err();
})
.detach()
}
drop(results_tx);
cx.background_spawn(async move {
let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
let mut snapshots = HashMap::default();
let mut total_bytes = 0;
'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
snapshots.insert(buffer.entity_id(), snapshot);
let existing = results.entry(buffer).or_default();
existing.reserve(excerpts.len());
for (range, size) in excerpts {
// Blunt trimming of the results until we have a proper algorithmic filtering step
if (total_bytes + size) > MAX_RESULTS_LEN {
log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
break 'outer;
}
total_bytes += size;
existing.push(range);
}
}
for (buffer, ranges) in results.iter_mut() {
if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
ranges.sort_unstable_by(|a, b| {
a.start
.cmp(&b.start, snapshot)
.then(b.end.cmp(&b.end, snapshot))
});
let mut index = 1;
while index < ranges.len() {
if ranges[index - 1]
.end
.cmp(&ranges[index].start, snapshot)
.is_gt()
{
let removed = ranges.remove(index);
ranges[index - 1].end = removed.end;
} else {
index += 1;
}
}
}
}
Ok(results)
})
.await
}
const MIN_EXCERPT_LEN: usize = 16;
const MAX_EXCERPT_LEN: usize = 768;
const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
async fn run_query(
glob: &str,
regex: &str,
results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
path_style: PathStyle,
exclude_matcher: PathMatcher,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<()> {
let include_matcher = PathMatcher::new(vec![glob], path_style)?;
let query = SearchQuery::regex(
regex,
false,
true,
false,
true,
include_matcher,
exclude_matcher,
true,
None,
)?;
let results = project.update(cx, |project, cx| project.search(query, cx))?;
futures::pin_mut!(results);
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if results_tx.is_closed() {
break;
}
if ranges.is_empty() {
continue;
}
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let results_tx = results_tx.clone();
cx.background_spawn(async move {
let mut excerpts = Vec::with_capacity(ranges.len());
for range in ranges {
let offset_range = range.to_offset(&snapshot);
let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
let excerpt = EditPredictionExcerpt::select_from_buffer(
query_point,
&snapshot,
&EditPredictionExcerptOptions {
max_bytes: MAX_EXCERPT_LEN,
min_bytes: MIN_EXCERPT_LEN,
target_before_cursor_over_total_bytes: 0.5,
},
None,
);
if let Some(excerpt) = excerpt
&& !excerpt.line_range.is_empty()
{
excerpts.push((
snapshot.anchor_after(excerpt.range.start)
..snapshot.anchor_before(excerpt.range.end),
excerpt.range.len(),
));
}
}
let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts));
if let Err(err) = send_result
&& !err.is_disconnected()
{
log::error!("{err}");
}
})
.detach();
}
anyhow::Ok(())
}

1024
crates/zeta2/src/udiff.rs Normal file

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -27,10 +27,11 @@ log.workspace = true
multi_buffer.workspace = true
ordered-float.workspace = true
project.workspace = true
regex-syntax = "0.8.8"
serde.workspace = true
serde_json.workspace = true
telemetry.workspace = true
text.workspace = true
regex-syntax = "0.8.8"
ui.workspace = true
ui_input.workspace = true
util.workspace = true

View file

@ -45,7 +45,6 @@ struct RetrievalRun {
started_at: Instant,
search_results_generated_at: Option<Instant>,
search_results_executed_at: Option<Instant>,
search_results_filtered_at: Option<Instant>,
finished_at: Option<Instant>,
}
@ -117,17 +116,12 @@ impl Zeta2ContextView {
self.handle_search_queries_executed(info, window, cx);
}
}
ZetaDebugInfo::SearchResultsFiltered(info) => {
if info.project == self.project {
self.handle_search_results_filtered(info, window, cx);
}
}
ZetaDebugInfo::ContextRetrievalFinished(info) => {
if info.project == self.project {
self.handle_context_retrieval_finished(info, window, cx);
}
}
ZetaDebugInfo::EditPredicted(_) => {}
ZetaDebugInfo::EditPredictionRequested(_) => {}
}
}
@ -159,7 +153,6 @@ impl Zeta2ContextView {
started_at: info.timestamp,
search_results_generated_at: None,
search_results_executed_at: None,
search_results_filtered_at: None,
finished_at: None,
});
@ -218,18 +211,18 @@ impl Zeta2ContextView {
run.search_results_generated_at = Some(info.timestamp);
run.search_queries = info
.queries
.regex_by_glob
.into_iter()
.map(|query| {
.map(|(glob, regex)| {
let mut regex_parser = regex_syntax::ast::parse::Parser::new();
GlobQueries {
glob: query.glob,
alternations: match regex_parser.parse(&query.regex) {
glob,
alternations: match regex_parser.parse(&regex) {
Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => {
alt.asts.iter().map(|ast| ast.to_string()).collect()
}
_ => vec![query.regex],
_ => vec![regex],
},
}
})
@ -256,20 +249,6 @@ impl Zeta2ContextView {
cx.notify();
}
fn handle_search_results_filtered(
&mut self,
info: ZetaContextRetrievalDebugInfo,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(run) = self.runs.back_mut() else {
return;
};
run.search_results_filtered_at = Some(info.timestamp);
cx.notify();
}
fn handle_go_back(
&mut self,
_: &Zeta2ContextGoBack,
@ -398,19 +377,10 @@ impl Zeta2ContextView {
};
div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis()));
let Some(t3) = run.search_results_filtered_at else {
return pending_message(div, "Filtering results...");
};
div =
div.child(format!("Filtered results: {:>5} ms", (t3 - t2).as_millis()));
let Some(t4) = run.finished_at else {
return pending_message(div, "Building excerpts");
};
div = div
.child(format!("Build excerpts: {:>5} µs", (t4 - t3).as_micros()))
.child(format!("Total: {:>5} ms", (t4 - t0).as_millis()));
div
div.child(format!(
"Total: {:>5} ms",
(run.finished_at.unwrap_or(t0) - t0).as_millis()
))
}),
)
}

View file

@ -5,7 +5,7 @@ use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use chrono::TimeDelta;
use client::{Client, UserStore};
use cloud_llm_client::predict_edits_v3::{
self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat,
DeclarationScoreComponents, PredictEditsRequest, PromptFormat,
};
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
@ -23,7 +23,7 @@ use ui_input::InputField;
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
use zeta2::{
ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, Zeta, Zeta2FeatureFlag,
AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag,
ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
};
@ -123,6 +123,7 @@ struct LastPrediction {
context_editor: Entity<Editor>,
prompt_editor: Entity<Editor>,
retrieval_time: TimeDelta,
request_time: Option<TimeDelta>,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
state: LastPredictionState,
@ -143,7 +144,7 @@ enum LastPredictionState {
model_response_editor: Entity<Editor>,
feedback_editor: Entity<Editor>,
feedback: Option<Feedback>,
response: predict_edits_v3::PredictEditsResponse,
request_id: String,
},
Failed {
message: String,
@ -217,7 +218,7 @@ impl Zeta2Inspector {
});
match &options.context {
ContextMode::Llm(_) => {
ContextMode::Agentic(_) => {
self.context_mode = ContextModeState::Llm;
}
ContextMode::Syntax(_) => {
@ -307,9 +308,11 @@ impl Zeta2Inspector {
};
let context = match zeta_options.context {
ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions {
excerpt: excerpt_options,
}),
ContextMode::Agentic(_context_options) => {
ContextMode::Agentic(AgenticContextOptions {
excerpt: excerpt_options,
})
}
ContextMode::Syntax(context_options) => {
let max_retrieved_declarations = match &this.context_mode {
ContextModeState::Llm => {
@ -368,7 +371,7 @@ impl Zeta2Inspector {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
let mut languages = HashMap::default();
let ZetaDebugInfo::EditPredicted(prediction) = prediction else {
let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else {
return;
};
for ext in prediction
@ -396,6 +399,8 @@ impl Zeta2Inspector {
.await
.log_err();
let json_language = language_registry.language_for_name("Json").await.log_err();
this.update_in(cx, |this, window, cx| {
let context_editor = cx.new(|cx| {
let mut excerpt_score_components = HashMap::default();
@ -492,25 +497,15 @@ impl Zeta2Inspector {
let task = cx.spawn_in(window, {
let markdown_language = markdown_language.clone();
let json_language = json_language.clone();
async move |this, cx| {
let response = response_rx.await;
this.update_in(cx, |this, window, cx| {
if let Some(prediction) = this.last_prediction.as_mut() {
prediction.state = match response {
Ok(Ok(response)) => {
if let Some(debug_info) = &response.debug_info {
prediction.prompt_editor.update(
cx,
|prompt_editor, cx| {
prompt_editor.set_text(
debug_info.prompt.as_str(),
window,
cx,
);
},
);
}
Ok((Ok(response), request_time)) => {
prediction.request_time = Some(request_time);
let feedback_editor = cx.new(|cx| {
let buffer = cx.new(|cx| {
@ -577,16 +572,11 @@ impl Zeta2Inspector {
model_response_editor: cx.new(|cx| {
let buffer = cx.new(|cx| {
let mut buffer = Buffer::local(
response
.debug_info
.as_ref()
.map(|p| p.model_response.as_str())
.unwrap_or(
"(Debug info not available)",
),
serde_json::to_string_pretty(&response)
.unwrap_or_default(),
cx,
);
buffer.set_language(markdown_language, cx);
buffer.set_language(json_language, cx);
buffer
});
let buffer = cx.new(|cx| {
@ -607,10 +597,11 @@ impl Zeta2Inspector {
}),
feedback_editor,
feedback: None,
response,
request_id: response.id.clone(),
}
}
Ok(Err(err)) => {
Ok((Err(err), request_time)) => {
prediction.request_time = Some(request_time);
LastPredictionState::Failed { message: err }
}
Err(oneshot::Canceled) => LastPredictionState::Failed {
@ -644,6 +635,7 @@ impl Zeta2Inspector {
editor
}),
retrieval_time,
request_time: None,
buffer,
position,
state: LastPredictionState::Requested,
@ -700,7 +692,7 @@ impl Zeta2Inspector {
feedback: feedback_state,
feedback_editor,
model_response_editor,
response,
request_id,
..
} = &mut last_prediction.state
else {
@ -734,11 +726,10 @@ impl Zeta2Inspector {
telemetry::event!(
"Zeta2 Prediction Rated",
id = response.request_id,
id = request_id,
kind = kind,
text = text,
request = last_prediction.request,
response = response,
project_snapshot = project_snapshot,
);
})
@ -834,11 +825,11 @@ impl Zeta2Inspector {
let current_options =
this.zeta.read(cx).options().clone();
match current_options.context.clone() {
ContextMode::Llm(_) => {}
ContextMode::Agentic(_) => {}
ContextMode::Syntax(context_options) => {
let options = ZetaOptions {
context: ContextMode::Llm(
LlmContextOptions {
context: ContextMode::Agentic(
AgenticContextOptions {
excerpt: context_options.excerpt,
},
),
@ -865,7 +856,7 @@ impl Zeta2Inspector {
let current_options =
this.zeta.read(cx).options().clone();
match current_options.context.clone() {
ContextMode::Llm(context_options) => {
ContextMode::Agentic(context_options) => {
let options = ZetaOptions {
context: ContextMode::Syntax(
EditPredictionContextOptions {
@ -976,25 +967,6 @@ impl Zeta2Inspector {
return None;
};
let (prompt_planning_time, inference_time, parsing_time) =
if let LastPredictionState::Success {
response:
PredictEditsResponse {
debug_info: Some(debug_info),
..
},
..
} = &prediction.state
{
(
Some(debug_info.prompt_planning_time),
Some(debug_info.inference_time),
Some(debug_info.parsing_time),
)
} else {
(None, None, None)
};
Some(
v_flex()
.p_4()
@ -1005,12 +977,7 @@ impl Zeta2Inspector {
"Context retrieval",
Some(prediction.retrieval_time),
))
.child(Self::render_duration(
"Prompt planning",
prompt_planning_time,
))
.child(Self::render_duration("Inference", inference_time))
.child(Self::render_duration("Parsing", parsing_time)),
.child(Self::render_duration("Request", prediction.request_time)),
)
}

View file

@ -7,13 +7,14 @@ use std::{
use anyhow::Result;
use clap::Args;
use cloud_llm_client::udiff::DiffLine;
use collections::HashSet;
use gpui::AsyncApp;
use zeta2::udiff::DiffLine;
use crate::{
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::CACHE_DIR,
predict::{PredictionDetails, zeta2_predict},
};
@ -54,10 +55,8 @@ pub async fn run_evaluate_one(
app_state: Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<EvaluationResult> {
let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
.join("../../target/zeta-prediction-cache");
let example = NamedExample::load(&example_path).unwrap();
let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
let predictions = if !re_run && example_cache_path.exists() {
let file_contents = fs::read_to_string(&example_cache_path)?;
@ -74,7 +73,7 @@ pub async fn run_evaluate_one(
};
if !example_cache_path.exists() {
fs::create_dir_all(&cache_dir).unwrap();
fs::create_dir_all(&*CACHE_DIR).unwrap();
fs::write(
example_cache_path,
serde_json::to_string(&predictions).unwrap(),

View file

@ -1,28 +1,31 @@
use std::{
borrow::Cow,
cell::RefCell,
env,
fmt::{self, Display},
fs,
io::Write,
mem,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use collections::{HashMap, HashSet};
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use gpui::{AsyncApp, Entity, http_client::Url};
use language::Buffer;
use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
use zeta2::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
@ -215,12 +218,10 @@ impl NamedExample {
let (repo_owner, repo_name) = self.repo_name()?;
let file_name = self.file_name();
let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
fs::create_dir_all(&repos_dir)?;
fs::create_dir_all(&worktrees_dir)?;
fs::create_dir_all(&*REPOS_DIR)?;
fs::create_dir_all(&*WORKTREES_DIR)?;
let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
@ -251,7 +252,7 @@ impl NamedExample {
};
// Create the worktree for this example if needed.
let worktree_path = worktrees_dir.join(&file_name);
let worktree_path = WORKTREES_DIR.join(&file_name);
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
@ -309,7 +310,6 @@ impl NamedExample {
.collect()
}
#[allow(unused)]
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.example.repository_url.contains('@') {
@ -344,13 +344,63 @@ impl NamedExample {
}
}
pub async fn cursor_position(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<(Entity<Buffer>, Anchor)> {
let worktree = project.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})?;
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = self
.example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))?;
let mut cursor_excerpt = self.example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let Some((excerpt_offset, _)) = matches.next() else {
anyhow::bail!(
"Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
);
};
assert!(matches.next().is_none());
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
Ok((cursor_buffer, cursor_anchor))
}
#[must_use]
pub async fn apply_edit_history(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<HashSet<Entity<Buffer>>> {
apply_diff(&self.example.edit_history, project, cx).await
) -> Result<OpenedBuffers<'_>> {
zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
@ -446,404 +496,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
.lock_owned()
.await
}
#[must_use]
pub async fn apply_diff(
diff: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<HashSet<Entity<Buffer>>> {
use cloud_llm_client::udiff::DiffLine;
use std::fmt::Write;
#[derive(Debug, Default)]
struct HunkState {
context: String,
edits: Vec<Edit>,
}
#[derive(Debug)]
struct Edit {
range: Range<usize>,
text: String,
}
let mut old_path = None;
let mut new_path = None;
let mut hunk = HunkState::default();
let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
let mut open_buffers = HashSet::default();
while let Some(diff_line) = diff_lines.next() {
match diff_line {
DiffLine::OldPath { path } => old_path = Some(path),
DiffLine::NewPath { path } => {
if old_path.is_none() {
anyhow::bail!(
"Found a new path header (`+++`) before an (`---`) old path header"
);
}
new_path = Some(path)
}
DiffLine::Context(ctx) => {
writeln!(&mut hunk.context, "{ctx}")?;
}
DiffLine::Deletion(del) => {
let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
if let Some(last_edit) = hunk.edits.last_mut()
&& last_edit.range.end == range.start
{
last_edit.range.end = range.end;
} else {
hunk.edits.push(Edit {
range,
text: String::new(),
});
}
writeln!(&mut hunk.context, "{del}")?;
}
DiffLine::Addition(add) => {
let range = hunk.context.len()..hunk.context.len();
if let Some(last_edit) = hunk.edits.last_mut()
&& last_edit.range.end == range.start
{
writeln!(&mut last_edit.text, "{add}").unwrap();
} else {
hunk.edits.push(Edit {
range,
text: format!("{add}\n"),
});
}
}
DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {}
}
let at_hunk_end = match diff_lines.peek() {
Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
_ => false,
};
if at_hunk_end {
let hunk = mem::take(&mut hunk);
let Some(old_path) = old_path.as_deref() else {
anyhow::bail!("Missing old path (`---`) header")
};
let Some(new_path) = new_path.as_deref() else {
anyhow::bail!("Missing new path (`+++`) header")
};
let buffer = project
.update(cx, |project, cx| {
let project_path = project
.find_project_path(old_path, cx)
.context("Failed to find old_path in project")?;
anyhow::Ok(project.open_buffer(project_path, cx))
})??
.await?;
open_buffers.insert(buffer.clone());
if old_path != new_path {
project
.update(cx, |project, cx| {
let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
let new_path = ProjectPath {
worktree_id: project_file.worktree_id(cx),
path: project_file.path.clone(),
};
project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
})?
.await?;
}
// TODO is it worth using project search?
buffer.update(cx, |buffer, cx| {
let context_offset = if hunk.context.is_empty() {
0
} else {
let text = buffer.text();
if let Some(offset) = text.find(&hunk.context) {
if text[offset + 1..].contains(&hunk.context) {
anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
}
offset
} else {
anyhow::bail!(
"Failed to match context:\n{}\n\nBuffer:\n{}",
hunk.context,
text
);
}
};
buffer.edit(
hunk.edits.into_iter().map(|edit| {
(
context_offset + edit.range.start..context_offset + edit.range.end,
edit.text,
)
}),
None,
cx,
);
anyhow::Ok(())
})??;
}
}
anyhow::Ok(open_buffers)
}
#[cfg(test)]
mod tests {
use super::*;
use ::fs::FakeFs;
use gpui::TestAppContext;
use indoc::indoc;
use pretty_assertions::assert_eq;
use project::Project;
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_apply_diff_successful(cx: &mut TestAppContext) {
let buffer_1_text = indoc! {r#"
one
two
three
four
five
"# };
let buffer_1_text_final = indoc! {r#"
3
4
5
"# };
let buffer_2_text = indoc! {r#"
six
seven
eight
nine
ten
"# };
let buffer_2_text_final = indoc! {r#"
5
six
seven
7.5
eight
nine
ten
11
"# };
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
"file2": buffer_2_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
five
--- a/root/file1
+++ b/root/file1
3
-four
-five
+4
+5
--- a/root/file1
+++ b/root/file1
-one
-two
3
4
--- a/root/file2
+++ b/root/file2
+5
six
--- a/root/file2
+++ b/root/file2
seven
+7.5
eight
--- a/root/file2
+++ b/root/file2
ten
+11
"#};
let _buffers = apply_diff(diff, &project, &mut cx.to_async())
.await
.unwrap();
let buffer_1 = project
.update(cx, |project, cx| {
let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
buffer_1.read_with(cx, |buffer, _cx| {
assert_eq!(buffer.text(), buffer_1_text_final);
});
let buffer_2 = project
.update(cx, |project, cx| {
let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
buffer_2.read_with(cx, |buffer, _cx| {
assert_eq!(buffer.text(), buffer_2_text_final);
});
}
#[gpui::test]
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
let buffer_1_text = indoc! {r#"
one
two
three
four
five
one
two
three
four
five
"# };
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
five
"#};
apply_diff(diff, &project, &mut cx.to_async())
.await
.expect_err("Non-unique edits should fail");
}
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let start = indoc! {r#"
one
two
three
four
five
four
five
"# };
let end = indoc! {r#"
one
two
3
four
5
four
five
"# };
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
path!("/root"),
json!({
"file1": start,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
-five
+5
"#};
let _buffers = apply_diff(diff, &project, &mut cx.to_async())
.await
.unwrap();
let buffer_1 = project
.update(cx, |project, cx| {
let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
buffer_1.read_with(cx, |buffer, _cx| {
assert_eq!(buffer.text(), end);
});
}
}

View file

@ -1,6 +1,7 @@
mod evaluate;
mod example;
mod headless;
mod paths;
mod predict;
mod source_location;
mod syntax_retrieval_stats;
@ -10,28 +11,22 @@ use crate::evaluate::{EvaluateArguments, run_evaluate};
use crate::example::{ExampleFormat, NamedExample};
use crate::predict::{PredictArguments, run_zeta2_predict};
use crate::syntax_retrieval_stats::retrieval_stats;
use ::serde::Serialize;
use ::util::paths::PathStyle;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use cloud_llm_client::predict_edits_v3;
use edit_prediction_context::{
EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
EditPredictionScoreOptions, Line,
EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
};
use futures::StreamExt as _;
use futures::channel::mpsc;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
use language_model::LanguageModelRegistry;
use language::{Bias, Buffer, BufferSnapshot, Point};
use project::{Project, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io::{self};
use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
use zeta2::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@ -79,12 +74,6 @@ enum Zeta2Command {
#[command(subcommand)]
command: Zeta2SyntaxCommand,
},
Llm {
#[clap(flatten)]
args: Zeta2Args,
#[command(subcommand)]
command: Zeta2LlmCommand,
},
Predict(PredictArguments),
Eval(EvaluateArguments),
}
@ -107,14 +96,6 @@ enum Zeta2SyntaxCommand {
},
}
#[derive(Subcommand, Debug)]
enum Zeta2LlmCommand {
Context {
#[clap(flatten)]
context_args: ContextArgs,
},
}
#[derive(Debug, Args)]
#[group(requires = "worktree")]
struct ContextArgs {
@ -388,197 +369,6 @@ async fn zeta2_syntax_context(
Ok(output)
}
async fn zeta2_llm_context(
zeta2_args: Zeta2Args,
context_args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<String> {
let LoadedContext {
buffer,
clipped_cursor,
snapshot: cursor_snapshot,
project,
..
} = load_context(&context_args, app_state, cx).await?;
let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
cx.update(|cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry
.provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
.unwrap()
.authenticate(cx)
})
})?
.await?;
let edit_history_unified_diff = match context_args.edit_history {
Some(events) => events.read_to_string().await?,
None => String::new(),
};
let (debug_tx, mut debug_rx) = mpsc::unbounded();
let excerpt_options = EditPredictionExcerptOptions {
max_bytes: zeta2_args.max_excerpt_bytes,
min_bytes: zeta2_args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
};
let related_excerpts = cx
.update(|cx| {
zeta2::related_excerpts::find_related_excerpts(
buffer,
cursor_position,
&project,
edit_history_unified_diff,
&LlmContextOptions {
excerpt: excerpt_options.clone(),
},
Some(debug_tx),
cx,
)
})?
.await?;
let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
clipped_cursor,
&cursor_snapshot,
&excerpt_options,
None,
)
.context("line didn't fit")?;
#[derive(Serialize)]
struct Output {
excerpts: Vec<OutputExcerpt>,
formatted_excerpts: String,
meta: OutputMeta,
}
#[derive(Default, Serialize)]
struct OutputMeta {
search_prompt: String,
search_queries: Vec<SearchToolQuery>,
}
#[derive(Serialize)]
struct OutputExcerpt {
path: PathBuf,
#[serde(flatten)]
excerpt: Excerpt,
}
let mut meta = OutputMeta::default();
while let Some(debug_info) = debug_rx.next().await {
match debug_info {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
meta.search_prompt = info.search_prompt;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
meta.search_queries = info.queries
}
_ => {}
}
}
cx.update(|cx| {
let mut excerpts = Vec::new();
let mut formatted_excerpts = String::new();
let cursor_insertions = [(
predict_edits_v3::Point {
line: Line(clipped_cursor.row),
column: clipped_cursor.column,
},
CURSOR_MARKER,
)];
let mut cursor_excerpt_added = false;
for (buffer, ranges) in related_excerpts {
let excerpt_snapshot = buffer.read(cx).snapshot();
let mut line_ranges = ranges
.into_iter()
.map(|range| {
let point_range = range.to_point(&excerpt_snapshot);
Line(point_range.start.row)..Line(point_range.end.row)
})
.collect::<Vec<_>>();
let Some(file) = excerpt_snapshot.file() else {
continue;
};
let path = file.full_path(cx);
let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
if is_cursor_file {
let insertion_ix = line_ranges
.binary_search_by(|probe| {
probe
.start
.cmp(&cursor_excerpt.line_range.start)
.then(cursor_excerpt.line_range.end.cmp(&probe.end))
})
.unwrap_or_else(|ix| ix);
line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
cursor_excerpt_added = true;
}
let merged_excerpts =
zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
.into_iter()
.map(|excerpt| OutputExcerpt {
path: path.clone(),
excerpt,
});
let excerpt_start_ix = excerpts.len();
excerpts.extend(merged_excerpts);
write_codeblock(
&path,
excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
if is_cursor_file {
&cursor_insertions
} else {
&[]
},
Line(excerpt_snapshot.max_point().row),
true,
&mut formatted_excerpts,
);
}
if !cursor_excerpt_added {
write_codeblock(
&cursor_snapshot.file().unwrap().full_path(cx),
&[Excerpt {
start_line: cursor_excerpt.line_range.start,
text: cursor_excerpt.text(&cursor_snapshot).body.into(),
}],
&cursor_insertions,
Line(cursor_snapshot.max_point().row),
true,
&mut formatted_excerpts,
);
}
let output = Output {
excerpts,
formatted_excerpts,
meta,
};
Ok(serde_json::to_string_pretty(&output)?)
})
.unwrap()
}
async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
@ -670,13 +460,6 @@ fn main() {
};
println!("{}", result.unwrap());
}
Zeta2Command::Llm { args, command } => match command {
Zeta2LlmCommand::Context { context_args } => {
let result =
zeta2_llm_context(args, context_args, &app_state, cx).await;
println!("{}", result.unwrap());
}
},
},
Command::ConvertExample {
path,

View file

@ -0,0 +1,8 @@
use std::{env, path::PathBuf, sync::LazyLock};
static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
pub static CACHE_DIR: LazyLock<PathBuf> =
LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache"));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));

View file

@ -1,22 +1,20 @@
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::LOGS_DIR;
use ::serde::Serialize;
use ::util::paths::PathStyle;
use anyhow::{Context as _, Result, anyhow};
use clap::Args;
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use futures::StreamExt as _;
use gpui::AsyncApp;
use language_model::LanguageModelRegistry;
use project::{Project, ProjectPath};
use project::Project;
use serde::Deserialize;
use std::cell::Cell;
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use util::rel_path::RelPath;
#[derive(Debug, Args)]
pub struct PredictArguments {
@ -50,21 +48,12 @@ pub async fn zeta2_predict(
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
fs::create_dir_all(&*LOGS_DIR)?;
let worktree_path = example.setup_worktree().await?;
if !AUTHENTICATED.get() {
AUTHENTICATED.set(true);
cx.update(|cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry
.provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
.unwrap()
.authenticate(cx)
})
})?
.await?;
app_state
.client
.sign_in_with_optional_connect(true, cx)
@ -83,6 +72,8 @@ pub async fn zeta2_predict(
)
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
@ -94,58 +85,30 @@ pub async fn zeta2_predict(
})?
.await;
let _edited_buffers = example.apply_edit_history(&project, cx).await?;
let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = example
.example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))?;
let mut cursor_excerpt = example.example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let Some((excerpt_offset, _)) = matches.next() else {
anyhow::bail!(
"Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
);
};
assert!(matches.next().is_none());
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
zeta2::Zeta::try_global(cx)
.unwrap()
.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
let _edited_buffers = example.apply_edit_history(&project, cx).await?;
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
let refresh_task = zeta.update(cx, |zeta, cx| {
zeta.register_buffer(&cursor_buffer, &project, cx);
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?;
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
let mut context_retrieval_started_at = None;
let mut context_retrieval_finished_at = None;
let mut search_queries_generated_at = None;
@ -159,9 +122,14 @@ pub async fn zeta2_predict(
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
context_retrieval_started_at = Some(info.timestamp);
fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
LOGS_DIR.join("search_queries.json"),
serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
@ -173,11 +141,21 @@ pub async fn zeta2_predict(
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
})?);
}
zeta2::ZetaDebugInfo::EditPredicted(request) => {
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
prediction_started_at = Some(Instant::now());
request.response_rx.await?.map_err(|err| anyhow!(err))?;
fs::write(
LOGS_DIR.join("prediction_prompt.md"),
&request.local_prompt.unwrap_or_default(),
)?;
let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
prediction_finished_at = Some(Instant::now());
fs::write(
LOGS_DIR.join("prediction_response.json"),
&serde_json::to_string_pretty(&response).unwrap(),
)?;
for included_file in request.request.included_files {
let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
result
@ -201,7 +179,6 @@ pub async fn zeta2_predict(
}
break;
}
_ => {}
}
}