diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 95f57a62797..5bc34320a87 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -24,6 +24,7 @@ use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::{mpsc, oneshot}, future::BoxFuture, + stream::BoxStream, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; @@ -1789,6 +1790,34 @@ impl ProtoClient for Client { self.request_dynamic(envelope, request_type).boxed() } + fn request_stream( + &self, + envelope: proto::Envelope, + request_type: &'static str, + ) -> BoxFuture<'static, Result>>> { + let client_id = self.id(); + let response = self.connection_id().map(|connection_id| { + self.peer + .request_stream_dynamic(connection_id, envelope, request_type) + }); + + async move { + log::debug!( + "rpc stream request start. client_id:{}. name:{}", + client_id, + request_type + ); + let response = response?.await; + log::debug!( + "rpc stream request opened. client_id:{}. name:{}", + client_id, + request_type + ); + response + } + .boxed() + } + fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> { log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type); let connection_id = self.connection_id()?; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 309b6a84a65..5f2cb0515ce 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -1435,10 +1435,43 @@ impl GitRepository for FakeGitRepository { fn search_commits( &self, _log_source: LogSource, - _search_args: SearchCommitArgs, - _request_tx: Sender, + search_args: SearchCommitArgs, + request_tx: Sender, ) -> BoxFuture<'_, Result<()>> { - async { bail!("search_commits not supported for FakeGitRepository") }.boxed() + async move { + let query = if search_args.case_sensitive { + search_args.query.to_string() + } else { + search_args.query.to_lowercase() + }; + + let matching_shas = self.fs.with_git_state(&self.dot_git_path, false, |state| { + state + .commit_data + .iter() + .filter_map(|(sha, entry)| { + let FakeCommitDataEntry::Success(commit_data) = entry else { + return None; + }; + let message = if search_args.case_sensitive { + commit_data.message.to_string() + } else { + commit_data.message.to_lowercase() + }; + message.contains(&query).then_some(*sha) + }) + .collect::>() + })?; + + for sha in matching_shas { + if request_tx.send(sha).await.is_err() { + break; + } + } + + Ok(()) + } + .boxed() } fn commit_data_reader(&self) -> Result { diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index b30b943d032..f5dd3bef95c 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -20,7 +20,7 @@ use collections::HashMap; pub use conflict_set::{ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate}; use fs::{Fs, RemoveOptions}; use futures::{ - FutureExt, StreamExt, + FutureExt, SinkExt, Stream, StreamExt, channel::{ mpsc, oneshot::{self, Canceled}, @@ -680,6 +680,7 @@ impl GitStore { client.add_entity_request_handler(Self::handle_edit_ref); client.add_entity_request_handler(Self::handle_repair_worktrees); client.add_entity_request_handler(Self::handle_get_commit_data); + client.add_entity_stream_request_handler(Self::handle_search_commits); } pub fn is_local(&self) -> bool { @@ -2669,6 +2670,63 @@ impl GitStore { Ok(proto::GetCommitDataResponse { commits }) } + async fn handle_search_commits( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result>> { + const CHUNK_SIZE: usize = 100; + + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + let log_source = log_source_from_proto( + envelope + .payload + .log_source + .context("missing search commit log source")?, + )?; + let search_args = SearchCommitArgs { + query: SharedString::from(envelope.payload.query), + case_sensitive: envelope.payload.case_sensitive, + }; + + let (request_tx, request_rx) = async_channel::unbounded(); + repository_handle.update(&mut cx, |repository, cx| { + repository.search_commits(log_source, search_args, request_tx, cx); + }); + + let (mut response_tx, response_rx) = mpsc::unbounded(); + cx.background_spawn(async move { + let mut shas = Vec::new(); + + while let Ok(sha) = request_rx.recv().await { + shas.push(sha.to_string()); + + if shas.len() >= CHUNK_SIZE { + if response_tx + .send(Ok(proto::SearchCommitsResponse { + shas: mem::take(&mut shas), + })) + .await + .is_err() + { + return; + } + } + } + + if !shas.is_empty() { + response_tx + .send(Ok(proto::SearchCommitsResponse { shas })) + .await + .ok(); + } + }) + .detach(); + + Ok(response_rx) + } + async fn handle_edit_ref( this: Entity, envelope: TypedEnvelope, @@ -4974,6 +5032,7 @@ impl Repository { cx: &mut Context, ) { let repository_state = self.repository_state.clone(); + let repository_id = self.id; cx.background_spawn(async move { let repo_state = repository_state.await; @@ -4985,8 +5044,50 @@ impl Repository { .await .log_err(); } - Ok(RepositoryState::Remote(_)) => {} - Err(_) => {} + + Ok(RepositoryState::Remote(RemoteRepositoryState { client, project_id })) => { + let result = client + .request_stream(proto::SearchCommits { + project_id: project_id.to_proto(), + repository_id: repository_id.to_proto(), + log_source: Some(log_source_to_proto(&log_source)), + query: search_args.query.to_string(), + case_sensitive: search_args.case_sensitive, + }) + .await; + + let mut stream = match result { + Ok(stream) => stream, + Err(error) => { + log::error!("failed to search commits remotely: {error:?}"); + return; + } + }; + + while let Some(response) = stream.next().await { + let response = match response { + Ok(response) => response, + Err(error) => { + log::error!( + "failed to receive remote commit search results: {error:?}" + ); + return; + } + }; + + for sha in &response.shas { + let Ok(oid) = Oid::from_str(sha) else { + return; + }; + if request_tx.send(oid).await.is_err() { + return; + } + } + } + } + Err(error) => { + log::error!("failed to get repository state for commit search: {error}"); + } }; }) .detach(); @@ -8119,6 +8220,31 @@ fn deserialize_blame_buffer_response( Some(Blame { entries, messages }) } +fn log_source_to_proto(log_source: &LogSource) -> proto::GitLogSource { + proto::GitLogSource { + source: Some(match log_source { + LogSource::All => proto::git_log_source::Source::All(proto::GitLogSourceAll {}), + LogSource::Branch(branch) => proto::git_log_source::Source::Branch(branch.to_string()), + LogSource::Sha(sha) => proto::git_log_source::Source::Sha(sha.to_string()), + LogSource::Path(path) => proto::git_log_source::Source::Path(path.to_proto()), + }), + } +} + +fn log_source_from_proto(log_source: proto::GitLogSource) -> Result { + match log_source + .source + .context("git log source is missing source")? + { + proto::git_log_source::Source::All(_) => Ok(LogSource::All), + proto::git_log_source::Source::Branch(branch) => Ok(LogSource::Branch(branch.into())), + proto::git_log_source::Source::Sha(sha) => Ok(LogSource::Sha(Oid::from_str(&sha)?)), + proto::git_log_source::Source::Path(path) => { + Ok(LogSource::Path(RepoPath::from_proto(&path)?)) + } + } +} + fn commit_data_to_proto(commit: &CommitData) -> proto::CommitData { proto::CommitData { sha: commit.sha.to_string(), diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index cea288ea2a0..afea6cf34a3 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -693,3 +693,26 @@ message CommitData { message GetCommitDataResponse { repeated CommitData commits = 1; } + +message GitLogSourceAll {} + +message GitLogSource { + oneof source { + GitLogSourceAll all = 1; + string branch = 2; + string sha = 3; + string path = 4; + } +} + +message SearchCommits { + uint64 project_id = 1; + uint64 repository_id = 2; + GitLogSource log_source = 3; + string query = 4; + bool case_sensitive = 5; +} + +message SearchCommitsResponse { + repeated string shas = 1; +} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 28626e687a8..0c149fb2976 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -480,7 +480,9 @@ message Envelope { GitCreateArchiveCheckpointResponse git_create_archive_checkpoint_response = 445; GitRestoreArchiveCheckpoint git_restore_archive_checkpoint = 446; GetCommitData get_commit_data = 447; - GetCommitDataResponse get_commit_data_response = 448; // current max + GetCommitDataResponse get_commit_data_response = 448; + SearchCommits search_commits = 449; + SearchCommitsResponse search_commits_response = 450; // current max } reserved 87 to 88; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 06a4b2b5cc0..651e11354a9 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -358,6 +358,8 @@ messages!( (GitRepairWorktrees, Background), (GetCommitData, Background), (GetCommitDataResponse, Background), + (SearchCommits, Background), + (SearchCommitsResponse, Background), (GitWorktreesResponse, Background), (GitCreateWorktree, Background), (GitRemoveWorktree, Background), @@ -573,6 +575,7 @@ request_messages!( (GitEditRef, Ack), (GitRepairWorktrees, Ack), (GetCommitData, GetCommitDataResponse), + (SearchCommits, SearchCommitsResponse), (GitCreateWorktree, Ack), (GitRemoveWorktree, Ack), (GitRenameWorktree, Ack), @@ -767,6 +770,7 @@ entity_messages!( GitEditRef, GitRepairWorktrees, GetCommitData, + SearchCommits, GitCreateArchiveCheckpoint, GitRestoreArchiveCheckpoint, GitCreateWorktree, diff --git a/crates/remote/src/remote_client.rs b/crates/remote/src/remote_client.rs index a32d5dc75c7..138238c5fd4 100644 --- a/crates/remote/src/remote_client.rs +++ b/crates/remote/src/remote_client.rs @@ -22,6 +22,7 @@ use futures::{ }, future::{BoxFuture, Shared, WeakShared}, select, select_biased, + stream::BoxStream, }; use gpui::{ App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity, @@ -1320,6 +1321,8 @@ impl RemoteConnectionOptions { #[cfg(test)] mod tests { use super::*; + use gpui::TestAppContext; + use rpc::{ErrorCodeExt, proto::ErrorCode}; #[test] fn test_ssh_display_name_prefers_nickname() { @@ -1341,6 +1344,137 @@ mod tests { assert_eq!(options.display_name(), "1.2.3.4"); } + + #[gpui::test] + async fn test_channel_client_request_stream_terminates_on_error(cx: &mut TestAppContext) { + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::(); + + let client = + cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false)); + + // The client sends RemoteStarted on startup; drain the outgoing channel + // so it doesn't block. + let _drain_outgoing = cx + .executor() + .spawn(async move { while outgoing_rx.next().await.is_some() {} }); + + let mut stream = client + .request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test") + .await + .unwrap(); + + let request_id = 0; + + incoming_tx + .unbounded_send(proto::Test { id: 1 }.into_envelope(100, Some(request_id), None)) + .unwrap(); + + let first = stream.next().await.unwrap().unwrap(); + assert_eq!( + proto::Test::from_envelope(first).unwrap(), + proto::Test { id: 1 } + ); + + // Send an Error without a trailing EndStream. The Error alone should + // terminate the stream. + incoming_tx + .unbounded_send( + ErrorCode::Internal + .message("boom".to_string()) + .to_proto() + .into_envelope(101, Some(request_id), None), + ) + .unwrap(); + + let second = stream.next().await.unwrap(); + let error = second.unwrap_err(); + assert!( + format!("{error}").contains("boom"), + "expected error to surface server message, got: {error}" + ); + + assert!(stream.next().await.is_none()); + assert_eq!(client.stream_response_channels.lock().len(), 0); + } + + #[gpui::test] + async fn test_channel_client_dropping_stream_request_before_response_cleans_up_channel( + cx: &mut TestAppContext, + ) { + let (_incoming_tx, incoming_rx) = mpsc::unbounded::(); + let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::(); + + let client = + cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false)); + + let _drain_outgoing = cx + .executor() + .spawn(async move { while outgoing_rx.next().await.is_some() {} }); + + let stream = client + .request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test") + .await + .unwrap(); + + assert_eq!(client.stream_response_channels.lock().len(), 1); + + drop(stream); + cx.run_until_parked(); + + assert_eq!( + client.stream_response_channels.lock().len(), + 0, + "dropping a stream before any responses arrive should remove response channel bookkeeping" + ); + } + + #[gpui::test] + async fn test_channel_client_dropping_stream_request_before_completion( + cx: &mut TestAppContext, + ) { + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::(); + + let client = + cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false)); + + let _drain_outgoing = cx + .executor() + .spawn(async move { while outgoing_rx.next().await.is_some() {} }); + + let mut stream = client + .request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test") + .await + .unwrap(); + + let request_id = 0; + + incoming_tx + .unbounded_send(proto::Test { id: 1 }.into_envelope(100, Some(request_id), None)) + .unwrap(); + let _ = stream.next().await.unwrap().unwrap(); + + assert_eq!(client.stream_response_channels.lock().len(), 1); + + drop(stream); + + // Inject an orphaned non-terminal response. The read loop should detect + // that the consumer has been dropped and clean up its bookkeeping (no + // EndStream sent here on purpose, otherwise the cleanup would happen + // via the terminal-response path and mask the bug under test). + incoming_tx + .unbounded_send(proto::Test { id: 2 }.into_envelope(101, Some(request_id), None)) + .unwrap(); + + cx.run_until_parked(); + + assert_eq!( + client.stream_response_channels.lock().len(), + 0, + "stream channel should be removed once the consumer has dropped the stream" + ); + } } impl From for RemoteConnectionOptions { @@ -1418,6 +1552,8 @@ pub trait RemoteConnection: Send + Sync { } type ResponseChannels = Mutex)>>>; +type StreamResponseChannels = + Arc, oneshot::Sender<()>)>>>>; struct Signal { tx: Mutex>>, @@ -1455,6 +1591,7 @@ pub(crate) struct ChannelClient { outgoing_tx: Mutex>, buffer: Mutex>, response_channels: ResponseChannels, + stream_response_channels: StreamResponseChannels, message_handlers: Mutex, max_received: AtomicU32, name: &'static str, @@ -1477,6 +1614,7 @@ impl ChannelClient { next_message_id: AtomicU32::new(0), max_received: AtomicU32::new(0), response_channels: ResponseChannels::default(), + stream_response_channels: StreamResponseChannels::default(), message_handlers: Default::default(), buffer: Mutex::new(VecDeque::new()), name, @@ -1550,13 +1688,40 @@ impl ChannelClient { if let Some(request_id) = incoming.responding_to { let request_id = MessageId(request_id); + // An incoming response with no payload is malformed; drop + // it. The request future and any stream consumers will + // remain pending until either a real response arrives or + // the connection is torn down. + if incoming.payload.is_none() { + continue; + } let sender = this.response_channels.lock().remove(&request_id); if let Some(sender) = sender { let (tx, rx) = oneshot::channel(); - if incoming.payload.is_some() { - sender.send((incoming, tx)).ok(); - } + sender.send((incoming, tx)).ok(); rx.await.ok(); + } else { + let terminal_stream_response = matches!( + &incoming.payload, + Some(proto::envelope::Payload::Error(_)) + | Some(proto::envelope::Payload::EndStream(_)) + ); + let sender = if terminal_stream_response { + this.stream_response_channels.lock().remove(&request_id) + } else { + this.stream_response_channels + .lock() + .get(&request_id) + .cloned() + }; + if let Some(sender) = sender { + let (tx, rx) = oneshot::channel(); + if sender.unbounded_send((Ok(incoming), tx)).is_err() { + this.stream_response_channels.lock().remove(&request_id); + continue; + } + rx.await.ok(); + } } } else if let Some(envelope) = build_typed_envelope(peer_id, Instant::now(), incoming) @@ -1721,6 +1886,55 @@ impl ChannelClient { } } + fn request_stream_dynamic( + &self, + mut envelope: proto::Envelope, + type_name: &'static str, + ) -> impl 'static + Future>>> { + envelope.id = self.next_message_id.fetch_add(1, SeqCst); + let message_id = MessageId(envelope.id); + let (tx, rx) = mpsc::unbounded(); + let stream_response_channels = self.stream_response_channels.clone(); + stream_response_channels.lock().insert(message_id, tx); + + let result = self.send_buffered(envelope); + async move { + if let Err(error) = &result { + log::error!("failed to send message: {error}"); + anyhow::bail!("failed to send message: {error}"); + } + + let cleanup_stream_response_channel = util::defer({ + let stream_response_channels = stream_response_channels.clone(); + move || { + stream_response_channels.lock().remove(&message_id); + } + }); + + Ok(rx + .filter_map(move |(response, _barrier)| { + // Keep the cleanup guard alive until the returned stream is dropped. + let _keep_cleanup_guard_alive = &cleanup_stream_response_channel; + futures::future::ready(match response { + Ok(response) => { + if let Some(proto::envelope::Payload::Error(error)) = &response.payload + { + Some(Err(RpcError::from_proto(error, type_name))) + } else if let Some(proto::envelope::Payload::EndStream(_)) = + &response.payload + { + None + } else { + Some(Ok(response)) + } + } + Err(error) => Some(Err(error)), + }) + }) + .boxed()) + } + } + pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> { envelope.id = self.next_message_id.fetch_add(1, SeqCst); self.send_buffered(envelope) @@ -1751,6 +1965,14 @@ impl ProtoClient for ChannelClient { self.request_dynamic(envelope, request_type, true).boxed() } + fn request_stream( + &self, + envelope: proto::Envelope, + request_type: &'static str, + ) -> BoxFuture<'static, Result>>> { + self.request_stream_dynamic(envelope, request_type).boxed() + } + fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> { self.send_dynamic(envelope) } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 825c0ba26c0..d31403275cb 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -11,7 +11,10 @@ use languages::rust_lang; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs}; -use git::repository::Worktree as GitWorktree; +use git::{ + Oid, + repository::{CommitData, Worktree as GitWorktree}, +}; use gpui::{AppContext as _, Entity, SharedString, TestAppContext}; use http_client::{BlockedHttpClient, FakeHttpClient}; use language::{ @@ -29,11 +32,13 @@ use project::{ search::{SearchQuery, SearchResult}, }; use remote::RemoteClient; +use rpc::proto; use serde_json::json; use settings::{Settings, SettingsLocation, SettingsStore, initial_server_settings_content}; use smol::stream::StreamExt; use std::{ path::{Path, PathBuf}, + str::FromStr, sync::Arc, }; use unindent::Unindent as _; @@ -1626,6 +1631,108 @@ async fn test_remote_root_repo_common_dir(cx: &mut TestAppContext, server_cx: &m assert_eq!(common_dir, None); } +#[gpui::test] +async fn test_remote_search_commits_streams_proto_chunks( + cx: &mut TestAppContext, + server_cx: &mut TestAppContext, +) { + const COMMIT_COUNT: usize = 900; + const RESPONSE_MAX_SIZE: usize = 100; + + let fs = FakeFs::new(server_cx.executor()); + fs.insert_tree( + path!("/code"), + json!({ + "project1": { + ".git": {}, + "file.txt": "content", + }, + }), + ) + .await; + + let commit_data = (0..COMMIT_COUNT) + .map(|index| { + let sha = Oid::from_str(&format!("{:040x}", index + 1)).unwrap(); + ( + CommitData { + sha, + parents: Default::default(), + author_name: SharedString::from("Author"), + author_email: SharedString::from("author@example.com"), + commit_timestamp: index as i64, + subject: SharedString::from(format!("Subject {index}")), + message: SharedString::from(format!("needle commit {index}")), + }, + false, + ) + }) + .collect::>(); + let expected_shas = commit_data + .iter() + .map(|(commit_data, _)| commit_data.sha.to_string()) + .collect::>(); + fs.set_commit_data(Path::new(path!("/code/project1/.git")), commit_data); + + let (project, _headless) = init_test(&fs, cx, server_cx).await; + project + .update(cx, |project, cx| { + project.find_or_create_worktree(path!("/code/project1"), true, cx) + }) + .await + .expect("should open remote worktree"); + server_cx.run_until_parked(); + cx.run_until_parked(); + project + .update(cx, |project, cx| project.git_scans_complete(cx)) + .await; + + let (remote_client, repository_id) = project.read_with(cx, |project, cx| { + let repository = project + .active_repository(cx) + .expect("remote project should have an active repository"); + let repository_id = repository.read(cx).snapshot().id; + let remote_client = project + .remote_client() + .expect("project should have a remote client"); + (remote_client, repository_id) + }); + let proto_client = remote_client.read_with(cx, |remote_client, _| remote_client.proto_client()); + let mut stream = proto_client + .request_stream(proto::SearchCommits { + project_id: proto::REMOTE_SERVER_PROJECT_ID, + repository_id: repository_id.to_proto(), + log_source: Some(proto::GitLogSource { + source: Some(proto::git_log_source::Source::All( + proto::GitLogSourceAll {}, + )), + }), + query: "needle".to_string(), + case_sensitive: true, + }) + .await + .expect("search commits stream should start"); + + let mut chunks = Vec::new(); + while let Some(response) = futures::StreamExt::next(&mut stream).await { + chunks.push(response.expect("search commits chunk should succeed").shas); + } + + assert!( + chunks.len() > 1, + "expected search results to stream in multiple chunks" + ); + for chunk in chunks.iter().take(chunks.len() - 1) { + assert!( + chunk.len() <= RESPONSE_MAX_SIZE, + "non-final chunks should meet the target byte size" + ); + } + + let actual_shas = chunks.into_iter().flatten().collect::>(); + assert_eq!(actual_shas, expected_shas); +} + #[gpui::test] async fn test_remote_archive_git_operations_are_supported( cx: &mut TestAppContext, diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 73be0f19fe2..d9f34d0dc59 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -8,7 +8,7 @@ use super::{ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; use futures::{ - FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, + FutureExt, SinkExt, StreamExt, TryFutureExt, channel::{mpsc, oneshot}, stream::BoxStream, }; @@ -278,11 +278,23 @@ impl Peer { ); let response_channel = response_channels.lock().as_mut()?.remove(&responding_to); - let stream_response_channel = stream_response_channels - .lock() - .as_ref()? - .get(&responding_to) - .cloned(); + let terminal_stream_response = matches!( + &incoming.payload, + Some(proto::envelope::Payload::Error(_)) + | Some(proto::envelope::Payload::EndStream(_)) + ); + let stream_response_channel = if terminal_stream_response { + stream_response_channels + .lock() + .as_mut()? + .remove(&responding_to) + } else { + stream_response_channels + .lock() + .as_ref()? + .get(&responding_to) + .cloned() + }; if let Some(tx) = response_channel { let requester_resumed = oneshot::channel(); @@ -319,21 +331,15 @@ impl Peer { ?error, "incoming stream response: request future dropped", ); + // The consumer has gone away, so drop the bookkeeping + // for this stream rather than letting it accumulate + // every subsequent message until a terminal frame. + if let Some(channels) = stream_response_channels.lock().as_mut() { + channels.remove(&responding_to); + } + } else { + let _ = requester_resumed.1.await; } - - tracing::debug!( - %connection_id, - message_id, - responding_to, - "incoming stream response: waiting to resume requester" - ); - let _ = requester_resumed.1.await; - tracing::debug!( - %connection_id, - message_id, - responding_to, - "incoming stream response: requester resumed" - ); } else { let message_type = proto::build_typed_envelope( connection_id.into(), @@ -484,55 +490,96 @@ impl Peer { &self, receiver_id: ConnectionId, request: T, - ) -> impl Future>>> { + ) -> impl Future>>> { + let stream = + self.request_stream_dynamic(receiver_id, request.into_envelope(0, None, None), T::NAME); + + async move { + Ok(stream + .await? + .map(|response| { + T::Response::from_envelope(response?) + .context("received response of the wrong type") + }) + .boxed()) + } + } + + pub fn request_stream_dynamic( + &self, + receiver_id: ConnectionId, + mut envelope: proto::Envelope, + request_type: &'static str, + ) -> impl Future>>> + use<> { let (tx, rx) = mpsc::unbounded(); let send = self.connection_state(receiver_id).and_then(|connection| { let message_id = connection.next_message_id.fetch_add(1, SeqCst); + envelope.id = message_id; let stream_response_channels = connection.stream_response_channels.clone(); stream_response_channels .lock() .as_mut() .context("connection was closed")? .insert(message_id, tx); - connection + if let Err(error) = connection .outgoing_tx - .unbounded_send(Message::Envelope( - request.into_envelope(message_id, None, None), - )) - .context("connection was closed")?; + .unbounded_send(Message::Envelope(envelope)) + { + if let Some(channels) = stream_response_channels.lock().as_mut() { + channels.remove(&message_id); + } + return Err(error).context("connection was closed"); + } Ok((message_id, stream_response_channels)) }); async move { let (message_id, stream_response_channels) = send?; let stream_response_channels = Arc::downgrade(&stream_response_channels); - - Ok(rx.filter_map(move |(response, _barrier)| { + let cleanup_stream_response_channel = util::defer({ let stream_response_channels = stream_response_channels.clone(); - future::ready(match response { - Ok(response) => { - if let Some(proto::envelope::Payload::Error(error)) = &response.payload { - Some(Err(RpcError::from_proto(error, T::NAME))) - } else if let Some(proto::envelope::Payload::EndStream(_)) = - &response.payload - { - // Remove the transmitting end of the response channel to end the stream. - if let Some(channels) = stream_response_channels.upgrade() - && let Some(channels) = channels.lock().as_mut() - { - channels.remove(&message_id); - } - None - } else { - Some( - T::Response::from_envelope(response) - .context("received response of the wrong type"), - ) - } + move || { + if let Some(channels) = stream_response_channels.upgrade() + && let Some(channels) = channels.lock().as_mut() + { + channels.remove(&message_id); } - Err(error) => Some(Err(error)), + } + }); + + Ok(rx + .filter_map(move |(response, _barrier)| { + let _keep_cleanup_guard_alive = &cleanup_stream_response_channel; + let stream_response_channels = stream_response_channels.clone(); + future::ready(match response { + Ok(response) => { + if let Some(proto::envelope::Payload::Error(error)) = &response.payload + { + // Remove the transmitting end of the response channel to end the stream. + if let Some(channels) = stream_response_channels.upgrade() + && let Some(channels) = channels.lock().as_mut() + { + channels.remove(&message_id); + } + Some(Err(RpcError::from_proto(error, request_type))) + } else if let Some(proto::envelope::Payload::EndStream(_)) = + &response.payload + { + // Remove the transmitting end of the response channel to end the stream. + if let Some(channels) = stream_response_channels.upgrade() + && let Some(channels) = channels.lock().as_mut() + { + channels.remove(&message_id); + } + None + } else { + Some(Ok(response)) + } + } + Err(error) => Some(Err(error)), + }) }) - })) + .boxed()) } } @@ -661,6 +708,13 @@ impl Peer { .with_context(|| format!("no such connection: {connection_id}"))?; Ok(connection.clone()) } + + #[cfg(any(test, feature = "test-support"))] + pub fn pending_stream_request_count(&self, connection_id: ConnectionId) -> Option { + let connection = self.connection_state(connection_id).ok()?; + let channels = connection.stream_response_channels.lock(); + Some(channels.as_ref()?.len()) + } } impl Serialize for Peer { @@ -992,6 +1046,268 @@ mod tests { ); } + #[gpui::test(iterations = 50)] + async fn test_request_stream(cx: &mut TestAppContext) { + init_logger(); + + let executor = cx.executor(); + let server = Peer::new(0); + let client = Peer::new(0); + + let (client_to_server_conn, server_to_client_conn, _kill) = + Connection::in_memory(executor.clone()); + let (client_to_server_conn_id, io_task1, mut client_incoming) = + client.add_test_connection(client_to_server_conn, executor.clone()); + let (_, io_task2, mut server_incoming) = + server.add_test_connection(server_to_client_conn, executor.clone()); + + executor.spawn(io_task1).detach(); + executor.spawn(io_task2).detach(); + executor + .spawn(async move { while client_incoming.next().await.is_some() {} }) + .detach(); + + executor + .spawn({ + let server = server.clone(); + async move { + let request = server_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + let receipt = request.receipt(); + server.respond(receipt, proto::Test { id: 1 }).unwrap(); + server.respond(receipt, proto::Test { id: 2 }).unwrap(); + server.respond(receipt, proto::Test { id: 3 }).unwrap(); + server.end_stream(receipt).unwrap(); + + // Prevent the connection from being dropped. + server_incoming.next().await; + } + }) + .detach(); + + let mut stream = client + .request_stream(client_to_server_conn_id, proto::Test { id: 0 }) + .await + .unwrap(); + + let mut received = Vec::new(); + while let Some(item) = stream.next().await { + received.push(item.unwrap()); + } + + assert_eq!( + received, + vec![ + proto::Test { id: 1 }, + proto::Test { id: 2 }, + proto::Test { id: 3 }, + ] + ); + assert_eq!( + client.pending_stream_request_count(client_to_server_conn_id), + Some(0) + ); + } + + #[gpui::test] + async fn test_request_stream_send_failure_cleans_up_response_channel(cx: &mut TestAppContext) { + init_logger(); + + let executor = cx.executor(); + let client = Peer::new(0); + + let (client_to_server_conn, _server_to_client_conn, _kill) = + Connection::in_memory(executor.clone()); + let (client_to_server_conn_id, io_task, _client_incoming) = + client.add_test_connection(client_to_server_conn, executor.clone()); + + drop(io_task); + + let result = client + .request_stream(client_to_server_conn_id, proto::Test { id: 0 }) + .await; + + assert!( + result.is_err(), + "stream request should fail when the connection write task has gone away" + ); + assert_eq!( + client.pending_stream_request_count(client_to_server_conn_id), + Some(0), + "failed stream request should not leave response channel bookkeeping behind" + ); + } + + #[gpui::test(iterations = 50)] + async fn test_request_stream_terminates_on_error(cx: &mut TestAppContext) { + init_logger(); + + let executor = cx.executor(); + let server = Peer::new(0); + let client = Peer::new(0); + + let (client_to_server_conn, server_to_client_conn, _kill) = + Connection::in_memory(executor.clone()); + let (client_to_server_conn_id, io_task1, mut client_incoming) = + client.add_test_connection(client_to_server_conn, executor.clone()); + let (_, io_task2, mut server_incoming) = + server.add_test_connection(server_to_client_conn, executor.clone()); + + executor.spawn(io_task1).detach(); + executor.spawn(io_task2).detach(); + executor + .spawn(async move { while client_incoming.next().await.is_some() {} }) + .detach(); + + executor + .spawn({ + let server = server.clone(); + async move { + let request = server_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + let receipt = request.receipt(); + server.respond(receipt, proto::Test { id: 1 }).unwrap(); + // Send an Error without a trailing EndStream. The Error alone + // should be treated as a terminal stream response. + server + .respond_with_error( + receipt, + ErrorCode::Internal.message("boom".to_string()).to_proto(), + ) + .unwrap(); + + // Prevent the connection from being dropped. + server_incoming.next().await; + } + }) + .detach(); + + let mut stream = client + .request_stream(client_to_server_conn_id, proto::Test { id: 0 }) + .await + .unwrap(); + + assert_eq!(stream.next().await.unwrap().unwrap(), proto::Test { id: 1 }); + + let error = stream.next().await.unwrap().unwrap_err(); + assert!( + format!("{error}").contains("boom"), + "expected error to surface server message, got: {error}" + ); + + // The error alone (without an EndStream) should terminate the stream. + assert!(stream.next().await.is_none()); + assert_eq!( + client.pending_stream_request_count(client_to_server_conn_id), + Some(0) + ); + } + + #[gpui::test(iterations = 50)] + async fn test_dropping_stream_request_before_completion(cx: &mut TestAppContext) { + init_logger(); + + let executor = cx.executor(); + let server = Peer::new(0); + let client = Peer::new(0); + + let (client_to_server_conn, server_to_client_conn, _kill) = + Connection::in_memory(executor.clone()); + let (client_to_server_conn_id, io_task1, mut client_incoming) = + client.add_test_connection(client_to_server_conn, executor.clone()); + let (_, io_task2, mut server_incoming) = + server.add_test_connection(server_to_client_conn, executor.clone()); + + executor.spawn(io_task1).detach(); + executor.spawn(io_task2).detach(); + executor + .spawn(async move { while client_incoming.next().await.is_some() {} }) + .detach(); + + let (drop_signal_tx, drop_signal_rx) = oneshot::channel::<()>(); + let server_task = executor.spawn({ + let server = server.clone(); + async move { + let request = server_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + let receipt = request.receipt(); + server.respond(receipt, proto::Test { id: 1 }).unwrap(); + + // Wait until the consumer has dropped the stream. + drop_signal_rx.await.ok(); + + // Send a non-terminal response after the consumer is gone. The + // peer should detect that the receiver has been dropped and clean + // up its bookkeeping. Crucially, we do NOT send EndStream here + // because that would clean up via the terminal-response path and + // mask the bug. + server.respond(receipt, proto::Test { id: 2 }).unwrap(); + + // A Ping/Ack round-trip after the response acts as a sync + // barrier: because messages over the in-memory connection are + // delivered in order, by the time the client observes the Ack, + // it has already processed the dropped response above. + let ping = server_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + server.respond(ping.receipt(), proto::Ack {}).unwrap(); + + // Prevent the connection from being dropped. + server_incoming.next().await; + } + }); + + let mut stream = client + .request_stream(client_to_server_conn_id, proto::Test { id: 0 }) + .await + .unwrap(); + + assert_eq!(stream.next().await.unwrap().unwrap(), proto::Test { id: 1 }); + + // The stream is mid-flight, so the channel should be tracked. + assert_eq!( + client.pending_stream_request_count(client_to_server_conn_id), + Some(1) + ); + + drop(stream); + drop_signal_tx.send(()).ok(); + + // Synchronization barrier: once this Ack arrives, the read loop has + // already processed the orphaned stream response that came before it. + client + .request(client_to_server_conn_id, proto::Ping {}) + .await + .unwrap(); + + assert_eq!( + client.pending_stream_request_count(client_to_server_conn_id), + Some(0), + "stream channel should be removed once the consumer has dropped the stream" + ); + + drop(server_task); + } + #[gpui::test(iterations = 50)] async fn test_disconnect(cx: &mut TestAppContext) { let executor = cx.executor(); diff --git a/crates/rpc/src/proto_client.rs b/crates/rpc/src/proto_client.rs index ba8b8782725..cb45948d5cd 100644 --- a/crates/rpc/src/proto_client.rs +++ b/crates/rpc/src/proto_client.rs @@ -1,9 +1,10 @@ use anyhow::{Context, Result}; use collections::HashMap; use futures::{ - Future, FutureExt as _, + Future, FutureExt as _, Stream, StreamExt as _, channel::oneshot, future::{BoxFuture, LocalBoxFuture}, + stream::BoxStream, }; use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _}; use parking_lot::Mutex; @@ -61,6 +62,20 @@ pub trait ProtoClient: Send + Sync { request_type: &'static str, ) -> BoxFuture<'static, Result>; + fn request_stream( + &self, + envelope: Envelope, + request_type: &'static str, + ) -> BoxFuture<'static, Result>>> { + async move { + anyhow::bail!( + "stream requests are not supported for {request_type}: {:?}", + envelope.payload + ) + } + .boxed() + } + fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>; fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>; @@ -223,6 +238,23 @@ impl AnyProtoClient { } } + pub fn request_stream( + &self, + request: T, + ) -> impl Future>>> + use { + let envelope = request.into_envelope(0, None, None); + let response_stream = self.0.client.request_stream(envelope, T::NAME); + async move { + Ok(response_stream + .await? + .map(|response| { + T::Response::from_envelope(response?) + .context("received response of the wrong type") + }) + .boxed()) + } + } + pub fn send(&self, request: T) -> Result<()> { let envelope = request.into_envelope(0, None, None); self.0.client.send(envelope, T::NAME) @@ -479,6 +511,68 @@ impl AnyProtoClient { ); } + pub fn add_entity_stream_request_handler(&self, handler: H) + where + M: EnvelopedMessage + RequestMessage + EntityMessage, + E: 'static, + H: 'static + Sync + Send + Fn(gpui::Entity, TypedEnvelope, AsyncApp) -> F, + F: 'static + Future>, + S: 'static + Stream>, + { + let message_type_id = TypeId::of::(); + let entity_type_id = TypeId::of::(); + let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| { + (envelope as &dyn Any) + .downcast_ref::>() + .unwrap() + .payload + .remote_entity_id() + }; + self.0 + .client + .message_handler_set() + .lock() + .add_entity_message_handler( + message_type_id, + entity_type_id, + entity_id_extractor, + Arc::new(move |entity, envelope, client, cx| { + let entity = entity.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + let request_id = envelope.message_id(); + let stream = handler(entity, *envelope, cx); + async move { + // An Error response is itself a terminal stream frame on + // both transports (Peer and ChannelClient), so we don't + // need to follow it with an EndStream. + match stream.await { + Ok(stream) => { + futures::pin_mut!(stream); + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + client.send_response(request_id, response)? + } + Err(error) => { + client.send_response(request_id, error.to_proto())?; + return Err(error); + } + } + } + client.send_response(request_id, proto::EndStream {})?; + Ok(()) + } + Err(error) => { + client.send_response(request_id, error.to_proto())?; + Err(error) + } + } + } + .boxed_local() + }), + ); + } + pub fn add_entity_message_handler(&self, handler: H) where M: EnvelopedMessage + EntityMessage,