kt-kernel: enable CPUInfer stream bridge for ROCm (#1918)
Some checks failed
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Release sglang-kt to PyPI / Build sglang-kt wheel (push) Has been cancelled
Release sglang-kt to PyPI / Publish sglang-kt to PyPI (push) Has been cancelled

This commit is contained in:
guanjiawei 2026-04-09 12:20:04 +08:00 committed by GitHub
parent 9b2d3b687b
commit 1dd0a78899
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -83,11 +83,13 @@ class CPUInfer {
}
#ifndef KTRANSFORMERS_CPU_ONLY
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
#if defined(KTRANSFORMERS_USE_CUDA)
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM)
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
#else
submit(params);
#endif
}
#endif
@ -100,6 +102,7 @@ class CPUInfer {
static void sync_(void* sync_args) {
SyncArgs* args = (SyncArgs*)sync_args;
args->cpuinfer->task_queue_->sync(args->allow_n_pending);
delete args;
}
void sync(size_t allow_n_pending = 0) {
@ -108,9 +111,11 @@ class CPUInfer {
}
#ifndef KTRANSFORMERS_CPU_ONLY
void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) {
#if defined(KTRANSFORMERS_USE_CUDA)
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM)
SyncArgs* args = new SyncArgs{this, allow_n_pending};
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args);
#else
sync(allow_n_pending);
#endif
}
#endif
@ -119,4 +124,4 @@ class CPUInfer {
TaskQueue* task_queue_;
};
#endif
#endif