mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
kt-kernel: enable CPUInfer stream bridge for ROCm (#1918)
Some checks failed
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Release sglang-kt to PyPI / Build sglang-kt wheel (push) Has been cancelled
Release sglang-kt to PyPI / Publish sglang-kt to PyPI (push) Has been cancelled
Some checks failed
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Release sglang-kt to PyPI / Build sglang-kt wheel (push) Has been cancelled
Release sglang-kt to PyPI / Publish sglang-kt to PyPI (push) Has been cancelled
This commit is contained in:
parent
9b2d3b687b
commit
1dd0a78899
1 changed files with 8 additions and 3 deletions
|
|
@ -83,11 +83,13 @@ class CPUInfer {
|
||||||
}
|
}
|
||||||
#ifndef KTRANSFORMERS_CPU_ONLY
|
#ifndef KTRANSFORMERS_CPU_ONLY
|
||||||
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
|
||||||
#if defined(KTRANSFORMERS_USE_CUDA)
|
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM)
|
||||||
void (*func)(void*) = (void (*)(void*))params.first;
|
void (*func)(void*) = (void (*)(void*))params.first;
|
||||||
void* args = (void*)params.second;
|
void* args = (void*)params.second;
|
||||||
*((CPUInfer**)args) = this;
|
*((CPUInfer**)args) = this;
|
||||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
|
||||||
|
#else
|
||||||
|
submit(params);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -100,6 +102,7 @@ class CPUInfer {
|
||||||
static void sync_(void* sync_args) {
|
static void sync_(void* sync_args) {
|
||||||
SyncArgs* args = (SyncArgs*)sync_args;
|
SyncArgs* args = (SyncArgs*)sync_args;
|
||||||
args->cpuinfer->task_queue_->sync(args->allow_n_pending);
|
args->cpuinfer->task_queue_->sync(args->allow_n_pending);
|
||||||
|
delete args;
|
||||||
}
|
}
|
||||||
|
|
||||||
void sync(size_t allow_n_pending = 0) {
|
void sync(size_t allow_n_pending = 0) {
|
||||||
|
|
@ -108,9 +111,11 @@ class CPUInfer {
|
||||||
}
|
}
|
||||||
#ifndef KTRANSFORMERS_CPU_ONLY
|
#ifndef KTRANSFORMERS_CPU_ONLY
|
||||||
void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) {
|
void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) {
|
||||||
#if defined(KTRANSFORMERS_USE_CUDA)
|
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM)
|
||||||
SyncArgs* args = new SyncArgs{this, allow_n_pending};
|
SyncArgs* args = new SyncArgs{this, allow_n_pending};
|
||||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args);
|
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args);
|
||||||
|
#else
|
||||||
|
sync(allow_n_pending);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -119,4 +124,4 @@ class CPUInfer {
|
||||||
TaskQueue* task_queue_;
|
TaskQueue* task_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue