Add experimental LSP-based context retrieval system for edit prediction (#44036)

To do

* [x] Default to no context retrieval. Allow opting in to LSP-based
retrieval via a setting (for users in `zeta2` feature flag)
* [x] Feed this context to models when enabled
* [x] Make the zeta2 context view work well with LSP retrieval
* [x] Add a UI for the setting (for feature-flagged users)
* [x] Ensure Zeta CLI `context` command is usable

---

* [ ] Filter out LSP definitions that are too large / entire files (e.g.
modules)
* [ ] Introduce timeouts
* [ ] Test with other LSPs
* [ ] Figure out hangs

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Max Brunsfeld 2025-12-04 12:48:39 -08:00 committed by GitHub
parent cd8679e81a
commit 76167109db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 2478 additions and 1337 deletions

28
Cargo.lock generated
View file

@ -5342,6 +5342,32 @@ dependencies = [
"zlog",
]
[[package]]
name = "edit_prediction_context2"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"env_logger 0.11.8",
"futures 0.3.31",
"gpui",
"indoc",
"language",
"log",
"lsp",
"parking_lot",
"pretty_assertions",
"project",
"serde",
"serde_json",
"settings",
"smallvec",
"text",
"tree-sitter",
"util",
"zlog",
]
[[package]]
name = "editor"
version = "0.1.0"
@ -21693,6 +21719,7 @@ dependencies = [
"db",
"edit_prediction",
"edit_prediction_context",
"edit_prediction_context2",
"editor",
"feature_flags",
"fs",
@ -21742,7 +21769,6 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"edit_prediction_context",
"editor",

View file

@ -56,6 +56,7 @@ members = [
"crates/edit_prediction",
"crates/edit_prediction_button",
"crates/edit_prediction_context",
"crates/edit_prediction_context2",
"crates/zeta2_tools",
"crates/editor",
"crates/eval",
@ -316,6 +317,7 @@ image_viewer = { path = "crates/image_viewer" }
edit_prediction = { path = "crates/edit_prediction" }
edit_prediction_button = { path = "crates/edit_prediction_button" }
edit_prediction_context = { path = "crates/edit_prediction_context" }
edit_prediction_context2 = { path = "crates/edit_prediction_context2" }
zeta2_tools = { path = "crates/zeta2_tools" }
inspector_ui = { path = "crates/inspector_ui" }
install_cli = { path = "crates/install_cli" }

View file

@ -1105,9 +1105,33 @@ impl EditPredictionButton {
.separator();
}
let menu = self.build_language_settings_menu(menu, window, cx);
let menu = self.add_provider_switching_section(menu, provider, cx);
menu = self.build_language_settings_menu(menu, window, cx);
if cx.has_flag::<Zeta2FeatureFlag>() {
let settings = all_language_settings(None, cx);
let context_retrieval = settings.edit_predictions.use_context;
menu = menu.separator().header("Context Retrieval").item(
ContextMenuEntry::new("Enable Context Retrieval")
.toggleable(IconPosition::Start, context_retrieval)
.action(workspace::ToggleEditPrediction.boxed_clone())
.handler({
let fs = self.fs.clone();
move |_, cx| {
update_settings_file(fs.clone(), cx, move |settings, _| {
settings
.project
.all_languages
.features
.get_or_insert_default()
.experimental_edit_prediction_context_retrieval =
Some(!context_retrieval)
});
}
}),
);
}
menu = self.add_provider_switching_section(menu, provider, cx);
menu
})
}

View file

