diff --git a/kt-kernel/cpu_backend/cpuinfer.h b/kt-kernel/cpu_backend/cpuinfer.h index f95cad27..5c3210a6 100644 --- a/kt-kernel/cpu_backend/cpuinfer.h +++ b/kt-kernel/cpu_backend/cpuinfer.h @@ -83,13 +83,11 @@ class CPUInfer { } #ifndef KTRANSFORMERS_CPU_ONLY void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { -#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM) +#if defined(KTRANSFORMERS_USE_CUDA) void (*func)(void*) = (void (*)(void*))params.first; void* args = (void*)params.second; *((CPUInfer**)args) = this; cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); -#else - submit(params); #endif } #endif @@ -102,7 +100,6 @@ class CPUInfer { static void sync_(void* sync_args) { SyncArgs* args = (SyncArgs*)sync_args; args->cpuinfer->task_queue_->sync(args->allow_n_pending); - delete args; } void sync(size_t allow_n_pending = 0) { @@ -111,11 +108,9 @@ class CPUInfer { } #ifndef KTRANSFORMERS_CPU_ONLY void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) { -#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_ROCM) +#if defined(KTRANSFORMERS_USE_CUDA) SyncArgs* args = new SyncArgs{this, allow_n_pending}; cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args); -#else - sync(allow_n_pending); #endif } #endif @@ -124,4 +119,4 @@ class CPUInfer { TaskQueue* task_queue_; }; -#endif +#endif \ No newline at end of file