mmvq Optim: add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for … (#23729)

* mmvq Optim:  add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for SM75 TURING

* avoid a mismatch for JIT compilation of Turing device code for Ampere or newer

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
redfox 2026-05-28 20:51:14 +08:00 committed by GitHub
parent bc81d47aba
commit d7be46189f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -63,6 +63,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_TURING,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA3_0,
@ -78,6 +79,8 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE
return MMVQ_PARAMETERS_TURING;
#else
return MMVQ_PARAMETERS_GENERIC;
#endif
@ -96,6 +99,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return MMVQ_PARAMETERS_GCN;
}
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) {
return MMVQ_PARAMETERS_TURING;
}
return MMVQ_PARAMETERS_GENERIC;
}
@ -417,11 +423,38 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
}
return 1;
}
if (table_id == MMVQ_PARAMETERS_TURING) {
if (ncols_dst == 1) {
switch (type) {
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return 2;
default:
return 4;
}
}
switch (ncols_dst) {
case 2:
case 3:
case 4:
return 4;
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
}
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) {
switch (ncols_dst) {
case 1:
return small_k ? nwarps : 1;