@ -0,0 +1,42 @@
[package]
name = "edit_prediction_context2"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/edit_prediction_context2.rs"
[dependencies]
parking_lot.workspace = true
anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
lsp.workspace = true
project.workspace = true
log.workspace = true
serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
[dev-dependencies]
env_logger.workspace = true
indoc.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
serde_json.workspace = true
settings = {workspace= true, features = ["test-support"]}
text = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,324 @@
use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
#[cfg(test)]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24;
pub fn assemble_excerpts(
buffer: &BufferSnapshot,
mut input_ranges: Vec<Range<Point>>,
) -> Vec<RelatedExcerpt> {
merge_ranges(&mut input_ranges);
let mut outline_ranges = Vec::new();
let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
let mut outline_ix = 0;
for input_range in &mut input_ranges {
*input_range = clip_range_to_lines(input_range, false, buffer);
while let Some(outline_item) = outline_items.get(outline_ix) {
let item_range = clip_range_to_lines(&outline_item.range, false, buffer);
if item_range.start > input_range.start {
break;
}
if item_range.end > input_range.start {
let body_range = outline_item
.body_range(buffer)
.map(|body| clip_range_to_lines(&body, true, buffer))
.filter(|body_range| {
body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE
});
add_outline_item(
item_range.clone(),
body_range.clone(),
buffer,
&mut outline_ranges,
);
if let Some(body_range) = body_range
&& input_range.start < body_range.start
{
let mut child_outline_ix = outline_ix + 1;
while let Some(next_outline_item) = outline_items.get(child_outline_ix) {
if next_outline_item.range.end > body_range.end {
break;
}
if next_outline_item.depth == outline_item.depth + 1 {
let next_item_range =
clip_range_to_lines(&next_outline_item.range, false, buffer);
add_outline_item(
next_item_range,
next_outline_item
.body_range(buffer)
.map(|body| clip_range_to_lines(&body, true, buffer)),
buffer,
&mut outline_ranges,
);
child_outline_ix += 1;
}
}
}
}
outline_ix += 1;
}
}
input_ranges.extend_from_slice(&outline_ranges);
merge_ranges(&mut input_ranges);
input_ranges
.into_iter()
.map(|range| {
let offset_range = range.to_offset(buffer);
RelatedExcerpt {
point_range: range,
anchor_range: buffer.anchor_before(offset_range.start)
..buffer.anchor_after(offset_range.end),
text: buffer.as_rope().slice(offset_range),
}
})
.collect()
}
fn clip_range_to_lines(
range: &Range<Point>,
inward: bool,
buffer: &BufferSnapshot,
) -> Range<Point> {
let mut range = range.clone();
if inward {
if range.start.column > 0 {
range.start.column = buffer.line_len(range.start.row);
}
range.end.column = 0;
} else {
range.start.column = 0;
if range.end.column > 0 {
range.end.column = buffer.line_len(range.end.row);
}
}
range
}
fn add_outline_item(
mut item_range: Range<Point>,
body_range: Option<Range<Point>>,
buffer: &BufferSnapshot,
outline_ranges: &mut Vec<Range<Point>>,
) {
if let Some(mut body_range) = body_range {
if body_range.start.column > 0 {
body_range.start.column = buffer.line_len(body_range.start.row);
}
body_range.end.column = 0;
let head_range = item_range.start..body_range.start;
if head_range.start < head_range.end {
outline_ranges.push(head_range);
}
let tail_range = body_range.end..item_range.end;
if tail_range.start < tail_range.end {
outline_ranges.push(tail_range);
}
} else {
item_range.start.column = 0;
item_range.end.column = buffer.line_len(item_range.end.row);
outline_ranges.push(item_range);
}
}
pub fn merge_ranges(ranges: &mut Vec<Range<Point>>) {
ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
let mut index = 1;
while index < ranges.len() {
let mut prev_range_end = ranges[index - 1].end;
if prev_range_end.column > 0 {
prev_range_end += Point::new(1, 0);
}
if (prev_range_end + Point::new(1, 0))
.cmp(&ranges[index].start)
.is_ge()
{
let removed = ranges.remove(index);
if removed.end.cmp(&ranges[index - 1].end).is_gt() {
ranges[index - 1].end = removed.end;
}
} else {
index += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{TestAppContext, prelude::*};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
use pretty_assertions::assert_eq;
use std::{fmt::Write as _, sync::Arc};
use util::test::marked_text_ranges;
#[gpui::test]
fn test_rust(cx: &mut TestAppContext) {
let table = [
(
indoc! {r#"
struct User {
first_name: String,
«last_name»: String,
age: u32,
email: String,
create_at: Instant,
}
impl User {
pub fn first_name(&self) -> String {
self.first_name.clone()
}
pub fn full_name(&self) -> String {
« format!("{} {}", self.first_name, self.last_name)
» }
}
"#},
indoc! {r#"
struct User {
first_name: String,
last_name: String,
}
impl User {
pub fn full_name(&self) -> String {
format!("{} {}", self.first_name, self.last_name)
}
}
"#},
),
(
indoc! {r#"
struct «User» {
first_name: String,
last_name: String,
age: u32,
}
impl User {
// methods
}
"#
},
indoc! {r#"
struct User {
first_name: String,
last_name: String,
age: u32,
}
"#},
),
(
indoc! {r#"
trait «FooProvider» {
const NAME: &'static str;
fn provide_foo(&self, id: usize) -> Foo;
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
ids.iter()
.map(|id| self.provide_foo(*id))
.collect()
}
fn sync(&self);
}
"#
},
indoc! {r#"
trait FooProvider {
const NAME: &'static str;
fn provide_foo(&self, id: usize) -> Foo;
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
}
fn sync(&self);
}
"#},
),
];
for (input, expected_output) in table {
let (input, ranges) = marked_text_ranges(&input, false);
let buffer =
cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
buffer.read_with(cx, |buffer, _cx| {
let ranges: Vec<Range<Point>> = ranges
.into_iter()
.map(|range| range.to_point(&buffer))
.collect();
let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
let output = format_excerpts(buffer, &excerpts);
assert_eq!(output, expected_output);
});
}
}
fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
let mut output = String::new();
let file_line_count = buffer.max_point().row;
let mut current_row = 0;
for excerpt in excerpts {
if excerpt.text.is_empty() {
continue;
}
if current_row < excerpt.point_range.start.row {
writeln!(&mut output, "").unwrap();
}
current_row = excerpt.point_range.start.row;
for line in excerpt.text.to_string().lines() {
output.push_str(line);
output.push('\n');
current_row += 1;
}
}
if current_row < file_line_count {
writeln!(&mut output, "").unwrap();
}
output
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(language::tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View file

@ -0,0 +1,465 @@
use crate::assemble_excerpts::assemble_excerpts;
use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
sync::Arc,
time::{Duration, Instant},
};
use util::{RangeExt as _, ResultExt};
mod assemble_excerpts;
#[cfg(test)]
mod edit_prediction_context_tests;
#[cfg(test)]
mod fake_definition_lsp;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
related_files: Vec<RelatedFile>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
}
pub enum RelatedExcerptStoreEvent {
StartedRefresh,
FinishedRefresh {
cache_hit_count: usize,
cache_miss_count: usize,
mean_definition_latency: Duration,
max_definition_latency: Duration,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct Identifier {
pub name: String,
pub range: Range<Anchor>,
}
enum DefinitionTask {
CacheHit(Arc<CacheEntry>),
CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
}
#[derive(Debug)]
struct CacheEntry {
definitions: SmallVec<[CachedDefinition; 1]>,
}
#[derive(Clone, Debug)]
struct CachedDefinition {
path: ProjectPath,
buffer: Entity<Buffer>,
anchor_range: Range<Anchor>,
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedFile {
#[serde(serialize_with = "serialize_project_path")]
pub path: ProjectPath,
#[serde(skip)]
pub buffer: WeakEntity<Buffer>,
pub excerpts: Vec<RelatedExcerpt>,
pub max_row: u32,
}
impl RelatedFile {
pub fn merge_excerpts(&mut self) {
self.excerpts.sort_unstable_by(|a, b| {
a.point_range
.start
.cmp(&b.point_range.start)
.then(b.point_range.end.cmp(&a.point_range.end))
});
let mut index = 1;
while index < self.excerpts.len() {
if self.excerpts[index - 1]
.point_range
.end
.cmp(&self.excerpts[index].point_range.start)
.is_ge()
{
let removed = self.excerpts.remove(index);
if removed
.point_range
.end
.cmp(&self.excerpts[index - 1].point_range.end)
.is_gt()
{
self.excerpts[index - 1].point_range.end = removed.point_range.end;
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
}
} else {
index += 1;
}
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedExcerpt {
#[serde(skip)]
pub anchor_range: Range<Anchor>,
#[serde(serialize_with = "serialize_point_range")]
pub point_range: Range<Point>,
#[serde(serialize_with = "serialize_rope")]
pub text: Rope,
}
fn serialize_project_path<S: Serializer>(
project_path: &ProjectPath,
serializer: S,
) -> Result<S::Ok, S::Error> {
project_path.path.serialize(serializer)
}
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
rope.to_string().serialize(serializer)
}
fn serialize_point_range<S: Serializer>(
range: &Range<Point>,
serializer: S,
) -> Result<S::Ok, S::Error> {
[
[range.start.row, range.start.column],
[range.end.row, range.end.column],
]
.serialize(serializer)
}
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
impl RelatedExcerptStore {
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
cx.spawn(async move |this, cx| {
let executor = cx.background_executor().clone();
while let Some((mut buffer, mut position)) = update_rx.next().await {
let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
loop {
futures::select_biased! {
next = update_rx.next() => {
if let Some((new_buffer, new_position)) = next {
buffer = new_buffer;
position = new_position;
timer = executor.timer(DEBOUNCE_DURATION).fuse();
} else {
return anyhow::Ok(());
}
}
_ = timer => break,
}
}
Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
related_files: Vec::new(),
cache: Default::default(),
}
}
pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
self.update_tx.unbounded_send((buffer, position)).ok();
}
pub fn related_files(&self) -> &[RelatedFile] {
&self.related_files
}
async fn fetch_excerpts(
this: WeakEntity<Self>,
buffer: Entity<Buffer>,
position: Anchor,
cx: &mut AsyncApp,
) -> Result<()> {
let (project, snapshot) = this.read_with(cx, |this, cx| {
(this.project.upgrade(), buffer.read(cx).snapshot())
})?;
let Some(project) = project else {
return Ok(());
};
let file = snapshot.file().cloned();
if let Some(file) = &file {
log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
}
this.update(cx, |_, cx| {
cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
})?;
let identifiers = cx
.background_spawn(async move { identifiers_for_position(&snapshot, position) })
.await;
let async_cx = cx.clone();
let start_time = Instant::now();
let futures = this.update(cx, |this, cx| {
identifiers
.into_iter()
.filter_map(|identifier| {
let task = if let Some(entry) = this.cache.get(&identifier) {
DefinitionTask::CacheHit(entry.clone())
} else {
DefinitionTask::CacheMiss(
this.project
.update(cx, |project, cx| {
project.definitions(&buffer, identifier.range.start, cx)
})
.ok()?,
)
};
let cx = async_cx.clone();
let project = project.clone();
Some(async move {
match task {
DefinitionTask::CacheHit(cache_entry) => {
Some((identifier, cache_entry, None))
}
DefinitionTask::CacheMiss(task) => {
let locations = task.await.log_err()??;
let duration = start_time.elapsed();
cx.update(|cx| {
(
identifier,
Arc::new(CacheEntry {
definitions: locations
.into_iter()
.filter_map(|location| {
process_definition(location, &project, cx)
})
.collect(),
}),
Some(duration),
)
})
.ok()
}
}
})
})
.collect::<Vec<_>>()
})?;
let mut cache_hit_count = 0;
let mut cache_miss_count = 0;
let mut mean_definition_latency = Duration::ZERO;
let mut max_definition_latency = Duration::ZERO;
let mut new_cache = HashMap::default();
new_cache.reserve(futures.len());
for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
new_cache.insert(identifier, entry);
if let Some(duration) = duration {
cache_miss_count += 1;
mean_definition_latency += duration;
max_definition_latency = max_definition_latency.max(duration);
} else {
cache_hit_count += 1;
}
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
"finished retrieving context buffer:{}, latency:{:?}",
file.path().as_unix_str(),
start_time.elapsed()
);
}
this.update(cx, |this, cx| {
this.cache = new_cache;
this.related_files = related_files;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
});
})?;
anyhow::Ok(())
}
}
async fn rebuild_related_files(
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
let mut snapshots = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
definition
.buffer
.read_with(cx, |buffer, _| buffer.parsing_idle())?
.await;
e.insert(
definition
.buffer
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
}
}
Ok(cx
.background_spawn(async move {
let mut files = Vec::<RelatedFile>::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
continue;
};
paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
ranges_by_buffer
.entry(definition.buffer.clone())
.or_default()
.push(definition.anchor_range.to_point(snapshot));
}
}
for (buffer, ranges) in ranges_by_buffer {
let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
continue;
};
let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
files.push(RelatedFile {
path: project_path.clone(),
buffer: buffer.downgrade(),
excerpts,
max_row: snapshot.max_point().row,
});
}
files.sort_by_key(|file| file.path.clone());
(new_entries, files)
})
.await)
}
fn process_definition(
location: LocationLink,
project: &Entity<Project>,
cx: &mut App,
) -> Option<CachedDefinition> {
let buffer = location.target.buffer.read(cx);
let anchor_range = location.target.range;
let file = buffer.file()?;
let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
if worktree.read(cx).is_single_file() {
return None;
}
Some(CachedDefinition {
path: ProjectPath {
worktree_id: file.worktree_id(cx),
path: file.path().clone(),
},
buffer: location.target.buffer,
anchor_range,
})
}
/// Gets all of the identifiers that are present in the given line, and its containing
/// outline items.
fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec<Identifier> {
let offset = position.to_offset(buffer);
let point = buffer.offset_to_point(offset);
let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point());
let mut ranges = vec![line_range.to_offset(&buffer)];
// Include the range of the outline item itself, but not its body.
let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
for item in outline_items {
if let Some(body_range) = item.body_range(&buffer) {
ranges.push(item.range.start..body_range.start.to_offset(&buffer));
} else {
ranges.push(item.range.clone());
}
}
ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
ranges.dedup_by(|a, b| {
if a.start <= b.end {
b.start = b.start.min(a.start);
b.end = b.end.max(a.end);
true
} else {
false
}
});
let mut identifiers = Vec::new();
let outer_range =
ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
let mut captures = buffer
.syntax
.captures(outer_range.clone(), &buffer.text, |grammar| {
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
for range in ranges {
captures.set_byte_range(range.start..outer_range.end);
let mut last_range = None;
while let Some(capture) = captures.peek() {
let node_range = capture.node.byte_range();
if node_range.start > range.end {
break;
}
let config = captures.grammars()[capture.grammar_index]
.highlights_config
.as_ref();
if let Some(config) = config
&& config.identifier_capture_indices.contains(&capture.index)
&& range.contains_inclusive(&node_range)
&& Some(&node_range) != last_range.as_ref()
{
let name = buffer.text_for_range(node_range.clone()).collect();
identifiers.push(Identifier {
range: buffer.anchor_after(node_range.start)
..buffer.anchor_before(node_range.end),
name,
});
last_range = Some(node_range);
}
captures.advance();
}
}
identifiers
}

View file

@ -0,0 +1,360 @@
use super::*;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::TestAppContext;
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageMatcher, Point, ToPoint as _, tree_sitter_rust};
use lsp::FakeLanguageServer;
use project::{FakeFs, LocationLink, Project};
use serde_json::json;
use settings::SettingsStore;
use std::sync::Arc;
use util::path;
#[gpui::test]
async fn test_edit_prediction_context(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), test_project_1()).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let mut servers = setup_fake_lsp(&project, cx);
let (buffer, _handle) = project
.update(cx, |project, cx| {
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
let _server = servers.next().await.unwrap();
cx.run_until_parked();
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
related_excerpt_store.update(cx, |store, cx| {
let position = {
let buffer = buffer.read(cx);
let offset = buffer.text().find("todo").unwrap();
buffer.anchor_before(offset)
};
store.refresh(buffer.clone(), position, cx);
});
cx.executor().advance_clock(DEBOUNCE_DURATION);
related_excerpt_store.update(cx, |store, _| {
let excerpts = store.related_files();
assert_related_files(
&excerpts,
&[
(
"src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
address: Address,
}"}],
),
(
"src/main.rs",
&[
indoc! {"
pub struct Session {
company: Arc<Company>,
}
impl Session {
pub fn set_company(&mut self, company: Arc<Company>) {"},
indoc! {"
}
}"},
],
),
(
"src/person.rs",
&[
indoc! {"
impl Person {
pub fn get_first_name(&self) -> &str {
&self.first_name
}"},
"}",
],
),
],
);
});
}
#[gpui::test]
async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), test_project_1()).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let mut servers = setup_fake_lsp(&project, cx);
let (buffer, _handle) = project
.update(cx, |project, cx| {
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
let _server = servers.next().await.unwrap();
cx.run_until_parked();
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let definitions = project
.update(cx, |project, cx| {
let offset = buffer_text.find("Address {").unwrap();
project.definitions(&buffer, offset, cx)
})
.await
.unwrap()
.unwrap();
assert_definitions(&definitions, &["pub struct Address {"], cx);
let definitions = project
.update(cx, |project, cx| {
let offset = buffer_text.find("State::CA").unwrap();
project.definitions(&buffer, offset, cx)
})
.await
.unwrap()
.unwrap();
assert_definitions(&definitions, &["pub enum State {"], cx);
let definitions = project
.update(cx, |project, cx| {
let offset = buffer_text.find("to_string()").unwrap();
project.definitions(&buffer, offset, cx)
})
.await
.unwrap()
.unwrap();
assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx);
}
fn init_test(cx: &mut TestAppContext) {
let settings_store = cx.update(|cx| SettingsStore::test(cx));
cx.set_global(settings_store);
env_logger::try_init().ok();
}
fn setup_fake_lsp(
project: &Entity<Project>,
cx: &mut TestAppContext,
) -> UnboundedReceiver<FakeLanguageServer> {
let (language_registry, fs) = project.read_with(cx, |project, _| {
(project.languages().clone(), project.fs().clone())
});
let language = rust_lang();
language_registry.add(language.clone());
fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs)
}
fn test_project_1() -> serde_json::Value {
let person_rs = indoc! {r#"
pub struct Person {
first_name: String,
last_name: String,
email: String,
age: u32,
}
impl Person {
pub fn get_first_name(&self) -> &str {
&self.first_name
}
pub fn get_last_name(&self) -> &str {
&self.last_name
}
pub fn get_email(&self) -> &str {
&self.email
}
pub fn get_age(&self) -> u32 {
self.age
}
}
"#};
let address_rs = indoc! {r#"
pub struct Address {
street: String,
city: String,
state: State,
zip: u32,
}
pub enum State {
CA,
OR,
WA,
TX,
// ...
}
impl Address {
pub fn get_street(&self) -> &str {
&self.street
}
pub fn get_city(&self) -> &str {
&self.city
}
pub fn get_state(&self) -> State {
self.state
}
pub fn get_zip(&self) -> u32 {
self.zip
}
}
"#};
let company_rs = indoc! {r#"
use super::person::Person;
use super::address::Address;
pub struct Company {
owner: Arc<Person>,
address: Address,
}
impl Company {
pub fn get_owner(&self) -> &Person {
&self.owner
}
pub fn get_address(&self) -> &Address {
&self.address
}
pub fn to_string(&self) -> String {
format!("{} ({})", self.owner.first_name, self.address.city)
}
}
"#};
let main_rs = indoc! {r#"
use std::sync::Arc;
use super::person::Person;
use super::address::Address;
use super::company::Company;
pub struct Session {
company: Arc<Company>,
}
impl Session {
pub fn set_company(&mut self, company: Arc<Company>) {
self.company = company;
if company.owner != self.company.owner {
log("new owner", company.owner.get_first_name()); todo();
}
}
}
fn main() {
let company = Company {
owner: Arc::new(Person {
first_name: "John".to_string(),
last_name: "Doe".to_string(),
email: "john@example.com".to_string(),
age: 30,
}),
address: Address {
street: "123 Main St".to_string(),
city: "Anytown".to_string(),
state: State::CA,
zip: 12345,
},
};
println!("Company: {}", company.to_string());
}
"#};
json!({
"src": {
"person.rs": person_rs,
"address.rs": address_rs,
"company.rs": company_rs,
"main.rs": main_rs,
},
})
}
fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) {
let actual_files = actual_files
.iter()
.map(|file| {
let excerpts = file
.excerpts
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
(file.path.path.as_unix_str(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
.iter()
.map(|(path, texts)| {
(
*path,
texts
.iter()
.map(|line| line.to_string())
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(actual_files, expected_excerpts)
}
fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
let actual_first_lines = definitions
.iter()
.map(|definition| {
definition.target.buffer.read_with(cx, |buffer, _| {
let mut start = definition.target.range.start.to_point(&buffer);
start.column = 0;
let end = Point::new(start.row, buffer.line_len(start.row));
buffer
.text_for_range(start..end)
.collect::<String>()
.trim()
.to_string()
})
})
.collect::<Vec<String>>();
assert_eq!(actual_first_lines, first_lines);
}
pub(crate) fn rust_lang() -> Arc<Language> {
Arc::new(
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
first_line_pattern: None,
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
.unwrap()
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap(),
)
}

View file

@ -0,0 +1,329 @@
use collections::HashMap;
use futures::channel::mpsc::UnboundedReceiver;
use language::{Language, LanguageRegistry};
use lsp::{
FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri,
};
use parking_lot::Mutex;
use project::Fs;
use std::{ops::Range, path::PathBuf, sync::Arc};
use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree};
/// Registers a fake language server that implements go-to-definition using tree-sitter,
/// making the assumption that all names are unique, and all variables' types are
/// explicitly declared.
pub fn register_fake_definition_server(
language_registry: &Arc<LanguageRegistry>,
language: Arc<Language>,
fs: Arc<dyn Fs>,
) -> UnboundedReceiver<FakeLanguageServer> {
let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone())));
language_registry.register_fake_lsp(
language.name(),
language::FakeLspAdapter {
name: "fake-definition-lsp",
initialization_options: None,
prettier_plugins: Vec::new(),
disk_based_diagnostics_progress_token: None,
disk_based_diagnostics_sources: Vec::new(),
language_server_binary: LanguageServerBinary {
path: PathBuf::from("fake-definition-lsp"),
arguments: Vec::new(),
env: None,
},
capabilities: lsp::ServerCapabilities {
definition_provider: Some(lsp::OneOf::Left(true)),
text_document_sync: Some(TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::FULL,
)),
..Default::default()
},
label_for_completion: None,
initializer: Some(Box::new({
move |server| {
server.handle_notification::<lsp::notification::DidOpenTextDocument, _>({
let index = index.clone();
move |params, _cx| {
index
.lock()
.open_buffer(params.text_document.uri, &params.text_document.text);
}
});
server.handle_notification::<lsp::notification::DidCloseTextDocument, _>({
let index = index.clone();
let fs = fs.clone();
move |params, cx| {
let uri = params.text_document.uri;
let path = uri.to_file_path().ok();
index.lock().mark_buffer_closed(&uri);
if let Some(path) = path {
let index = index.clone();
let fs = fs.clone();
cx.spawn(async move |_cx| {
if let Ok(content) = fs.load(&path).await {
index.lock().index_file(uri, &content);
}
})
.detach();
}
}
});
server.handle_notification::<lsp::notification::DidChangeWatchedFiles, _>({
let index = index.clone();
let fs = fs.clone();
move |params, cx| {
let index = index.clone();
let fs = fs.clone();
cx.spawn(async move |_cx| {
for event in params.changes {
if index.lock().is_buffer_open(&event.uri) {
continue;
}
match event.typ {
lsp::FileChangeType::DELETED => {
index.lock().remove_definitions_for_file(&event.uri);
}
lsp::FileChangeType::CREATED
| lsp::FileChangeType::CHANGED => {
if let Some(path) = event.uri.to_file_path().ok() {
if let Ok(content) = fs.load(&path).await {
index.lock().index_file(event.uri, &content);
}
}
}
_ => {}
}
}
})
.detach();
}
});
server.handle_notification::<lsp::notification::DidChangeTextDocument, _>({
let index = index.clone();
move |params, _cx| {
if let Some(change) = params.content_changes.into_iter().last() {
index
.lock()
.index_file(params.text_document.uri, &change.text);
}
}
});
server.handle_notification::<lsp::notification::DidChangeWorkspaceFolders, _>(
{
let index = index.clone();
let fs = fs.clone();
move |params, cx| {
let index = index.clone();
let fs = fs.clone();
let files = fs.as_fake().files();
cx.spawn(async move |_cx| {
for folder in params.event.added {
let Ok(path) = folder.uri.to_file_path() else {
continue;
};
for file in &files {
if let Some(uri) = Uri::from_file_path(&file).ok()
&& file.starts_with(&path)
&& let Ok(content) = fs.load(&file).await
{
index.lock().index_file(uri, &content);
}
}
}
})
.detach();
}
},
);
server.set_request_handler::<lsp::request::GotoDefinition, _, _>({
let index = index.clone();
move |params, _cx| {
let result = index.lock().get_definitions(
params.text_document_position_params.text_document.uri,
params.text_document_position_params.position,
);
async move { Ok(result) }
}
});
}
})),
},
)
}
struct DefinitionIndex {
language: Arc<Language>,
definitions: HashMap<String, Vec<lsp::Location>>,
files: HashMap<Uri, FileEntry>,
}
#[derive(Debug)]
struct FileEntry {
contents: String,
is_open_in_buffer: bool,
}
impl DefinitionIndex {
fn new(language: Arc<Language>) -> Self {
Self {
language,
definitions: HashMap::default(),
files: HashMap::default(),
}
}
fn remove_definitions_for_file(&mut self, uri: &Uri) {
self.definitions.retain(|_, locations| {
locations.retain(|loc| &loc.uri != uri);
!locations.is_empty()
});
self.files.remove(uri);
}
fn open_buffer(&mut self, uri: Uri, content: &str) {
self.index_file_inner(uri, content, true);
}
fn mark_buffer_closed(&mut self, uri: &Uri) {
if let Some(entry) = self.files.get_mut(uri) {
entry.is_open_in_buffer = false;
}
}
fn is_buffer_open(&self, uri: &Uri) -> bool {
self.files
.get(uri)
.map(|entry| entry.is_open_in_buffer)
.unwrap_or(false)
}
fn index_file(&mut self, uri: Uri, content: &str) {
self.index_file_inner(uri, content, false);
}
fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> {
self.remove_definitions_for_file(&uri);
let grammar = self.language.grammar()?;
let outline_config = grammar.outline_config.as_ref()?;
let mut parser = Parser::new();
parser.set_language(&grammar.ts_language).ok()?;
let tree = parser.parse(content, None)?;
let declarations = extract_declarations_from_tree(&tree, content, outline_config);
for (name, byte_range) in declarations {
let range = byte_range_to_lsp_range(content, byte_range);
let location = lsp::Location {
uri: uri.clone(),
range,
};
self.definitions
.entry(name)
.or_insert_with(Vec::new)
.push(location);
}
self.files.insert(
uri,
FileEntry {
contents: content.to_string(),
is_open_in_buffer,
},
);
Some(())
}
fn get_definitions(
&mut self,
uri: Uri,
position: lsp::Position,
) -> Option<lsp::GotoDefinitionResponse> {
let entry = self.files.get(&uri)?;
let name = word_at_position(&entry.contents, position)?;
let locations = self.definitions.get(name).cloned()?;
Some(lsp::GotoDefinitionResponse::Array(locations))
}
}
fn extract_declarations_from_tree(
tree: &Tree,
content: &str,
outline_config: &language::OutlineConfig,
) -> Vec<(String, Range<usize>)> {
let mut cursor = QueryCursor::new();
let mut declarations = Vec::new();
let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes());
while let Some(query_match) = matches.next() {
let mut name_range: Option<Range<usize>> = None;
let mut has_item_range = false;
for capture in query_match.captures {
let range = capture.node.byte_range();
if capture.index == outline_config.name_capture_ix {
name_range = Some(range);
} else if capture.index == outline_config.item_capture_ix {
has_item_range = true;
}
}
if let Some(name_range) = name_range
&& has_item_range
{
let name = content[name_range.clone()].to_string();
if declarations.iter().any(|(n, _)| n == &name) {
continue;
}
declarations.push((name, name_range));
}
}
declarations
}
fn byte_range_to_lsp_range(content: &str, byte_range: Range<usize>) -> lsp::Range {
let start = byte_offset_to_position(content, byte_range.start);
let end = byte_offset_to_position(content, byte_range.end);
lsp::Range { start, end }
}
fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position {
let mut line = 0;
let mut character = 0;
let mut current_offset = 0;
for ch in content.chars() {
if current_offset >= offset {
break;
}
if ch == '\n' {
line += 1;
character = 0;
} else {
character += 1;
}
current_offset += ch.len_utf8();
}
lsp::Position { line, character }
}
fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> {
let mut lines = content.lines();
let line = lines.nth(position.line as usize)?;
let column = position.character as usize;
if column > line.len() {
return None;
}
let start = line[..column]
.rfind(|c: char| !c.is_alphanumeric() && c != '_')
.map(|i| i + 1)
.unwrap_or(0);
let end = line[column..]
.find(|c: char| !c.is_alphanumeric() && c != '_')
.map(|i| i + column)
.unwrap_or(line.len());
Some(&line[start..end]).filter(|word| !word.is_empty())
}

View file

@ -705,7 +705,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
.await
.unwrap();
let mut fake_servers = language_registry.register_fake_language_server(
let mut fake_servers = language_registry.register_fake_lsp_server(
LanguageServerName("gleam".into()),
lsp::ServerCapabilities {
completion_provider: Some(Default::default()),

View file

@ -4022,6 +4022,20 @@ impl BufferSnapshot {
})
}
pub fn outline_items_as_offsets_containing<T: ToOffset>(
&self,
range: Range<T>,
include_extra_context: bool,
theme: Option<&SyntaxTheme>,
) -> Vec<OutlineItem<usize>> {
self.outline_items_containing_internal(
range,
include_extra_context,
theme,
|buffer, range| range.to_offset(buffer),
)
}
fn outline_items_containing_internal<T: ToOffset, U>(
&self,
range: Range<T>,

View file

@ -784,28 +784,48 @@ async fn test_outline(cx: &mut gpui::TestAppContext) {
.unindent();
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None));
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
let outline = snapshot.outline(None);
assert_eq!(
pretty_assertions::assert_eq!(
outline
.items
.iter()
.map(|item| (item.text.as_str(), item.depth))
.map(|item| (
item.text.as_str(),
item.depth,
item.to_point(&snapshot).body_range(&snapshot)
.map(|range| minimize_space(&snapshot.text_for_range(range).collect::<String>()))
))
.collect::<Vec<_>>(),
&[
("struct Person", 0),
("name", 1),
("age", 1),
("mod module", 0),
("enum LoginState", 1),
("LoggedOut", 2),
("LoggingOn", 2),
("LoggedIn", 2),
("person", 3),
("time", 3),
("impl Eq for Person", 0),
("impl Drop for Person", 0),
("fn drop", 1),
("struct Person", 0, Some("name: String, age: usize,".to_string())),
("name", 1, None),
("age", 1, None),
(
"mod module",
0,
Some(
"enum LoginState { LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, } }".to_string()
)
),
(
"enum LoginState",
1,
Some("LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, }".to_string())
),
("LoggedOut", 2, None),
("LoggingOn", 2, None),
("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())),
("person", 3, None),
("time", 3, None),
("impl Eq for Person", 0, None),
(
"impl Drop for Person",
0,
Some("fn drop(&mut self) { println!(\"bye\"); }".to_string())
),
("fn drop", 1, Some("println!(\"bye\");".to_string())),
]
);
@ -840,6 +860,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) {
]
);
fn minimize_space(text: &str) -> String {
static WHITESPACE: LazyLock<Regex> = LazyLock::new(|| Regex::new("[\\n\\s]+").unwrap());
WHITESPACE.replace_all(text, " ").trim().to_string()
}
async fn search<'a>(
outline: &'a Outline<Anchor>,
query: &'a str,

View file

@ -437,26 +437,14 @@ impl LanguageRegistry {
language_name: impl Into<LanguageName>,
mut adapter: crate::FakeLspAdapter,
) -> futures::channel::mpsc::UnboundedReceiver<lsp::FakeLanguageServer> {
let language_name = language_name.into();
let adapter_name = LanguageServerName(adapter.name.into());
let capabilities = adapter.capabilities.clone();
let initializer = adapter.initializer.take();
let adapter = CachedLspAdapter::new(Arc::new(adapter));
{
let mut state = self.state.write();
state
.lsp_adapters
.entry(language_name)
.or_default()
.push(adapter.clone());
state.all_lsp_adapters.insert(adapter.name(), adapter);
}
self.register_fake_language_server(adapter_name, capabilities, initializer)
self.register_fake_lsp_adapter(language_name, adapter);
self.register_fake_lsp_server(adapter_name, capabilities, initializer)
}
/// Register a fake lsp adapter (without the language server)
/// The returned channel receives a new instance of the language server every time it is started
#[cfg(any(feature = "test-support", test))]
pub fn register_fake_lsp_adapter(
&self,
@ -479,7 +467,7 @@ impl LanguageRegistry {
/// Register a fake language server (without the adapter)
/// The returned channel receives a new instance of the language server every time it is started
#[cfg(any(feature = "test-support", test))]
pub fn register_fake_language_server(
pub fn register_fake_lsp_server(
&self,
lsp_name: LanguageServerName,
capabilities: lsp::ServerCapabilities,

View file

@ -373,6 +373,8 @@ impl InlayHintSettings {
pub struct EditPredictionSettings {
/// The provider that supplies edit predictions.
pub provider: settings::EditPredictionProvider,
/// Whether to use the experimental edit prediction context retrieval system.
pub use_context: bool,
/// A list of globs representing files that edit predictions should be disabled for.
/// This list adds to a pre-existing, sensible default set of globs.
/// Any additional ones you add are combined with them.
@ -622,6 +624,11 @@ impl settings::Settings for AllLanguageSettings {
.features
.as_ref()
.and_then(|f| f.edit_prediction_provider);
let use_edit_prediction_context = all_languages
.features
.as_ref()
.and_then(|f| f.experimental_edit_prediction_context_retrieval)
.unwrap_or_default();
let edit_predictions = all_languages.edit_predictions.clone().unwrap();
let edit_predictions_mode = edit_predictions.mode.unwrap();
@ -668,6 +675,7 @@ impl settings::Settings for AllLanguageSettings {
} else {
EditPredictionProvider::None
},
use_context: use_edit_prediction_context,
disabled_globs: disabled_globs
.iter()
.filter_map(|g| {

View file

@ -1,4 +1,4 @@
use crate::{BufferSnapshot, Point, ToPoint};
use crate::{BufferSnapshot, Point, ToPoint, ToTreeSitterPoint};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{BackgroundExecutor, HighlightStyle};
use std::ops::Range;
@ -48,6 +48,54 @@ impl<T: ToPoint> OutlineItem<T> {
.map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)),
}
}
pub fn body_range(&self, buffer: &BufferSnapshot) -> Option<Range<Point>> {
if let Some(range) = self.body_range.as_ref() {
return Some(range.start.to_point(buffer)..range.end.to_point(buffer));
}
let range = self.range.start.to_point(buffer)..self.range.end.to_point(buffer);
let start_indent = buffer.indent_size_for_line(range.start.row);
let node = buffer.syntax_ancestor(range.clone())?;
let mut cursor = node.walk();
loop {
let node = cursor.node();
if node.start_position() >= range.start.to_ts_point()
&& node.end_position() <= range.end.to_ts_point()
{
break;
}
cursor.goto_first_child_for_point(range.start.to_ts_point());
}
if !cursor.goto_last_child() {
return None;
}
let body_node = loop {
let node = cursor.node();
if node.child_count() > 0 {
break node;
}
if !cursor.goto_previous_sibling() {
return None;
}
};
let mut start_row = body_node.start_position().row as u32;
let mut end_row = body_node.end_position().row as u32;
while start_row < end_row && buffer.indent_size_for_line(start_row) == start_indent {
start_row += 1;
}
while start_row < end_row && buffer.indent_size_for_line(end_row - 1) == start_indent {
end_row -= 1;
}
if start_row < end_row {
return Some(Point::new(start_row, 0)..Point::new(end_row, 0));
}
None
}
}
impl<T> Outline<T> {

View file

@ -1215,6 +1215,19 @@ impl<'a> SyntaxMapMatches<'a> {
true
}
// pub fn set_byte_range(&mut self, range: Range<usize>) {
// for layer in &mut self.layers {
// layer.matches.set_byte_range(range.clone());
// layer.advance();
// }
// self.layers.sort_unstable_by_key(|layer| layer.sort_key());
// self.active_layer_count = self
// .layers
// .iter()
// .position(|layer| !layer.has_next)
// .unwrap_or(self.layers.len());
// }
}
impl SyntaxMapCapturesLayer<'_> {

View file

@ -452,7 +452,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
});
let mut fake_lsp = server_cx.update(|cx| {
headless.read(cx).languages.register_fake_language_server(
headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("rust-analyzer".into()),
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions::default()),
@ -476,7 +476,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
..FakeLspAdapter::default()
},
);
headless.read(cx).languages.register_fake_language_server(
headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("fake-analyzer".into()),
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions::default()),
@ -669,7 +669,7 @@ async fn test_remote_cancel_language_server_work(
});
let mut fake_lsp = server_cx.update(|cx| {
headless.read(cx).languages.register_fake_language_server(
headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("rust-analyzer".into()),
Default::default(),
None,

View file

@ -62,6 +62,8 @@ impl merge_from::MergeFrom for AllLanguageSettingsContent {
pub struct FeaturesContent {
/// Determines which edit prediction provider to use.
pub edit_prediction_provider: Option<EditPredictionProvider>,
/// Enables the experimental edit prediction context retrieval system.
pub experimental_edit_prediction_context_retrieval: Option<bool>,
}
/// The provider that supplies edit predictions.

View file

@ -8,10 +8,14 @@ use sum_tree::{Bias, Dimensions};
/// A timestamped position in a buffer
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct Anchor {
/// The timestamp of the operation that inserted the text
/// in which this anchor is located.
pub timestamp: clock::Lamport,
/// The byte offset in the buffer
/// The byte offset into the text inserted in the operation
/// at `timestamp`.
pub offset: usize,
/// Describes which character the anchor is biased towards
/// Whether this anchor stays attached to the character *before* or *after*
/// the offset.
pub bias: Bias,
pub buffer_id: Option<BufferId>,
}

View file

@ -485,6 +485,7 @@ pub struct Table<const COLS: usize = 3> {
interaction_state: Option<WeakEntity<TableInteractionState>>,
col_widths: Option<TableWidths<COLS>>,
map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>,
use_ui_font: bool,
empty_table_callback: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>>,
}
@ -498,6 +499,7 @@ impl<const COLS: usize> Table<COLS> {
rows: TableContents::Vec(Vec::new()),
interaction_state: None,
map_row: None,
use_ui_font: true,
empty_table_callback: None,
col_widths: None,
}
@ -590,6 +592,11 @@ impl<const COLS: usize> Table<COLS> {
self
}
pub fn no_ui_font(mut self) -> Self {
self.use_ui_font = false;
self
}
pub fn map_row(
mut self,
callback: impl Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement + 'static,
@ -618,8 +625,8 @@ fn base_cell_style(width: Option<Length>) -> Div {
.overflow_hidden()
}
fn base_cell_style_text(width: Option<Length>, cx: &App) -> Div {
base_cell_style(width).text_ui(cx)
fn base_cell_style_text(width: Option<Length>, use_ui_font: bool, cx: &App) -> Div {
base_cell_style(width).when(use_ui_font, |el| el.text_ui(cx))
}
pub fn render_table_row<const COLS: usize>(
@ -656,7 +663,12 @@ pub fn render_table_row<const COLS: usize>(
.map(IntoElement::into_any_element)
.into_iter()
.zip(column_widths)
.map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)),
.map(|(cell, width)| {
base_cell_style_text(width, table_context.use_ui_font, cx)
.px_1()
.py_0p5()
.child(cell)
}),
);
let row = if let Some(map_row) = table_context.map_row {
@ -700,7 +712,7 @@ pub fn render_table_header<const COLS: usize>(
.border_color(cx.theme().colors().border)
.children(headers.into_iter().enumerate().zip(column_widths).map(
|((header_idx, h), width)| {
base_cell_style_text(width, cx)
base_cell_style_text(width, table_context.use_ui_font, cx)
.child(h)
.id(ElementId::NamedInteger(
shared_element_id.clone(),
@ -739,6 +751,7 @@ pub struct TableRenderContext<const COLS: usize> {
pub total_row_count: usize,
pub column_widths: Option<[Length; COLS]>,
pub map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>,
pub use_ui_font: bool,
}
impl<const COLS: usize> TableRenderContext<COLS> {
@ -748,6 +761,7 @@ impl<const COLS: usize> TableRenderContext<COLS> {
total_row_count: table.rows.len(),
column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)),
map_row: table.map_row.clone(),
use_ui_font: table.use_ui_font,
}
}
}

View file

@ -30,6 +30,7 @@ credentials_provider.workspace = true
db.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
edit_prediction_context2.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true

View file

@ -1,173 +0,0 @@
use cloud_llm_client::predict_edits_v3::Excerpt;
use edit_prediction_context::Line;
use language::{BufferSnapshot, Point};
use std::ops::Range;
pub fn assemble_excerpts(
buffer: &BufferSnapshot,
merged_line_ranges: impl IntoIterator<Item = Range<Line>>,
) -> Vec<Excerpt> {
let mut output = Vec::new();
let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
let mut outline_items = outline_items.into_iter().peekable();
for range in merged_line_ranges {
let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
while let Some(outline_item) = outline_items.peek() {
if outline_item.range.start >= point_range.start {
break;
}
if outline_item.range.end > point_range.start {
let mut point_range = outline_item.source_range_for_text.clone();
point_range.start.column = 0;
point_range.end.column = buffer.line_len(point_range.end.row);
output.push(Excerpt {
start_line: Line(point_range.start.row),
text: buffer
.text_for_range(point_range.clone())
.collect::<String>()
.into(),
})
}
outline_items.next();
}
output.push(Excerpt {
start_line: Line(point_range.start.row),
text: buffer
.text_for_range(point_range.clone())
.collect::<String>()
.into(),
})
}
output
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use cloud_llm_client::predict_edits_v3;
use gpui::{TestAppContext, prelude::*};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
use pretty_assertions::assert_eq;
use util::test::marked_text_ranges;
#[gpui::test]
fn test_rust(cx: &mut TestAppContext) {
let table = [
(
indoc! {r#"
struct User {
first_name: String,
« last_name: String,
ageˇ: u32,
» email: String,
create_at: Instant,
}
impl User {
pub fn first_name(&self) -> String {
self.first_name.clone()
}
pub fn full_name(&self) -> String {
« format!("{} {}", self.first_name, self.last_name)
» }
}
"#},
indoc! {r#"
1|struct User {
3| last_name: String,
4| age<|cursor|>: u32,
9|impl User {
14| pub fn full_name(&self) -> String {
15| format!("{} {}", self.first_name, self.last_name)
"#},
),
(
indoc! {r#"
struct User {
first_name: String,
« last_name: String,
age: u32,
}
»"#
},
indoc! {r#"
1|struct User {
3| last_name: String,
4| age: u32,
5|}
"#},
),
];
for (input, expected_output) in table {
let input_without_ranges = input.replace(['«', '»'], "");
let input_without_caret = input.replace('ˇ', "");
let cursor_offset = input_without_ranges.find('ˇ');
let (input, ranges) = marked_text_ranges(&input_without_caret, false);
let buffer =
cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
buffer.read_with(cx, |buffer, _cx| {
let insertions = cursor_offset
.map(|offset| {
let point = buffer.offset_to_point(offset);
vec![(
predict_edits_v3::Point {
line: Line(point.row),
column: point.column,
},
"<|cursor|>",
)]
})
.unwrap_or_default();
let ranges: Vec<Range<Line>> = ranges
.into_iter()
.map(|range| {
let point_range = range.to_point(&buffer);
Line(point_range.start.row)..Line(point_range.end.row)
})
.collect();
let mut output = String::new();
cloud_zeta2_prompt::write_excerpts(
assemble_excerpts(&buffer.snapshot(), ranges).iter(),
&insertions,
Line(buffer.max_point().row),
true,
&mut output,
);
assert_eq!(output, expected_output);
});
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(language::tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View file

@ -1,6 +1,7 @@
use anyhow::Result;
use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
use collections::HashMap;
use edit_prediction_context2::{RelatedExcerpt, RelatedFile};
use futures::{
StreamExt,
channel::mpsc::{self, UnboundedSender},
@ -8,7 +9,7 @@ use futures::{
use gpui::{AppContext, AsyncApp, Entity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
use project::{
Project, WorktreeSettings,
Project, ProjectPath, WorktreeSettings,
search::{SearchQuery, SearchResult},
};
use smol::channel;
@ -20,14 +21,14 @@ use util::{
use workspace::item::Settings as _;
#[cfg(feature = "eval-support")]
type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<(u32, u32)>>>;
pub async fn run_retrieval_searches(
queries: Vec<SearchToolQuery>,
project: Entity<Project>,
#[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
cx: &mut AsyncApp,
) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
) -> Result<Vec<RelatedFile>> {
#[cfg(feature = "eval-support")]
let cache = if let Some(eval_cache) = eval_cache {
use crate::EvalCacheEntryKind;
@ -54,24 +55,44 @@ pub async fn run_retrieval_searches(
if let Some(cached_results) = eval_cache.read(key) {
let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
.context("Failed to deserialize cached search results")?;
let mut results = HashMap::default();
let mut results = Vec::new();
for (path, ranges) in file_results {
let project_path = project.update(cx, |project, cx| {
project.find_project_path(path, cx).unwrap()
})?;
let buffer = project
.update(cx, |project, cx| {
let project_path = project.find_project_path(path, cx).unwrap();
project.open_buffer(project_path, cx)
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let mut ranges: Vec<_> = ranges
.into_iter()
.map(|range| {
snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
})
.map(
|Range {
start: (start_row, start_col),
end: (end_row, end_col),
}| {
snapshot.anchor_before(Point::new(start_row, start_col))
..snapshot.anchor_after(Point::new(end_row, end_col))
},
)
.collect();
merge_anchor_ranges(&mut ranges, &snapshot);
results.insert(buffer, ranges);
results.push(RelatedFile {
path: project_path,
buffer: buffer.downgrade(),
excerpts: ranges
.into_iter()
.map(|range| RelatedExcerpt {
point_range: range.to_point(&snapshot),
text: snapshot.as_rope().slice(range.to_offset(&snapshot)),
anchor_range: range,
})
.collect(),
max_row: snapshot.max_point().row,
});
}
return Ok(results);
@ -117,14 +138,29 @@ pub async fn run_retrieval_searches(
#[cfg(feature = "eval-support")]
let cache = cache.clone();
cx.background_spawn(async move {
let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
let mut results: Vec<RelatedFile> = Vec::default();
let mut snapshots = HashMap::default();
let mut total_bytes = 0;
'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
snapshots.insert(buffer.entity_id(), snapshot);
let existing = results.entry(buffer).or_default();
existing.reserve(excerpts.len());
'outer: while let Some((project_path, buffer, snapshot, excerpts)) = results_rx.next().await
{
let existing = results
.iter_mut()
.find(|related_file| related_file.buffer.entity_id() == buffer.entity_id());
let existing = match existing {
Some(existing) => existing,
None => {
results.push(RelatedFile {
path: project_path,
buffer: buffer.downgrade(),
excerpts: Vec::new(),
max_row: snapshot.max_point().row,
});
results.last_mut().unwrap()
}
};
// let existing = results.entry(buffer).or_default();
existing.excerpts.reserve(excerpts.len());
for (range, size) in excerpts {
// Blunt trimming of the results until we have a proper algorithmic filtering step
@ -133,24 +169,34 @@ pub async fn run_retrieval_searches(
break 'outer;
}
total_bytes += size;
existing.push(range);
existing.excerpts.push(RelatedExcerpt {
point_range: range.to_point(&snapshot),
text: snapshot.as_rope().slice(range.to_offset(&snapshot)),
anchor_range: range,
});
}
snapshots.insert(buffer.entity_id(), snapshot);
}
#[cfg(feature = "eval-support")]
if let Some((cache, queries, key)) = cache {
let cached_results: CachedSearchResults = results
.iter()
.filter_map(|(buffer, ranges)| {
let snapshot = snapshots.get(&buffer.entity_id())?;
let path = snapshot.file().map(|f| f.path());
let mut ranges = ranges
.map(|related_file| {
let mut ranges = related_file
.excerpts
.iter()
.map(|range| range.to_offset(&snapshot))
.map(
|RelatedExcerpt {
point_range: Range { start, end },
..
}| {
(start.row, start.column)..(end.row, end.column)
},
)
.collect::<Vec<_>>();
ranges.sort_unstable_by_key(|range| (range.start, range.end));
Some((path?.as_std_path().to_path_buf(), ranges))
(related_file.path.path.as_std_path().to_path_buf(), ranges)
})
.collect();
cache.write(
@ -160,10 +206,8 @@ pub async fn run_retrieval_searches(
);
}
for (buffer, ranges) in results.iter_mut() {
if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
merge_anchor_ranges(ranges, snapshot);
}
for related_file in results.iter_mut() {
related_file.merge_excerpts();
}
Ok(results)
@ -171,6 +215,7 @@ pub async fn run_retrieval_searches(
.await
}
#[cfg(feature = "eval-support")]
pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
ranges.sort_unstable_by(|a, b| {
a.start
@ -201,6 +246,7 @@ const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
struct SearchJob {
buffer: Entity<Buffer>,
snapshot: BufferSnapshot,
project_path: ProjectPath,
ranges: Vec<Range<usize>>,
query_ix: usize,
jobs_tx: channel::Sender<SearchJob>,
@ -208,7 +254,12 @@ struct SearchJob {
async fn run_query(
input_query: SearchToolQuery,
results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
results_tx: UnboundedSender<(
ProjectPath,
Entity<Buffer>,
BufferSnapshot,
Vec<(Range<Anchor>, usize)>,
)>,
path_style: PathStyle,
exclude_matcher: PathMatcher,
project: &Entity<Project>,
@ -257,12 +308,21 @@ async fn run_query(
.read_with(cx, |buffer, _| buffer.parsing_idle())?
.await;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let Some(file) = snapshot.file() else {
continue;
};
let project_path = cx.update(|cx| ProjectPath {
worktree_id: file.worktree_id(cx),
path: file.path().clone(),
})?;
let expanded_ranges: Vec<_> = ranges
.into_iter()
.filter_map(|range| expand_to_parent_range(&range, &snapshot))
.collect();
jobs_tx
.send(SearchJob {
project_path,
buffer,
snapshot,
ranges: expanded_ranges,
@ -301,6 +361,13 @@ async fn run_query(
while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let Some(file) = snapshot.file() else {
continue;
};
let project_path = cx.update(|cx| ProjectPath {
worktree_id: file.worktree_id(cx),
path: file.path().clone(),
})?;
let ranges = ranges
.into_iter()
@ -314,7 +381,8 @@ async fn run_query(
})
.collect();
let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
let send_result =
results_tx.unbounded_send((project_path, buffer.clone(), snapshot.clone(), ranges));
if let Err(err) = send_result
&& !err.is_disconnected()
@ -330,7 +398,12 @@ async fn run_query(
}
async fn process_nested_search_job(
results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
results_tx: &UnboundedSender<(
ProjectPath,
Entity<Buffer>,
BufferSnapshot,
Vec<(Range<Anchor>, usize)>,
)>,
queries: &Vec<SearchQuery>,
content_query: &Option<SearchQuery>,
job: SearchJob,
@ -347,6 +420,7 @@ async fn process_nested_search_job(
}
job.jobs_tx
.send(SearchJob {
project_path: job.project_path,
buffer: job.buffer,
snapshot: job.snapshot,
ranges: subranges,
@ -382,7 +456,8 @@ async fn process_nested_search_job(
})
.collect();
let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
let send_result =
results_tx.unbounded_send((job.project_path, job.buffer, job.snapshot, matches));
if let Err(err) = send_result
&& !err.is_disconnected()
@ -413,230 +488,3 @@ fn expand_to_parent_range<T: ToPoint + ToOffset>(
let node = snapshot.syntax_ancestor(line_range)?;
Some(node.byte_range())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assemble_excerpts::assemble_excerpts;
use cloud_zeta2_prompt::write_codeblock;
use edit_prediction_context::Line;
use gpui::TestAppContext;
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use pretty_assertions::assert_eq;
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use std::path::Path;
use util::path;
#[gpui::test]
async fn test_retrieval(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"user.rs": indoc!{"
pub struct Organization {
owner: Arc<User>,
}
pub struct User {
first_name: String,
last_name: String,
}
impl Organization {
pub fn owner(&self) -> Arc<User> {
self.owner.clone()
}
}
impl User {
pub fn new(first_name: String, last_name: String) -> Self {
Self {
first_name,
last_name
}
}
pub fn first_name(&self) -> String {
self.first_name.clone()
}
pub fn last_name(&self) -> String {
self.last_name.clone()
}
}
"},
"main.rs": indoc!{r#"
fn main() {
let user = User::new(FIRST_NAME.clone(), "doe".into());
println!("user {:?}", user);
}
"#},
}),
)
.await;
let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
project.update(cx, |project, _cx| {
project.languages().add(rust_lang().into())
});
assert_results(
&project,
SearchToolQuery {
glob: "user.rs".into(),
syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
content: None,
},
indoc! {r#"
`````root/user.rs
impl User {
pub fn first_name(&self) -> String {
self.first_name.clone()
}
`````
"#},
cx,
)
.await;
assert_results(
&project,
SearchToolQuery {
glob: "user.rs".into(),
syntax_node: vec!["impl\\s+User".into()],
content: Some("\\.clone".into()),
},
indoc! {r#"
`````root/user.rs
impl User {
pub fn first_name(&self) -> String {
self.first_name.clone()
pub fn last_name(&self) -> String {
self.last_name.clone()
`````
"#},
cx,
)
.await;
assert_results(
&project,
SearchToolQuery {
glob: "*.rs".into(),
syntax_node: vec![],
content: Some("\\.clone".into()),
},
indoc! {r#"
`````root/main.rs
fn main() {
let user = User::new(FIRST_NAME.clone(), "doe".into());
`````
`````root/user.rs
impl Organization {
pub fn owner(&self) -> Arc<User> {
self.owner.clone()
impl User {
pub fn first_name(&self) -> String {
self.first_name.clone()
pub fn last_name(&self) -> String {
self.last_name.clone()
`````
"#},
cx,
)
.await;
}
async fn assert_results(
project: &Entity<Project>,
query: SearchToolQuery,
expected_output: &str,
cx: &mut TestAppContext,
) {
let results = run_retrieval_searches(
vec![query],
project.clone(),
#[cfg(feature = "eval-support")]
None,
&mut cx.to_async(),
)
.await
.unwrap();
let mut results = results.into_iter().collect::<Vec<_>>();
results.sort_by_key(|results| {
results
.0
.read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
});
let mut output = String::new();
for (buffer, ranges) in results {
buffer.read_with(cx, |buffer, cx| {
let excerpts = ranges.into_iter().map(|range| {
let point_range = range.to_point(buffer);
if point_range.end.column > 0 {
Line(point_range.start.row)..Line(point_range.end.row + 1)
} else {
Line(point_range.start.row)..Line(point_range.end.row)
}
});
write_codeblock(
&buffer.file().unwrap().full_path(cx),
assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
&[],
Line(buffer.max_point().row),
false,
&mut output,
);
});
}
output.pop();
assert_eq!(output, expected_output);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
fn init_test(cx: &mut TestAppContext) {
cx.update(move |cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
zlog::init_test();
});
}
}

View file

@ -1,6 +1,7 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context2::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
@ -49,6 +50,7 @@ impl SweepAi {
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
diagnostic_search_range: Range<Point>,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
@ -120,6 +122,19 @@ impl SweepAi {
})
.collect::<Vec<_>>();
let retrieval_chunks = related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
file_path: related_file.path.path.as_unix_str().to_string(),
start_line: excerpt.point_range.start.row as usize,
end_line: excerpt.point_range.end.row as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@ -168,7 +183,7 @@ impl SweepAi {
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
retrieval_chunks,
recent_user_actions: vec![],
use_bytes: true,
// TODO
@ -320,7 +335,7 @@ struct AutocompleteRequest {
pub cursor_position: usize,
pub original_file_contents: String,
pub file_chunks: Vec<FileChunk>,
pub retrieval_chunks: Vec<RetrievalChunk>,
pub retrieval_chunks: Vec<FileChunk>,
pub recent_user_actions: Vec<UserAction>,
pub multiple_suggestions: bool,
pub privacy_mode_enabled: bool,
@ -337,15 +352,6 @@ struct FileChunk {
pub timestamp: Option<u64>,
}
#[derive(Debug, Clone, Serialize)]
struct RetrievalChunk {
pub file_path: String,
pub start_line: usize,
pub end_line: usize,
pub content: String,
pub timestamp: u64,
}
#[derive(Debug, Clone, Serialize)]
struct UserAction {
pub action_type: ActionType,

File diff suppressed because it is too large Load diff

View file

@ -15,7 +15,6 @@ path = "src/zeta2_tools.rs"
anyhow.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
edit_prediction_context.workspace = true
editor.workspace = true

View file

@ -8,26 +8,25 @@ use std::{
use anyhow::Result;
use client::{Client, UserStore};
use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
use editor::{Editor, PathKey};
use futures::StreamExt as _;
use gpui::{
Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle,
Focusable, ParentElement as _, SharedString, Styled as _, Task, TextAlign, Window, actions,
pulsating_between,
Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString,
Styled as _, Task, TextAlign, Window, actions, div, pulsating_between,
};
use multi_buffer::MultiBuffer;
use project::Project;
use text::OffsetRangeExt;
use ui::{
ButtonCommon, Clickable, Color, Disableable, FluentBuilder as _, Icon, IconButton, IconName,
IconSize, InteractiveElement, IntoElement, ListHeader, ListItem, StyledTypography, div, h_flex,
v_flex,
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
};
use workspace::Item;
use zeta::{
Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
ZetaSearchQueryDebugInfo,
Zeta, ZetaContextRetrievalFinishedDebugInfo, ZetaContextRetrievalStartedDebugInfo,
ZetaDebugInfo,
};
pub struct Zeta2ContextView {
@ -42,10 +41,8 @@ pub struct Zeta2ContextView {
#[derive(Debug)]
struct RetrievalRun {
editor: Entity<Editor>,
search_queries: Vec<SearchToolQuery>,
started_at: Instant,
search_results_generated_at: Option<Instant>,
search_results_executed_at: Option<Instant>,
metadata: Vec<(&'static str, SharedString)>,
finished_at: Option<Instant>,
}
@ -97,22 +94,12 @@ impl Zeta2ContextView {
) {
match event {
ZetaDebugInfo::ContextRetrievalStarted(info) => {
if info.project == self.project {
if info.project_entity_id == self.project.entity_id() {
self.handle_context_retrieval_started(info, window, cx);
}
}
ZetaDebugInfo::SearchQueriesGenerated(info) => {
if info.project == self.project {
self.handle_search_queries_generated(info, window, cx);
}
}
ZetaDebugInfo::SearchQueriesExecuted(info) => {
if info.project == self.project {
self.handle_search_queries_executed(info, window, cx);
}
}
ZetaDebugInfo::ContextRetrievalFinished(info) => {
if info.project == self.project {
if info.project_entity_id == self.project.entity_id() {
self.handle_context_retrieval_finished(info, window, cx);
}
}
@ -129,7 +116,7 @@ impl Zeta2ContextView {
if self
.runs
.back()
.is_some_and(|run| run.search_results_executed_at.is_none())
.is_some_and(|run| run.finished_at.is_none())
{
self.runs.pop_back();
}
@ -144,11 +131,9 @@ impl Zeta2ContextView {
self.runs.push_back(RetrievalRun {
editor,
search_queries: Vec::new(),
started_at: info.timestamp,
search_results_generated_at: None,
search_results_executed_at: None,
finished_at: None,
metadata: Vec::new(),
});
cx.notify();
@ -156,7 +141,7 @@ impl Zeta2ContextView {
fn handle_context_retrieval_finished(
&mut self,
info: ZetaContextRetrievalDebugInfo,
info: ZetaContextRetrievalFinishedDebugInfo,
window: &mut Window,
cx: &mut Context<Self>,
) {
@ -165,67 +150,72 @@ impl Zeta2ContextView {
};
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
let project = self.project.clone();
let related_files = self
.zeta
.read(cx)
.context_for_project(&self.project, cx)
.to_vec();
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
multibuffer.update(cx, |multibuffer, cx| {
multibuffer.clear(cx);
let context = self.zeta.read(cx).context_for_project(&self.project);
let mut paths = Vec::new();
for (buffer, ranges) in context {
let path = PathKey::for_buffer(&buffer, cx);
let snapshot = buffer.read(cx).snapshot();
let ranges = ranges
.iter()
.map(|range| range.to_point(&snapshot))
.collect::<Vec<_>>();
paths.push((path, buffer, ranges));
}
for (path, buffer, ranges) in paths {
multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx);
}
});
run.editor.update(cx, |editor, cx| {
editor.move_to_beginning(&Default::default(), window, cx);
});
cx.notify();
}
fn handle_search_queries_generated(
&mut self,
info: ZetaSearchQueryDebugInfo,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(run) = self.runs.back_mut() else {
return;
};
run.search_results_generated_at = Some(info.timestamp);
run.search_queries = info.search_queries;
cx.notify();
}
fn handle_search_queries_executed(
&mut self,
info: ZetaContextRetrievalDebugInfo,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if self.current_ix + 2 == self.runs.len() {
// Switch to latest when the queries are executed
self.current_ix += 1;
}
let Some(run) = self.runs.back_mut() else {
return;
};
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
for related_file in related_files {
let (buffer, point_ranges): (_, Vec<_>) =
if let Some(buffer) = related_file.buffer.upgrade() {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
run.search_results_executed_at = Some(info.timestamp);
cx.notify();
(
buffer,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
.collect(),
)
} else {
(
project
.update(cx, |project, cx| {
project.open_buffer(related_file.path.clone(), cx)
})?
.await?,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.point_range.clone())
.collect(),
)
};
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));
})?;
}
multibuffer.update(cx, |multibuffer, cx| {
multibuffer.clear(cx);
for (path, buffer, ranges) in paths {
multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx);
}
})?;
editor.update_in(cx, |editor, window, cx| {
editor.move_to_beginning(&Default::default(), window, cx);
})?;
this.update(cx, |_, cx| cx.notify())
})
.detach();
}
fn handle_go_back(
@ -254,8 +244,11 @@ impl Zeta2ContextView {
}
fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div {
let is_latest = self.runs.len() == self.current_ix + 1;
let run = &self.runs[self.current_ix];
let new_run_started = self
.runs
.back()
.map_or(false, |latest_run| latest_run.finished_at.is_none());
h_flex()
.p_2()
@ -264,114 +257,65 @@ impl Zeta2ContextView {
.text_xs()
.border_t_1()
.gap_2()
.child(v_flex().h_full().flex_1().child({
let t0 = run.started_at;
let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font();
for (key, value) in &run.metadata {
table = table.row([key.into_any_element(), value.clone().into_any_element()])
}
table = table.row([
"Total Time".into_any_element(),
format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis())
.into_any_element(),
]);
table
}))
.child(
v_flex().h_full().flex_1().children(
run.search_queries
.iter()
.enumerate()
.flat_map(|(ix, query)| {
std::iter::once(ListHeader::new(query.glob.clone()).into_any_element())
.chain(query.syntax_node.iter().enumerate().map(
move |(regex_ix, regex)| {
ListItem::new(ix * 100 + regex_ix)
.start_slot(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
)
.child(regex.clone())
.into_any_element()
},
v_flex().h_full().text_align(TextAlign::Right).child(
h_flex()
.justify_end()
.child(
IconButton::new("go-back", IconName::ChevronLeft)
.disabled(self.current_ix == 0 || self.runs.len() < 2)
.tooltip(ui::Tooltip::for_action_title(
"Go to previous run",
&Zeta2ContextGoBack,
))
.chain(query.content.as_ref().map(move |regex| {
ListItem::new(ix * 100 + query.syntax_node.len())
.start_slot(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
.on_click(cx.listener(|this, _, window, cx| {
this.handle_go_back(&Zeta2ContextGoBack, window, cx);
})),
)
.child(
div()
.child(format!("{}/{}", self.current_ix + 1, self.runs.len()))
.map(|this| {
if new_run_started {
this.with_animation(
"pulsating-count",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.opacity(delta),
)
.child(regex.clone())
.into_any_element()
}))
}),
} else {
this.into_any_element()
}
}),
)
.child(
IconButton::new("go-forward", IconName::ChevronRight)
.disabled(self.current_ix + 1 == self.runs.len())
.tooltip(ui::Tooltip::for_action_title(
"Go to next run",
&Zeta2ContextGoBack,
))
.on_click(cx.listener(|this, _, window, cx| {
this.handle_go_forward(&Zeta2ContextGoForward, window, cx);
})),
),
),
)
.child(
v_flex()
.h_full()
.text_align(TextAlign::Right)
.child(
h_flex()
.justify_end()
.child(
IconButton::new("go-back", IconName::ChevronLeft)
.disabled(self.current_ix == 0 || self.runs.len() < 2)
.tooltip(ui::Tooltip::for_action_title(
"Go to previous run",
&Zeta2ContextGoBack,
))
.on_click(cx.listener(|this, _, window, cx| {
this.handle_go_back(&Zeta2ContextGoBack, window, cx);
})),
)
.child(
div()
.child(format!("{}/{}", self.current_ix + 1, self.runs.len()))
.map(|this| {
if self.runs.back().is_some_and(|back| {
back.search_results_executed_at.is_none()
}) {
this.with_animation(
"pulsating-count",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.opacity(delta),
)
.into_any_element()
} else {
this.into_any_element()
}
}),
)
.child(
IconButton::new("go-forward", IconName::ChevronRight)
.disabled(self.current_ix + 1 == self.runs.len())
.tooltip(ui::Tooltip::for_action_title(
"Go to next run",
&Zeta2ContextGoBack,
))
.on_click(cx.listener(|this, _, window, cx| {
this.handle_go_forward(&Zeta2ContextGoForward, window, cx);
})),
),
)
.map(|mut div| {
let pending_message = |div: ui::Div, msg: &'static str| {
if is_latest {
return div.child(msg);
} else {
return div.child("Canceled");
}
};
let t0 = run.started_at;
let Some(t1) = run.search_results_generated_at else {
return pending_message(div, "Planning search...");
};
div = div.child(format!("Planned search: {:>5} ms", (t1 - t0).as_millis()));
let Some(t2) = run.search_results_executed_at else {
return pending_message(div, "Running search...");
};
div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis()));
div.child(format!(
"Total: {:>5} ms",
(run.finished_at.unwrap_or(t0) - t0).as_millis()
))
}),
)
}
}

View file

@ -108,6 +108,7 @@ pub struct Zeta2Inspector {
pub enum ContextModeState {
Llm,
Lsp,
Syntax {
max_retrieved_declarations: Entity<InputField>,
},
@ -222,6 +223,9 @@ impl Zeta2Inspector {
),
};
}
ContextMode::Lsp(_) => {
self.context_mode = ContextModeState::Lsp;
}
}
cx.notify();
}
@ -302,6 +306,9 @@ impl Zeta2Inspector {
ContextModeState::Syntax {
max_retrieved_declarations,
} => number_input_value(max_retrieved_declarations, cx),
ContextModeState::Lsp => {
zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
}
};
ContextMode::Syntax(EditPredictionContextOptions {
@ -310,6 +317,7 @@ impl Zeta2Inspector {
..context_options
})
}
ContextMode::Lsp(excerpt_options) => ContextMode::Lsp(excerpt_options),
};
this.set_zeta_options(
@ -656,6 +664,7 @@ impl Zeta2Inspector {
ContextModeState::Syntax {
max_retrieved_declarations,
} => Some(max_retrieved_declarations.clone()),
ContextModeState::Lsp => None,
})
.child(self.max_prompt_bytes_input.clone())
.child(self.render_prompt_format_dropdown(window, cx)),
@ -679,6 +688,7 @@ impl Zeta2Inspector {
match &self.context_mode {
ContextModeState::Llm => "LLM-based",
ContextModeState::Syntax { .. } => "Syntax",
ContextModeState::Lsp => "LSP-based",
},
ContextMenu::build(window, cx, move |menu, _window, _cx| {
menu.item(
@ -695,6 +705,7 @@ impl Zeta2Inspector {
this.zeta.read(cx).options().clone();
match current_options.context.clone() {
ContextMode::Agentic(_) => {}
ContextMode::Lsp(_) => {}
ContextMode::Syntax(context_options) => {
let options = ZetaOptions {
context: ContextMode::Agentic(
@ -739,6 +750,7 @@ impl Zeta2Inspector {
this.set_zeta_options(options, cx);
}
ContextMode::Syntax(_) => {}
ContextMode::Lsp(_) => {}
}
})
.ok();

View file

@ -21,15 +21,12 @@ use ::util::paths::PathStyle;
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
use edit_prediction_context::{
EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
};
use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
use metrics::delta_chr_f;
use project::{Project, Worktree};
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io::{self};
use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
@ -97,7 +94,7 @@ struct ContextArgs {
enum ContextProvider {
Zeta1,
#[default]
Syntax,
Zeta2,
}
#[derive(Clone, Debug, Args)]
@ -204,19 +201,12 @@ enum PredictionProvider {
Sweep,
}
fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions {
zeta::ZetaOptions {
context: ContextMode::Syntax(EditPredictionContextOptions {
max_retrieved_declarations: args.max_retrieved_definitions,
use_imports: !args.disable_imports_gathering,
excerpt: EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
},
score: EditPredictionScoreOptions {
omit_excerpt_overlaps,
},
context: ContextMode::Lsp(EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
}),
max_diagnostic_bytes: args.max_diagnostic_bytes,
max_prompt_bytes: args.max_prompt_bytes,
@ -295,6 +285,7 @@ struct LoadedContext {
worktree: Entity<Worktree>,
project: Entity<Project>,
buffer: Entity<Buffer>,
lsp_open_handle: Option<OpenLspBufferHandle>,
}
async fn load_context(
@ -330,7 +321,7 @@ async fn load_context(
.await?;
let mut ready_languages = HashSet::default();
let (_lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
project.clone(),
worktree.clone(),
@ -377,10 +368,11 @@ async fn load_context(
worktree,
project,
buffer,
lsp_open_handle,
})
}
async fn zeta2_syntax_context(
async fn zeta2_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@ -390,6 +382,7 @@ async fn zeta2_syntax_context(
project,
buffer,
clipped_cursor,
lsp_open_handle: _handle,
..
} = load_context(&args, app_state, cx).await?;
@ -406,30 +399,26 @@ async fn zeta2_syntax_context(
zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
});
let indexing_done_task = zeta.update(cx, |zeta, cx| {
zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
zeta.set_options(zeta2_args_to_options(&args.zeta2_args));
zeta.register_buffer(&buffer, &project, cx);
zeta.wait_for_initial_indexing(&project, cx)
});
cx.spawn(async move |cx| {
indexing_done_task.await?;
let request = zeta
.update(cx, |zeta, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
})?
.await?;
let updates_rx = zeta.update(cx, |zeta, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
zeta.set_use_context(true);
zeta.refresh_context_if_needed(&project, &buffer, cursor, cx);
zeta.project_context_updates(&project).unwrap()
})?;
let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
updates_rx.recv().await.ok();
match args.zeta2_args.output_format {
OutputFormat::Prompt => anyhow::Ok(prompt_string),
OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
"request": request,
"prompt": prompt_string,
"section_labels": section_labels,
}))?),
}
let context = zeta.update(cx, |zeta, cx| {
zeta.context_for_project(&project, cx).to_vec()
})?;
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
})
})?
.await?;
@ -482,7 +471,6 @@ fn main() {
None => {
if args.printenv {
::util::shell_env::print_env();
return;
} else {
panic!("Expected a command");
}
@ -494,7 +482,7 @@ fn main() {
arguments.extension,
arguments.limit,
arguments.skip,
zeta2_args_to_options(&arguments.zeta2_args, false),
zeta2_args_to_options(&arguments.zeta2_args),
cx,
)
.await;
@ -507,10 +495,8 @@ fn main() {
zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).unwrap()
}
ContextProvider::Syntax => {
zeta2_syntax_context(context_args, &app_state, cx)
.await
.unwrap()
ContextProvider::Zeta2 => {
zeta2_context(context_args, &app_state, cx).await.unwrap()
}
};
println!("{}", result);

View file

@ -136,8 +136,7 @@ pub async fn perform_predict(
let result = result.clone();
async move {
let mut start_time = None;
let mut search_queries_generated_at = None;
let mut search_queries_executed_at = None;
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
@ -147,17 +146,17 @@ pub async fn perform_predict(
&info.search_prompt,
)?;
}
zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
zeta::ZetaDebugInfo::ContextRetrievalFinished(info) => {
retrieval_finished_at = Some(info.timestamp);
for (key, value) in &info.metadata {
if *key == "search_queries" {
fs::write(
example_run_dir.join("search_queries.json"),
value.as_bytes(),
)?;
}
}
}
zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
}
zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
@ -200,13 +199,8 @@ pub async fn perform_predict(
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
result.planning_search_time =
Some(search_queries_generated_at.unwrap() - start_time.unwrap());
result.running_search_time = Some(
search_queries_executed_at.unwrap()
- search_queries_generated_at.unwrap(),
);
result.retrieval_time =
retrieval_finished_at.unwrap() - start_time.unwrap();
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
@ -219,7 +213,12 @@ pub async fn perform_predict(
});
zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
zeta.refresh_context_with_agentic_retrieval(
project.clone(),
cursor_buffer.clone(),
cursor_anchor,
cx,
)
})?
.await?;
}
@ -321,8 +320,7 @@ pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
pub planning_search_time: Option<Duration>,
pub running_search_time: Option<Duration>,
pub retrieval_time: Duration,
pub prediction_time: Duration,
pub total_time: Duration,
pub run_example_dir: PathBuf,
@ -336,8 +334,7 @@ impl PredictionDetails {
diff: Default::default(),
excerpts: Default::default(),
excerpts_text: Default::default(),
planning_search_time: Default::default(),
running_search_time: Default::default(),
retrieval_time: Default::default(),
prediction_time: Default::default(),
total_time: Default::default(),
run_example_dir,
@ -357,28 +354,20 @@ impl PredictionDetails {
}
pub fn to_markdown(&self) -> String {
let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
format!(
"## Excerpts\n\n\
{}\n\n\
## Prediction\n\n\
{}\n\n\
## Time\n\n\
Planning searches: {}ms\n\
Running searches: {}ms\n\
Making Prediction: {}ms\n\n\
-------------------\n\n\
Total: {}ms\n\
Inference: {}ms ({:.2}%)\n",
Retrieval: {}ms\n\
Prediction: {}ms\n\n\
Total: {}ms\n",
self.excerpts_text,
self.diff,
self.planning_search_time.unwrap_or_default().as_millis(),
self.running_search_time.unwrap_or_default().as_millis(),
self.retrieval_time.as_millis(),
self.prediction_time.as_millis(),
self.total_time.as_millis(),
inference_time.as_millis(),
(inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
)
}
}

View file

@ -2,7 +2,8 @@ use anyhow::{Result, anyhow};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AsyncApp, Entity, Task};
use language::{Buffer, LanguageId, LanguageServerId, ParseStatus};
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
use project::lsp_store::OpenLspBufferHandle;
use project::{Project, ProjectPath, Worktree};
use std::collections::HashSet;
use std::sync::Arc;
@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server(
path: Arc<RelPath>,
ready_languages: &mut HashSet<LanguageId>,
cx: &mut AsyncApp,
) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server(
)
})?;
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(path.as_std_path())
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
anyhow::bail!(error);
}
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})?
@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server(
return Err(anyhow!("No language for {}", path.display(path_style)));
};
let log_prefix = path.display(path_style);
let log_prefix = format!("{} | ", path.display(path_style));
if !ready_languages.contains(&language_id) {
wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
ready_languages.insert(language_id);
}
@ -95,7 +107,7 @@ pub fn wait_for_lang_server(
log_prefix: String,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
println!("{}⏵ Waiting for language server", log_prefix);
eprintln!("{}⏵ Waiting for language server", log_prefix);
let (mut tx, mut rx) = mpsc::channel(1);
@ -137,7 +149,7 @@ pub fn wait_for_lang_server(
..
} = event
{
println!("{}{message}", log_prefix)
eprintln!("{}{message}", log_prefix)
}
}
}),
@ -162,7 +174,7 @@ pub fn wait_for_lang_server(
cx.spawn(async move |cx| {
if !has_lang_server {
// some buffers never have a language server, so this aborts quickly in that case.
let timeout = cx.background_executor().timer(Duration::from_secs(5));
let timeout = cx.background_executor().timer(Duration::from_secs(500));
futures::select! {
_ = added_rx.next() => {},
_ = timeout.fuse() => {
@ -173,7 +185,7 @@ pub fn wait_for_lang_server(
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
let result = futures::select! {
_ = rx.next() => {
println!("{}⚑ Language server idle", log_prefix);
eprintln!("{}⚑ Language server idle", log_prefix);
anyhow::Ok(())
},
_ = timeout.fuse() => {