mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-23 12:37:09 +00:00
git_graph: Add remote support for search operations (#55167)
### Motivation This is the second of three PRs to add remote/collab support for the git graph and is a follow-up to #54468. I'm adding remote support for the search because it's not user accessible without the initial graph fetch having remote support, so it allows us to merge this without having to add full remote support. Collab guest support will be added in a follow-up PR. #### Summary For large repos, searching can take a while to fully stream in all matched results. For example, running a basic search on the Linux repo took over 10s for me. Because of that, we want to stream search results in chunks to downstream users to keep the time-to-first-match low. After this change, the first chunk gets sent back after ~50ms on the Linux repo from receiving the request. In order to accomplish that, I added a new proto client API that allows for a request to map to n responses. e.g. ```/dev/null/example.rs#L1-1 client.add_entity_stream_request_handler(Self::handle_search_commits); ``` Note: The proto API isn't supported over collab yet, that will be another PR Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A --------- Co-authored-by: cameron <cameron.studdstreet@gmail.com>
This commit is contained in:
parent
8aedcbf410
commit
149cd4e2bc
10 changed files with 1018 additions and 62 deletions
|
|
@ -24,6 +24,7 @@ use futures::{
|
|||
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||
channel::{mpsc, oneshot},
|
||||
future::BoxFuture,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
|
||||
|
|
@ -1789,6 +1790,34 @@ impl ProtoClient for Client {
|
|||
self.request_dynamic(envelope, request_type).boxed()
|
||||
}
|
||||
|
||||
fn request_stream(
|
||||
&self,
|
||||
envelope: proto::Envelope,
|
||||
request_type: &'static str,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::Envelope>>>> {
|
||||
let client_id = self.id();
|
||||
let response = self.connection_id().map(|connection_id| {
|
||||
self.peer
|
||||
.request_stream_dynamic(connection_id, envelope, request_type)
|
||||
});
|
||||
|
||||
async move {
|
||||
log::debug!(
|
||||
"rpc stream request start. client_id:{}. name:{}",
|
||||
client_id,
|
||||
request_type
|
||||
);
|
||||
let response = response?.await;
|
||||
log::debug!(
|
||||
"rpc stream request opened. client_id:{}. name:{}",
|
||||
client_id,
|
||||
request_type
|
||||
);
|
||||
response
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
|
||||
log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
|
||||
let connection_id = self.connection_id()?;
|
||||
|
|
|
|||
|
|
@ -1435,10 +1435,43 @@ impl GitRepository for FakeGitRepository {
|
|||
fn search_commits(
|
||||
&self,
|
||||
_log_source: LogSource,
|
||||
_search_args: SearchCommitArgs,
|
||||
_request_tx: Sender<Oid>,
|
||||
search_args: SearchCommitArgs,
|
||||
request_tx: Sender<Oid>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
async { bail!("search_commits not supported for FakeGitRepository") }.boxed()
|
||||
async move {
|
||||
let query = if search_args.case_sensitive {
|
||||
search_args.query.to_string()
|
||||
} else {
|
||||
search_args.query.to_lowercase()
|
||||
};
|
||||
|
||||
let matching_shas = self.fs.with_git_state(&self.dot_git_path, false, |state| {
|
||||
state
|
||||
.commit_data
|
||||
.iter()
|
||||
.filter_map(|(sha, entry)| {
|
||||
let FakeCommitDataEntry::Success(commit_data) = entry else {
|
||||
return None;
|
||||
};
|
||||
let message = if search_args.case_sensitive {
|
||||
commit_data.message.to_string()
|
||||
} else {
|
||||
commit_data.message.to_lowercase()
|
||||
};
|
||||
message.contains(&query).then_some(*sha)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})?;
|
||||
|
||||
for sha in matching_shas {
|
||||
if request_tx.send(sha).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn commit_data_reader(&self) -> Result<CommitDataReader> {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ use collections::HashMap;
|
|||
pub use conflict_set::{ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::{
|
||||
FutureExt, StreamExt,
|
||||
FutureExt, SinkExt, Stream, StreamExt,
|
||||
channel::{
|
||||
mpsc,
|
||||
oneshot::{self, Canceled},
|
||||
|
|
@ -680,6 +680,7 @@ impl GitStore {
|
|||
client.add_entity_request_handler(Self::handle_edit_ref);
|
||||
client.add_entity_request_handler(Self::handle_repair_worktrees);
|
||||
client.add_entity_request_handler(Self::handle_get_commit_data);
|
||||
client.add_entity_stream_request_handler(Self::handle_search_commits);
|
||||
}
|
||||
|
||||
pub fn is_local(&self) -> bool {
|
||||
|
|
@ -2669,6 +2670,63 @@ impl GitStore {
|
|||
Ok(proto::GetCommitDataResponse { commits })
|
||||
}
|
||||
|
||||
async fn handle_search_commits(
|
||||
this: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::SearchCommits>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<impl Stream<Item = Result<proto::SearchCommitsResponse>>> {
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
|
||||
let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
|
||||
let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
|
||||
let log_source = log_source_from_proto(
|
||||
envelope
|
||||
.payload
|
||||
.log_source
|
||||
.context("missing search commit log source")?,
|
||||
)?;
|
||||
let search_args = SearchCommitArgs {
|
||||
query: SharedString::from(envelope.payload.query),
|
||||
case_sensitive: envelope.payload.case_sensitive,
|
||||
};
|
||||
|
||||
let (request_tx, request_rx) = async_channel::unbounded();
|
||||
repository_handle.update(&mut cx, |repository, cx| {
|
||||
repository.search_commits(log_source, search_args, request_tx, cx);
|
||||
});
|
||||
|
||||
let (mut response_tx, response_rx) = mpsc::unbounded();
|
||||
cx.background_spawn(async move {
|
||||
let mut shas = Vec::new();
|
||||
|
||||
while let Ok(sha) = request_rx.recv().await {
|
||||
shas.push(sha.to_string());
|
||||
|
||||
if shas.len() >= CHUNK_SIZE {
|
||||
if response_tx
|
||||
.send(Ok(proto::SearchCommitsResponse {
|
||||
shas: mem::take(&mut shas),
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !shas.is_empty() {
|
||||
response_tx
|
||||
.send(Ok(proto::SearchCommitsResponse { shas }))
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(response_rx)
|
||||
}
|
||||
|
||||
async fn handle_edit_ref(
|
||||
this: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::GitEditRef>,
|
||||
|
|
@ -4974,6 +5032,7 @@ impl Repository {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let repository_state = self.repository_state.clone();
|
||||
let repository_id = self.id;
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let repo_state = repository_state.await;
|
||||
|
|
@ -4985,8 +5044,50 @@ impl Repository {
|
|||
.await
|
||||
.log_err();
|
||||
}
|
||||
Ok(RepositoryState::Remote(_)) => {}
|
||||
Err(_) => {}
|
||||
|
||||
Ok(RepositoryState::Remote(RemoteRepositoryState { client, project_id })) => {
|
||||
let result = client
|
||||
.request_stream(proto::SearchCommits {
|
||||
project_id: project_id.to_proto(),
|
||||
repository_id: repository_id.to_proto(),
|
||||
log_source: Some(log_source_to_proto(&log_source)),
|
||||
query: search_args.query.to_string(),
|
||||
case_sensitive: search_args.case_sensitive,
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut stream = match result {
|
||||
Ok(stream) => stream,
|
||||
Err(error) => {
|
||||
log::error!("failed to search commits remotely: {error:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
log::error!(
|
||||
"failed to receive remote commit search results: {error:?}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for sha in &response.shas {
|
||||
let Ok(oid) = Oid::from_str(sha) else {
|
||||
return;
|
||||
};
|
||||
if request_tx.send(oid).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("failed to get repository state for commit search: {error}");
|
||||
}
|
||||
};
|
||||
})
|
||||
.detach();
|
||||
|
|
@ -8119,6 +8220,31 @@ fn deserialize_blame_buffer_response(
|
|||
Some(Blame { entries, messages })
|
||||
}
|
||||
|
||||
fn log_source_to_proto(log_source: &LogSource) -> proto::GitLogSource {
|
||||
proto::GitLogSource {
|
||||
source: Some(match log_source {
|
||||
LogSource::All => proto::git_log_source::Source::All(proto::GitLogSourceAll {}),
|
||||
LogSource::Branch(branch) => proto::git_log_source::Source::Branch(branch.to_string()),
|
||||
LogSource::Sha(sha) => proto::git_log_source::Source::Sha(sha.to_string()),
|
||||
LogSource::Path(path) => proto::git_log_source::Source::Path(path.to_proto()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_source_from_proto(log_source: proto::GitLogSource) -> Result<LogSource> {
|
||||
match log_source
|
||||
.source
|
||||
.context("git log source is missing source")?
|
||||
{
|
||||
proto::git_log_source::Source::All(_) => Ok(LogSource::All),
|
||||
proto::git_log_source::Source::Branch(branch) => Ok(LogSource::Branch(branch.into())),
|
||||
proto::git_log_source::Source::Sha(sha) => Ok(LogSource::Sha(Oid::from_str(&sha)?)),
|
||||
proto::git_log_source::Source::Path(path) => {
|
||||
Ok(LogSource::Path(RepoPath::from_proto(&path)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn commit_data_to_proto(commit: &CommitData) -> proto::CommitData {
|
||||
proto::CommitData {
|
||||
sha: commit.sha.to_string(),
|
||||
|
|
|
|||
|
|
@ -693,3 +693,26 @@ message CommitData {
|
|||
message GetCommitDataResponse {
|
||||
repeated CommitData commits = 1;
|
||||
}
|
||||
|
||||
message GitLogSourceAll {}
|
||||
|
||||
message GitLogSource {
|
||||
oneof source {
|
||||
GitLogSourceAll all = 1;
|
||||
string branch = 2;
|
||||
string sha = 3;
|
||||
string path = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message SearchCommits {
|
||||
uint64 project_id = 1;
|
||||
uint64 repository_id = 2;
|
||||
GitLogSource log_source = 3;
|
||||
string query = 4;
|
||||
bool case_sensitive = 5;
|
||||
}
|
||||
|
||||
message SearchCommitsResponse {
|
||||
repeated string shas = 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -480,7 +480,9 @@ message Envelope {
|
|||
GitCreateArchiveCheckpointResponse git_create_archive_checkpoint_response = 445;
|
||||
GitRestoreArchiveCheckpoint git_restore_archive_checkpoint = 446;
|
||||
GetCommitData get_commit_data = 447;
|
||||
GetCommitDataResponse get_commit_data_response = 448; // current max
|
||||
GetCommitDataResponse get_commit_data_response = 448;
|
||||
SearchCommits search_commits = 449;
|
||||
SearchCommitsResponse search_commits_response = 450; // current max
|
||||
}
|
||||
|
||||
reserved 87 to 88;
|
||||
|
|
|
|||
|
|
@ -358,6 +358,8 @@ messages!(
|
|||
(GitRepairWorktrees, Background),
|
||||
(GetCommitData, Background),
|
||||
(GetCommitDataResponse, Background),
|
||||
(SearchCommits, Background),
|
||||
(SearchCommitsResponse, Background),
|
||||
(GitWorktreesResponse, Background),
|
||||
(GitCreateWorktree, Background),
|
||||
(GitRemoveWorktree, Background),
|
||||
|
|
@ -573,6 +575,7 @@ request_messages!(
|
|||
(GitEditRef, Ack),
|
||||
(GitRepairWorktrees, Ack),
|
||||
(GetCommitData, GetCommitDataResponse),
|
||||
(SearchCommits, SearchCommitsResponse),
|
||||
(GitCreateWorktree, Ack),
|
||||
(GitRemoveWorktree, Ack),
|
||||
(GitRenameWorktree, Ack),
|
||||
|
|
@ -767,6 +770,7 @@ entity_messages!(
|
|||
GitEditRef,
|
||||
GitRepairWorktrees,
|
||||
GetCommitData,
|
||||
SearchCommits,
|
||||
GitCreateArchiveCheckpoint,
|
||||
GitRestoreArchiveCheckpoint,
|
||||
GitCreateWorktree,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ use futures::{
|
|||
},
|
||||
future::{BoxFuture, Shared, WeakShared},
|
||||
select, select_biased,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
|
||||
|
|
@ -1320,6 +1321,8 @@ impl RemoteConnectionOptions {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use rpc::{ErrorCodeExt, proto::ErrorCode};
|
||||
|
||||
#[test]
|
||||
fn test_ssh_display_name_prefers_nickname() {
|
||||
|
|
@ -1341,6 +1344,137 @@ mod tests {
|
|||
|
||||
assert_eq!(options.display_name(), "1.2.3.4");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_client_request_stream_terminates_on_error(cx: &mut TestAppContext) {
|
||||
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
|
||||
|
||||
let client =
|
||||
cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false));
|
||||
|
||||
// The client sends RemoteStarted on startup; drain the outgoing channel
|
||||
// so it doesn't block.
|
||||
let _drain_outgoing = cx
|
||||
.executor()
|
||||
.spawn(async move { while outgoing_rx.next().await.is_some() {} });
|
||||
|
||||
let mut stream = client
|
||||
.request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_id = 0;
|
||||
|
||||
incoming_tx
|
||||
.unbounded_send(proto::Test { id: 1 }.into_envelope(100, Some(request_id), None))
|
||||
.unwrap();
|
||||
|
||||
let first = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(
|
||||
proto::Test::from_envelope(first).unwrap(),
|
||||
proto::Test { id: 1 }
|
||||
);
|
||||
|
||||
// Send an Error without a trailing EndStream. The Error alone should
|
||||
// terminate the stream.
|
||||
incoming_tx
|
||||
.unbounded_send(
|
||||
ErrorCode::Internal
|
||||
.message("boom".to_string())
|
||||
.to_proto()
|
||||
.into_envelope(101, Some(request_id), None),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let second = stream.next().await.unwrap();
|
||||
let error = second.unwrap_err();
|
||||
assert!(
|
||||
format!("{error}").contains("boom"),
|
||||
"expected error to surface server message, got: {error}"
|
||||
);
|
||||
|
||||
assert!(stream.next().await.is_none());
|
||||
assert_eq!(client.stream_response_channels.lock().len(), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_client_dropping_stream_request_before_response_cleans_up_channel(
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let (_incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
|
||||
|
||||
let client =
|
||||
cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false));
|
||||
|
||||
let _drain_outgoing = cx
|
||||
.executor()
|
||||
.spawn(async move { while outgoing_rx.next().await.is_some() {} });
|
||||
|
||||
let stream = client
|
||||
.request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(client.stream_response_channels.lock().len(), 1);
|
||||
|
||||
drop(stream);
|
||||
cx.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
client.stream_response_channels.lock().len(),
|
||||
0,
|
||||
"dropping a stream before any responses arrive should remove response channel bookkeeping"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_client_dropping_stream_request_before_completion(
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
|
||||
|
||||
let client =
|
||||
cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "test-client", false));
|
||||
|
||||
let _drain_outgoing = cx
|
||||
.executor()
|
||||
.spawn(async move { while outgoing_rx.next().await.is_some() {} });
|
||||
|
||||
let mut stream = client
|
||||
.request_stream_dynamic(proto::Test { id: 0 }.into_envelope(0, None, None), "Test")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_id = 0;
|
||||
|
||||
incoming_tx
|
||||
.unbounded_send(proto::Test { id: 1 }.into_envelope(100, Some(request_id), None))
|
||||
.unwrap();
|
||||
let _ = stream.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(client.stream_response_channels.lock().len(), 1);
|
||||
|
||||
drop(stream);
|
||||
|
||||
// Inject an orphaned non-terminal response. The read loop should detect
|
||||
// that the consumer has been dropped and clean up its bookkeeping (no
|
||||
// EndStream sent here on purpose, otherwise the cleanup would happen
|
||||
// via the terminal-response path and mask the bug under test).
|
||||
incoming_tx
|
||||
.unbounded_send(proto::Test { id: 2 }.into_envelope(101, Some(request_id), None))
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
client.stream_response_channels.lock().len(),
|
||||
0,
|
||||
"stream channel should be removed once the consumer has dropped the stream"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SshConnectionOptions> for RemoteConnectionOptions {
|
||||
|
|
@ -1418,6 +1552,8 @@ pub trait RemoteConnection: Send + Sync {
|
|||
}
|
||||
|
||||
type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
|
||||
type StreamResponseChannels =
|
||||
Arc<Mutex<HashMap<MessageId, UnboundedSender<(Result<Envelope>, oneshot::Sender<()>)>>>>;
|
||||
|
||||
struct Signal<T> {
|
||||
tx: Mutex<Option<oneshot::Sender<T>>>,
|
||||
|
|
@ -1455,6 +1591,7 @@ pub(crate) struct ChannelClient {
|
|||
outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
|
||||
buffer: Mutex<VecDeque<Envelope>>,
|
||||
response_channels: ResponseChannels,
|
||||
stream_response_channels: StreamResponseChannels,
|
||||
message_handlers: Mutex<ProtoMessageHandlerSet>,
|
||||
max_received: AtomicU32,
|
||||
name: &'static str,
|
||||
|
|
@ -1477,6 +1614,7 @@ impl ChannelClient {
|
|||
next_message_id: AtomicU32::new(0),
|
||||
max_received: AtomicU32::new(0),
|
||||
response_channels: ResponseChannels::default(),
|
||||
stream_response_channels: StreamResponseChannels::default(),
|
||||
message_handlers: Default::default(),
|
||||
buffer: Mutex::new(VecDeque::new()),
|
||||
name,
|
||||
|
|
@ -1550,13 +1688,40 @@ impl ChannelClient {
|
|||
|
||||
if let Some(request_id) = incoming.responding_to {
|
||||
let request_id = MessageId(request_id);
|
||||
// An incoming response with no payload is malformed; drop
|
||||
// it. The request future and any stream consumers will
|
||||
// remain pending until either a real response arrives or
|
||||
// the connection is torn down.
|
||||
if incoming.payload.is_none() {
|
||||
continue;
|
||||
}
|
||||
let sender = this.response_channels.lock().remove(&request_id);
|
||||
if let Some(sender) = sender {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if incoming.payload.is_some() {
|
||||
sender.send((incoming, tx)).ok();
|
||||
}
|
||||
sender.send((incoming, tx)).ok();
|
||||
rx.await.ok();
|
||||
} else {
|
||||
let terminal_stream_response = matches!(
|
||||
&incoming.payload,
|
||||
Some(proto::envelope::Payload::Error(_))
|
||||
| Some(proto::envelope::Payload::EndStream(_))
|
||||
);
|
||||
let sender = if terminal_stream_response {
|
||||
this.stream_response_channels.lock().remove(&request_id)
|
||||
} else {
|
||||
this.stream_response_channels
|
||||
.lock()
|
||||
.get(&request_id)
|
||||
.cloned()
|
||||
};
|
||||
if let Some(sender) = sender {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if sender.unbounded_send((Ok(incoming), tx)).is_err() {
|
||||
this.stream_response_channels.lock().remove(&request_id);
|
||||
continue;
|
||||
}
|
||||
rx.await.ok();
|
||||
}
|
||||
}
|
||||
} else if let Some(envelope) =
|
||||
build_typed_envelope(peer_id, Instant::now(), incoming)
|
||||
|
|
@ -1721,6 +1886,55 @@ impl ChannelClient {
|
|||
}
|
||||
}
|
||||
|
||||
fn request_stream_dynamic(
|
||||
&self,
|
||||
mut envelope: proto::Envelope,
|
||||
type_name: &'static str,
|
||||
) -> impl 'static + Future<Output = Result<BoxStream<'static, Result<proto::Envelope>>>> {
|
||||
envelope.id = self.next_message_id.fetch_add(1, SeqCst);
|
||||
let message_id = MessageId(envelope.id);
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let stream_response_channels = self.stream_response_channels.clone();
|
||||
stream_response_channels.lock().insert(message_id, tx);
|
||||
|
||||
let result = self.send_buffered(envelope);
|
||||
async move {
|
||||
if let Err(error) = &result {
|
||||
log::error!("failed to send message: {error}");
|
||||
anyhow::bail!("failed to send message: {error}");
|
||||
}
|
||||
|
||||
let cleanup_stream_response_channel = util::defer({
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
move || {
|
||||
stream_response_channels.lock().remove(&message_id);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx
|
||||
.filter_map(move |(response, _barrier)| {
|
||||
// Keep the cleanup guard alive until the returned stream is dropped.
|
||||
let _keep_cleanup_guard_alive = &cleanup_stream_response_channel;
|
||||
futures::future::ready(match response {
|
||||
Ok(response) => {
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload
|
||||
{
|
||||
Some(Err(RpcError::from_proto(error, type_name)))
|
||||
} else if let Some(proto::envelope::Payload::EndStream(_)) =
|
||||
&response.payload
|
||||
{
|
||||
None
|
||||
} else {
|
||||
Some(Ok(response))
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
})
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
|
||||
envelope.id = self.next_message_id.fetch_add(1, SeqCst);
|
||||
self.send_buffered(envelope)
|
||||
|
|
@ -1751,6 +1965,14 @@ impl ProtoClient for ChannelClient {
|
|||
self.request_dynamic(envelope, request_type, true).boxed()
|
||||
}
|
||||
|
||||
fn request_stream(
|
||||
&self,
|
||||
envelope: proto::Envelope,
|
||||
request_type: &'static str,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::Envelope>>>> {
|
||||
self.request_stream_dynamic(envelope, request_type).boxed()
|
||||
}
|
||||
|
||||
fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
|
||||
self.send_dynamic(envelope)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@ use languages::rust_lang;
|
|||
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs};
|
||||
use git::repository::Worktree as GitWorktree;
|
||||
use git::{
|
||||
Oid,
|
||||
repository::{CommitData, Worktree as GitWorktree},
|
||||
};
|
||||
use gpui::{AppContext as _, Entity, SharedString, TestAppContext};
|
||||
use http_client::{BlockedHttpClient, FakeHttpClient};
|
||||
use language::{
|
||||
|
|
@ -29,11 +32,13 @@ use project::{
|
|||
search::{SearchQuery, SearchResult},
|
||||
};
|
||||
use remote::RemoteClient;
|
||||
use rpc::proto;
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsLocation, SettingsStore, initial_server_settings_content};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
use unindent::Unindent as _;
|
||||
|
|
@ -1626,6 +1631,108 @@ async fn test_remote_root_repo_common_dir(cx: &mut TestAppContext, server_cx: &m
|
|||
assert_eq!(common_dir, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_remote_search_commits_streams_proto_chunks(
|
||||
cx: &mut TestAppContext,
|
||||
server_cx: &mut TestAppContext,
|
||||
) {
|
||||
const COMMIT_COUNT: usize = 900;
|
||||
const RESPONSE_MAX_SIZE: usize = 100;
|
||||
|
||||
let fs = FakeFs::new(server_cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/code"),
|
||||
json!({
|
||||
"project1": {
|
||||
".git": {},
|
||||
"file.txt": "content",
|
||||
},
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let commit_data = (0..COMMIT_COUNT)
|
||||
.map(|index| {
|
||||
let sha = Oid::from_str(&format!("{:040x}", index + 1)).unwrap();
|
||||
(
|
||||
CommitData {
|
||||
sha,
|
||||
parents: Default::default(),
|
||||
author_name: SharedString::from("Author"),
|
||||
author_email: SharedString::from("author@example.com"),
|
||||
commit_timestamp: index as i64,
|
||||
subject: SharedString::from(format!("Subject {index}")),
|
||||
message: SharedString::from(format!("needle commit {index}")),
|
||||
},
|
||||
false,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_shas = commit_data
|
||||
.iter()
|
||||
.map(|(commit_data, _)| commit_data.sha.to_string())
|
||||
.collect::<HashSet<_>>();
|
||||
fs.set_commit_data(Path::new(path!("/code/project1/.git")), commit_data);
|
||||
|
||||
let (project, _headless) = init_test(&fs, cx, server_cx).await;
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.find_or_create_worktree(path!("/code/project1"), true, cx)
|
||||
})
|
||||
.await
|
||||
.expect("should open remote worktree");
|
||||
server_cx.run_until_parked();
|
||||
cx.run_until_parked();
|
||||
project
|
||||
.update(cx, |project, cx| project.git_scans_complete(cx))
|
||||
.await;
|
||||
|
||||
let (remote_client, repository_id) = project.read_with(cx, |project, cx| {
|
||||
let repository = project
|
||||
.active_repository(cx)
|
||||
.expect("remote project should have an active repository");
|
||||
let repository_id = repository.read(cx).snapshot().id;
|
||||
let remote_client = project
|
||||
.remote_client()
|
||||
.expect("project should have a remote client");
|
||||
(remote_client, repository_id)
|
||||
});
|
||||
let proto_client = remote_client.read_with(cx, |remote_client, _| remote_client.proto_client());
|
||||
let mut stream = proto_client
|
||||
.request_stream(proto::SearchCommits {
|
||||
project_id: proto::REMOTE_SERVER_PROJECT_ID,
|
||||
repository_id: repository_id.to_proto(),
|
||||
log_source: Some(proto::GitLogSource {
|
||||
source: Some(proto::git_log_source::Source::All(
|
||||
proto::GitLogSourceAll {},
|
||||
)),
|
||||
}),
|
||||
query: "needle".to_string(),
|
||||
case_sensitive: true,
|
||||
})
|
||||
.await
|
||||
.expect("search commits stream should start");
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
while let Some(response) = futures::StreamExt::next(&mut stream).await {
|
||||
chunks.push(response.expect("search commits chunk should succeed").shas);
|
||||
}
|
||||
|
||||
assert!(
|
||||
chunks.len() > 1,
|
||||
"expected search results to stream in multiple chunks"
|
||||
);
|
||||
for chunk in chunks.iter().take(chunks.len() - 1) {
|
||||
assert!(
|
||||
chunk.len() <= RESPONSE_MAX_SIZE,
|
||||
"non-final chunks should meet the target byte size"
|
||||
);
|
||||
}
|
||||
|
||||
let actual_shas = chunks.into_iter().flatten().collect::<HashSet<_>>();
|
||||
assert_eq!(actual_shas, expected_shas);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_remote_archive_git_operations_are_supported(
|
||||
cx: &mut TestAppContext,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use super::{
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
|
||||
FutureExt, SinkExt, StreamExt, TryFutureExt,
|
||||
channel::{mpsc, oneshot},
|
||||
stream::BoxStream,
|
||||
};
|
||||
|
|
@ -278,11 +278,23 @@ impl Peer {
|
|||
);
|
||||
let response_channel =
|
||||
response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
let stream_response_channel = stream_response_channels
|
||||
.lock()
|
||||
.as_ref()?
|
||||
.get(&responding_to)
|
||||
.cloned();
|
||||
let terminal_stream_response = matches!(
|
||||
&incoming.payload,
|
||||
Some(proto::envelope::Payload::Error(_))
|
||||
| Some(proto::envelope::Payload::EndStream(_))
|
||||
);
|
||||
let stream_response_channel = if terminal_stream_response {
|
||||
stream_response_channels
|
||||
.lock()
|
||||
.as_mut()?
|
||||
.remove(&responding_to)
|
||||
} else {
|
||||
stream_response_channels
|
||||
.lock()
|
||||
.as_ref()?
|
||||
.get(&responding_to)
|
||||
.cloned()
|
||||
};
|
||||
|
||||
if let Some(tx) = response_channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
|
|
@ -319,21 +331,15 @@ impl Peer {
|
|||
?error,
|
||||
"incoming stream response: request future dropped",
|
||||
);
|
||||
// The consumer has gone away, so drop the bookkeeping
|
||||
// for this stream rather than letting it accumulate
|
||||
// every subsequent message until a terminal frame.
|
||||
if let Some(channels) = stream_response_channels.lock().as_mut() {
|
||||
channels.remove(&responding_to);
|
||||
}
|
||||
} else {
|
||||
let _ = requester_resumed.1.await;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: waiting to resume requester"
|
||||
);
|
||||
let _ = requester_resumed.1.await;
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: requester resumed"
|
||||
);
|
||||
} else {
|
||||
let message_type = proto::build_typed_envelope(
|
||||
connection_id.into(),
|
||||
|
|
@ -484,55 +490,96 @@ impl Peer {
|
|||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
|
||||
) -> impl Future<Output = Result<BoxStream<'static, Result<T::Response>>>> {
|
||||
let stream =
|
||||
self.request_stream_dynamic(receiver_id, request.into_envelope(0, None, None), T::NAME);
|
||||
|
||||
async move {
|
||||
Ok(stream
|
||||
.await?
|
||||
.map(|response| {
|
||||
T::Response::from_envelope(response?)
|
||||
.context("received response of the wrong type")
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_stream_dynamic(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
mut envelope: proto::Envelope,
|
||||
request_type: &'static str,
|
||||
) -> impl Future<Output = Result<BoxStream<'static, Result<proto::Envelope>>>> + use<> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let send = self.connection_state(receiver_id).and_then(|connection| {
|
||||
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
||||
envelope.id = message_id;
|
||||
let stream_response_channels = connection.stream_response_channels.clone();
|
||||
stream_response_channels
|
||||
.lock()
|
||||
.as_mut()
|
||||
.context("connection was closed")?
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
if let Err(error) = connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(Message::Envelope(
|
||||
request.into_envelope(message_id, None, None),
|
||||
))
|
||||
.context("connection was closed")?;
|
||||
.unbounded_send(Message::Envelope(envelope))
|
||||
{
|
||||
if let Some(channels) = stream_response_channels.lock().as_mut() {
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
return Err(error).context("connection was closed");
|
||||
}
|
||||
Ok((message_id, stream_response_channels))
|
||||
});
|
||||
|
||||
async move {
|
||||
let (message_id, stream_response_channels) = send?;
|
||||
let stream_response_channels = Arc::downgrade(&stream_response_channels);
|
||||
|
||||
Ok(rx.filter_map(move |(response, _barrier)| {
|
||||
let cleanup_stream_response_channel = util::defer({
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
future::ready(match response {
|
||||
Ok(response) => {
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Some(Err(RpcError::from_proto(error, T::NAME)))
|
||||
} else if let Some(proto::envelope::Payload::EndStream(_)) =
|
||||
&response.payload
|
||||
{
|
||||
// Remove the transmitting end of the response channel to end the stream.
|
||||
if let Some(channels) = stream_response_channels.upgrade()
|
||||
&& let Some(channels) = channels.lock().as_mut()
|
||||
{
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
T::Response::from_envelope(response)
|
||||
.context("received response of the wrong type"),
|
||||
)
|
||||
}
|
||||
move || {
|
||||
if let Some(channels) = stream_response_channels.upgrade()
|
||||
&& let Some(channels) = channels.lock().as_mut()
|
||||
{
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx
|
||||
.filter_map(move |(response, _barrier)| {
|
||||
let _keep_cleanup_guard_alive = &cleanup_stream_response_channel;
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
future::ready(match response {
|
||||
Ok(response) => {
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload
|
||||
{
|
||||
// Remove the transmitting end of the response channel to end the stream.
|
||||
if let Some(channels) = stream_response_channels.upgrade()
|
||||
&& let Some(channels) = channels.lock().as_mut()
|
||||
{
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
Some(Err(RpcError::from_proto(error, request_type)))
|
||||
} else if let Some(proto::envelope::Payload::EndStream(_)) =
|
||||
&response.payload
|
||||
{
|
||||
// Remove the transmitting end of the response channel to end the stream.
|
||||
if let Some(channels) = stream_response_channels.upgrade()
|
||||
&& let Some(channels) = channels.lock().as_mut()
|
||||
{
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
None
|
||||
} else {
|
||||
Some(Ok(response))
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
})
|
||||
})
|
||||
}))
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -661,6 +708,13 @@ impl Peer {
|
|||
.with_context(|| format!("no such connection: {connection_id}"))?;
|
||||
Ok(connection.clone())
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn pending_stream_request_count(&self, connection_id: ConnectionId) -> Option<usize> {
|
||||
let connection = self.connection_state(connection_id).ok()?;
|
||||
let channels = connection.stream_response_channels.lock();
|
||||
Some(channels.as_ref()?.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Peer {
|
||||
|
|
@ -992,6 +1046,268 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_request_stream(cx: &mut TestAppContext) {
|
||||
init_logger();
|
||||
|
||||
let executor = cx.executor();
|
||||
let server = Peer::new(0);
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, server_to_client_conn, _kill) =
|
||||
Connection::in_memory(executor.clone());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, executor.clone());
|
||||
let (_, io_task2, mut server_incoming) =
|
||||
server.add_test_connection(server_to_client_conn, executor.clone());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
executor
|
||||
.spawn(async move { while client_incoming.next().await.is_some() {} })
|
||||
.detach();
|
||||
|
||||
executor
|
||||
.spawn({
|
||||
let server = server.clone();
|
||||
async move {
|
||||
let request = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Test>>()
|
||||
.unwrap();
|
||||
let receipt = request.receipt();
|
||||
server.respond(receipt, proto::Test { id: 1 }).unwrap();
|
||||
server.respond(receipt, proto::Test { id: 2 }).unwrap();
|
||||
server.respond(receipt, proto::Test { id: 3 }).unwrap();
|
||||
server.end_stream(receipt).unwrap();
|
||||
|
||||
// Prevent the connection from being dropped.
|
||||
server_incoming.next().await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let mut stream = client
|
||||
.request_stream(client_to_server_conn_id, proto::Test { id: 0 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut received = Vec::new();
|
||||
while let Some(item) = stream.next().await {
|
||||
received.push(item.unwrap());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
received,
|
||||
vec![
|
||||
proto::Test { id: 1 },
|
||||
proto::Test { id: 2 },
|
||||
proto::Test { id: 3 },
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
client.pending_stream_request_count(client_to_server_conn_id),
|
||||
Some(0)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_request_stream_send_failure_cleans_up_response_channel(cx: &mut TestAppContext) {
|
||||
init_logger();
|
||||
|
||||
let executor = cx.executor();
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, _server_to_client_conn, _kill) =
|
||||
Connection::in_memory(executor.clone());
|
||||
let (client_to_server_conn_id, io_task, _client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, executor.clone());
|
||||
|
||||
drop(io_task);
|
||||
|
||||
let result = client
|
||||
.request_stream(client_to_server_conn_id, proto::Test { id: 0 })
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"stream request should fail when the connection write task has gone away"
|
||||
);
|
||||
assert_eq!(
|
||||
client.pending_stream_request_count(client_to_server_conn_id),
|
||||
Some(0),
|
||||
"failed stream request should not leave response channel bookkeeping behind"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_request_stream_terminates_on_error(cx: &mut TestAppContext) {
|
||||
init_logger();
|
||||
|
||||
let executor = cx.executor();
|
||||
let server = Peer::new(0);
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, server_to_client_conn, _kill) =
|
||||
Connection::in_memory(executor.clone());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, executor.clone());
|
||||
let (_, io_task2, mut server_incoming) =
|
||||
server.add_test_connection(server_to_client_conn, executor.clone());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
executor
|
||||
.spawn(async move { while client_incoming.next().await.is_some() {} })
|
||||
.detach();
|
||||
|
||||
executor
|
||||
.spawn({
|
||||
let server = server.clone();
|
||||
async move {
|
||||
let request = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Test>>()
|
||||
.unwrap();
|
||||
let receipt = request.receipt();
|
||||
server.respond(receipt, proto::Test { id: 1 }).unwrap();
|
||||
// Send an Error without a trailing EndStream. The Error alone
|
||||
// should be treated as a terminal stream response.
|
||||
server
|
||||
.respond_with_error(
|
||||
receipt,
|
||||
ErrorCode::Internal.message("boom".to_string()).to_proto(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Prevent the connection from being dropped.
|
||||
server_incoming.next().await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let mut stream = client
|
||||
.request_stream(client_to_server_conn_id, proto::Test { id: 0 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.next().await.unwrap().unwrap(), proto::Test { id: 1 });
|
||||
|
||||
let error = stream.next().await.unwrap().unwrap_err();
|
||||
assert!(
|
||||
format!("{error}").contains("boom"),
|
||||
"expected error to surface server message, got: {error}"
|
||||
);
|
||||
|
||||
// The error alone (without an EndStream) should terminate the stream.
|
||||
assert!(stream.next().await.is_none());
|
||||
assert_eq!(
|
||||
client.pending_stream_request_count(client_to_server_conn_id),
|
||||
Some(0)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_dropping_stream_request_before_completion(cx: &mut TestAppContext) {
|
||||
init_logger();
|
||||
|
||||
let executor = cx.executor();
|
||||
let server = Peer::new(0);
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, server_to_client_conn, _kill) =
|
||||
Connection::in_memory(executor.clone());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, executor.clone());
|
||||
let (_, io_task2, mut server_incoming) =
|
||||
server.add_test_connection(server_to_client_conn, executor.clone());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
executor
|
||||
.spawn(async move { while client_incoming.next().await.is_some() {} })
|
||||
.detach();
|
||||
|
||||
let (drop_signal_tx, drop_signal_rx) = oneshot::channel::<()>();
|
||||
let server_task = executor.spawn({
|
||||
let server = server.clone();
|
||||
async move {
|
||||
let request = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Test>>()
|
||||
.unwrap();
|
||||
let receipt = request.receipt();
|
||||
server.respond(receipt, proto::Test { id: 1 }).unwrap();
|
||||
|
||||
// Wait until the consumer has dropped the stream.
|
||||
drop_signal_rx.await.ok();
|
||||
|
||||
// Send a non-terminal response after the consumer is gone. The
|
||||
// peer should detect that the receiver has been dropped and clean
|
||||
// up its bookkeeping. Crucially, we do NOT send EndStream here
|
||||
// because that would clean up via the terminal-response path and
|
||||
// mask the bug.
|
||||
server.respond(receipt, proto::Test { id: 2 }).unwrap();
|
||||
|
||||
// A Ping/Ack round-trip after the response acts as a sync
|
||||
// barrier: because messages over the in-memory connection are
|
||||
// delivered in order, by the time the client observes the Ack,
|
||||
// it has already processed the dropped response above.
|
||||
let ping = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Ping>>()
|
||||
.unwrap();
|
||||
server.respond(ping.receipt(), proto::Ack {}).unwrap();
|
||||
|
||||
// Prevent the connection from being dropped.
|
||||
server_incoming.next().await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut stream = client
|
||||
.request_stream(client_to_server_conn_id, proto::Test { id: 0 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.next().await.unwrap().unwrap(), proto::Test { id: 1 });
|
||||
|
||||
// The stream is mid-flight, so the channel should be tracked.
|
||||
assert_eq!(
|
||||
client.pending_stream_request_count(client_to_server_conn_id),
|
||||
Some(1)
|
||||
);
|
||||
|
||||
drop(stream);
|
||||
drop_signal_tx.send(()).ok();
|
||||
|
||||
// Synchronization barrier: once this Ack arrives, the read loop has
|
||||
// already processed the orphaned stream response that came before it.
|
||||
client
|
||||
.request(client_to_server_conn_id, proto::Ping {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
client.pending_stream_request_count(client_to_server_conn_id),
|
||||
Some(0),
|
||||
"stream channel should be removed once the consumer has dropped the stream"
|
||||
);
|
||||
|
||||
drop(server_task);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_disconnect(cx: &mut TestAppContext) {
|
||||
let executor = cx.executor();
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
Future, FutureExt as _,
|
||||
Future, FutureExt as _, Stream, StreamExt as _,
|
||||
channel::oneshot,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
|
||||
use parking_lot::Mutex;
|
||||
|
|
@ -61,6 +62,20 @@ pub trait ProtoClient: Send + Sync {
|
|||
request_type: &'static str,
|
||||
) -> BoxFuture<'static, Result<Envelope>>;
|
||||
|
||||
fn request_stream(
|
||||
&self,
|
||||
envelope: Envelope,
|
||||
request_type: &'static str,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Envelope>>>> {
|
||||
async move {
|
||||
anyhow::bail!(
|
||||
"stream requests are not supported for {request_type}: {:?}",
|
||||
envelope.payload
|
||||
)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
|
||||
|
||||
fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
|
||||
|
|
@ -223,6 +238,23 @@ impl AnyProtoClient {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn request_stream<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<BoxStream<'static, Result<T::Response>>>> + use<T> {
|
||||
let envelope = request.into_envelope(0, None, None);
|
||||
let response_stream = self.0.client.request_stream(envelope, T::NAME);
|
||||
async move {
|
||||
Ok(response_stream
|
||||
.await?
|
||||
.map(|response| {
|
||||
T::Response::from_envelope(response?)
|
||||
.context("received response of the wrong type")
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
|
||||
let envelope = request.into_envelope(0, None, None);
|
||||
self.0.client.send(envelope, T::NAME)
|
||||
|
|
@ -479,6 +511,68 @@ impl AnyProtoClient {
|
|||
);
|
||||
}
|
||||
|
||||
pub fn add_entity_stream_request_handler<M, E, H, F, S>(&self, handler: H)
|
||||
where
|
||||
M: EnvelopedMessage + RequestMessage + EntityMessage,
|
||||
E: 'static,
|
||||
H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
|
||||
F: 'static + Future<Output = Result<S>>,
|
||||
S: 'static + Stream<Item = Result<M::Response>>,
|
||||
{
|
||||
let message_type_id = TypeId::of::<M>();
|
||||
let entity_type_id = TypeId::of::<E>();
|
||||
let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
|
||||
(envelope as &dyn Any)
|
||||
.downcast_ref::<TypedEnvelope<M>>()
|
||||
.unwrap()
|
||||
.payload
|
||||
.remote_entity_id()
|
||||
};
|
||||
self.0
|
||||
.client
|
||||
.message_handler_set()
|
||||
.lock()
|
||||
.add_entity_message_handler(
|
||||
message_type_id,
|
||||
entity_type_id,
|
||||
entity_id_extractor,
|
||||
Arc::new(move |entity, envelope, client, cx| {
|
||||
let entity = entity.downcast::<E>().unwrap();
|
||||
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
||||
let request_id = envelope.message_id();
|
||||
let stream = handler(entity, *envelope, cx);
|
||||
async move {
|
||||
// An Error response is itself a terminal stream frame on
|
||||
// both transports (Peer and ChannelClient), so we don't
|
||||
// need to follow it with an EndStream.
|
||||
match stream.await {
|
||||
Ok(stream) => {
|
||||
futures::pin_mut!(stream);
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
client.send_response(request_id, response)?
|
||||
}
|
||||
Err(error) => {
|
||||
client.send_response(request_id, error.to_proto())?;
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
client.send_response(request_id, proto::EndStream {})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
client.send_response(request_id, error.to_proto())?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
.boxed_local()
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
|
||||
where
|
||||
M: EnvelopedMessage + EntityMessage,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue