cont: tentative fix for allReduce compile error

This commit is contained in:
Concedo 2026-05-12 21:05:07 +08:00
parent 31c18ed12d
commit 216901034a

View file

@ -184,13 +184,21 @@ static __global__ void ggml_cuda_ar_kernel(
#pragma unroll
for (int k = 0; k < ELEMS_PER_VEC; ++k) {
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]);
recvbuf[off + k] = ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(wire[k]);
const float a = ggml_cuda_cast<float>(d_low);
const float b = ggml_cuda_cast<float>(wire[k]);
recvbuf[off + k] = ggml_cuda_cast<T_dst>(a + b);
}
}
if (bid == 0 && tid < count - tail) {
const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]);
recvbuf[tail + tid] =
ggml_cuda_cast<T_dst>(d_low) + ggml_cuda_cast<T_dst>(host_other[tail + tid]);
const float a = ggml_cuda_cast<float>(d_low);
const float b = ggml_cuda_cast<float>(host_other[tail + tid]);
recvbuf[tail + tid] = ggml_cuda_cast<T_dst>(a + b);
}
}
}