mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 09:59:41 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/rocm.Dockerfile # docs/build-s390x.md # docs/development/HOWTO-add-model.md # docs/ops.md # docs/ops/CPU.csv # docs/ops/CUDA.csv # ggml/CMakeLists.txt # ggml/src/ggml-cann/acl_tensor.cpp # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/aclnn_ops.h # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/rms_norm.cl # scripts/create_ops_docs.py # tests/test-backend-ops.cpp # tools/export-lora/export-lora.cpp
This commit is contained in:
commit
21b7d0a899
19 changed files with 1529 additions and 800 deletions
|
@ -3791,7 +3791,7 @@ class Plamo2Model(TextModel):
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(block_count)
|
||||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
|
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
||||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
|
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
||||||
|
|
||||||
# Mamba parameters
|
# Mamba parameters
|
||||||
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
|
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
|
||||||
|
@ -3802,7 +3802,7 @@ class Plamo2Model(TextModel):
|
||||||
self.gguf_writer.add_ssm_group_count(0)
|
self.gguf_writer.add_ssm_group_count(0)
|
||||||
|
|
||||||
# MLP feed forward parameters (for attention layers)
|
# MLP feed forward parameters (for attention layers)
|
||||||
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
|
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 13312))
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
|
|
@ -56,7 +56,7 @@
|
||||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
||||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
||||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
||||||
|
|
||||||
|
@ -72,8 +72,9 @@
|
||||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
||||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
// Moore Threads
|
// Moore Threads
|
||||||
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
||||||
|
@ -230,6 +231,10 @@ typedef float2 dfloat2;
|
||||||
#define FP16_MMA_AVAILABLE
|
#define FP16_MMA_AVAILABLE
|
||||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
|
||||||
|
#define AMD_MFMA_AVAILABLE
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
#define NEW_MMA_AVAILABLE
|
#define NEW_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
|
@ -292,6 +297,11 @@ static bool fp32_mma_hardware_available(const int cc) {
|
||||||
return GGML_CUDA_CC_IS_CDNA(cc);
|
return GGML_CUDA_CC_IS_CDNA(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
|
||||||
|
static bool amd_mfma_available(const int cc) {
|
||||||
|
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
|
||||||
|
}
|
||||||
|
|
||||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||||
static bool new_mma_available(const int cc) {
|
static bool new_mma_available(const int cc) {
|
||||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||||
|
|
|
@ -1330,14 +1330,16 @@ static __global__ void flash_attn_ext_f16(
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||||
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,16 +37,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||||
GGML_UNUSED(nb23);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -282,16 +282,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
||||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
|
||||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||||
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
|
@ -329,16 +329,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||||
GGML_UNUSED(nb23);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,8 @@
|
||||||
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||||
//
|
//
|
||||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||||
|
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
|
||||||
|
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
|
@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
|
||||||
struct tile {
|
struct tile {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
static constexpr int ne = I * J / 64;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else if constexpr (I == 32 && J == 4) {
|
||||||
|
return threadIdx.x % 32;
|
||||||
|
} else if constexpr (I == 16 && J == 16) {
|
||||||
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
|
} else if constexpr (I == 32 && J == 32) {
|
||||||
|
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||||
|
return (2 * ((threadIdx.x / 16) % 2) + l);
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
return 2 * (threadIdx.x / 16) + l;
|
||||||
|
} else if constexpr (I == 32 && J == 4) {
|
||||||
|
return 2 * (threadIdx.x / 32) + l;
|
||||||
|
} else if constexpr (I == 16 && J == 16) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else if constexpr (I == 32 && J == 32) {
|
||||||
|
return threadIdx.x % 32;
|
||||||
|
} else {
|
||||||
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
|
||||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
|
@ -148,10 +187,23 @@ namespace ggml_cuda_mma {
|
||||||
|
|
||||||
template <int I, int J, typename T>
|
template <int I, int J, typename T>
|
||||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int64_t * xi = (int64_t *) t.x;
|
||||||
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||||
|
xi[0] = xs[0];
|
||||||
|
}
|
||||||
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
}
|
}
|
||||||
|
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ __forceinline__ void load_ldmatrix(
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#ifdef NEW_MMA_AVAILABLE
|
#if defined(NEW_MMA_AVAILABLE)
|
||||||
int * xi = (int * ) t.x;
|
int * xi = (int * ) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
|
@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // NEW_MMA_AVAILABLE
|
#endif // NEW_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
|
int32x4_t * acc = (int32x4_t *) D.x;
|
||||||
|
#if defined(CDNA3)
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
||||||
|
((int64_t *) B.x)[0],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
#elif defined(CDNA2) || defined(CDNA)
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
||||||
|
B.x[0],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
||||||
|
B.x[1],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
#endif // defined(CDNA3)
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(D);
|
||||||
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // AMD_MFMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
|
||||||
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
||||||
|
int32x16_t * acc = (int32x16_t *) D.x;
|
||||||
|
#if defined(CDNA3)
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
||||||
|
((int64_t *) B.x)[0],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
#elif defined(CDNA2) || defined(CDNA)
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
||||||
|
B.x[0],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
||||||
|
B.x[1],
|
||||||
|
acc[0],
|
||||||
|
0, 0, 0);
|
||||||
|
#endif // defined(CDNA3)
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(D);
|
||||||
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // AMD_MFMA_AVAILABLE
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q(
|
||||||
const int64_t s03 = src0->nb[3] / ts_src0;
|
const int64_t s03 = src0->nb[3] / ts_src0;
|
||||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||||
|
|
||||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||||
|
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
|
||||||
|
|
||||||
if (!ids) {
|
if (!ids) {
|
||||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||||
|
@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
|
||||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
|
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
|
||||||
|
&& src1_ncols == ne11;
|
||||||
const mmq_args args = {
|
const mmq_args args = {
|
||||||
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
||||||
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
||||||
|
@ -306,7 +308,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (new_mma_available(cc)) {
|
if (new_mma_available(cc) || amd_mfma_available(cc)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -44,6 +44,9 @@ static __global__ void k_set_rows_quant(
|
||||||
block_type * dst_block = dst_row_ptr + i00 / qk;
|
block_type * dst_block = dst_row_ptr + i00 / qk;
|
||||||
|
|
||||||
quantize_func(src_block, dst_block);
|
quantize_func(src_block, dst_block);
|
||||||
|
|
||||||
|
GGML_UNUSED(ne10);
|
||||||
|
GGML_UNUSED(ne13);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template dispatch function for quantized set_rows
|
// Template dispatch function for quantized set_rows
|
||||||
|
|
14
ggml/src/ggml-cuda/vendors/hip.h
vendored
14
ggml/src/ggml-cuda/vendors/hip.h
vendored
|
@ -160,7 +160,19 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||||
#define CDNA
|
#define CDNA // For the entire family
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx942__)
|
||||||
|
#define CDNA3
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx90a__)
|
||||||
|
#define CDNA2
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx908__)
|
||||||
|
#define CDNA1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__GFX12__)
|
#if defined(__GFX12__)
|
||||||
|
|
|
@ -528,6 +528,7 @@ typedef struct {
|
||||||
int64_t n_group;
|
int64_t n_group;
|
||||||
int64_t n_seq_tokens;
|
int64_t n_seq_tokens;
|
||||||
int64_t n_seqs;
|
int64_t n_seqs;
|
||||||
|
int64_t s_off;
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
uint64_t nb02;
|
uint64_t nb02;
|
||||||
uint64_t nb03;
|
uint64_t nb03;
|
||||||
|
|
|
@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
|
||||||
/*.n_group =*/ n_group,
|
/*.n_group =*/ n_group,
|
||||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
|
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb01 =*/ nb01,
|
||||||
/*.nb02 =*/ nb02,
|
/*.nb02 =*/ nb02,
|
||||||
/*.nb03 =*/ nb03,
|
/*.nb03 =*/ nb03,
|
||||||
|
@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||||
|
|
||||||
|
// One shared memory bucket for each simd group in the threadgroup
|
||||||
|
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
||||||
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||||
|
if (d_state >= 32) {
|
||||||
|
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
||||||
|
const int64_t shmem_size = 32;
|
||||||
|
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
||||||
|
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
||||||
|
}
|
||||||
|
|
||||||
if (ne30 == 1) {
|
if (ne30 == 1) {
|
||||||
// Mamba-2
|
// Mamba-2
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(d_inner == 1);
|
GGML_ASSERT(d_inner == 1);
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
|
|
@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
|
||||||
device const void * src5,
|
device const void * src5,
|
||||||
device const void * src6,
|
device const void * src6,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
threadgroup float * shared [[threadgroup(0)]],
|
||||||
constant ggml_metal_kargs_ssm_scan & args,
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||||
|
|
||||||
|
const int64_t i0 = tpitg.x;
|
||||||
const int64_t i1 = 0;
|
const int64_t i1 = 0;
|
||||||
const int64_t ir = tgpig.x; // current head
|
const int64_t ir = tgpig.x; // current head
|
||||||
const int64_t i3 = tgpig.y; // current seq
|
const int64_t i3 = tgpig.y; // current seq
|
||||||
|
@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
|
||||||
const int64_t ng = args.n_group;
|
const int64_t ng = args.n_group;
|
||||||
const int64_t n_t = args.n_seq_tokens;
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
|
|
||||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
const int64_t s_off = args.s_off;
|
||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src6;
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
const int64_t i = i0 + i1*nc;
|
||||||
|
float s0 = s0_buff[i];
|
||||||
|
float s = s_buff[i];
|
||||||
|
|
||||||
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
||||||
|
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||||
|
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||||
|
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||||
|
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||||
|
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
||||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
float sumf = 0.0f;
|
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||||
const int64_t i = i0 + i1*nc;
|
s = state;
|
||||||
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
||||||
sumf += state * C[i0];
|
// Parallel sum: This relies on the fact that this kernel will be
|
||||||
s[i] = state;
|
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||||
|
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
||||||
|
// compute y = sum({state * C[i] for i in range(d_state)}).
|
||||||
|
// To parallelize this effectively, we first use simd_sum over each SIMD
|
||||||
|
// group to compute the sum of each SIMD group, then place the result in
|
||||||
|
// the SIMD group's indexed bucket in the shared memory. We then sum
|
||||||
|
// over the individual group sums to compute the final sum.
|
||||||
|
|
||||||
|
// Computed for each thread
|
||||||
|
float sumf = state * C[i0];
|
||||||
|
|
||||||
|
// Sum the threads in the simd group => simd sum
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
if (sgptg > 1) {
|
||||||
|
|
||||||
|
// Once per simd group, place the group sum into the shared buffer
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shared[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all threads in the threadgroup to reach this point. This
|
||||||
|
// ensures that all elements of the shared buffer are populated with the
|
||||||
|
// sum of the individual simd groups.
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// For simd group 0 at indices < num simd groups, extract the shared
|
||||||
|
// simd sum
|
||||||
|
sumf = 0.0f;
|
||||||
|
if (sgitg == 0) {
|
||||||
|
if (tiisg < sgptg) {
|
||||||
|
sumf = shared[tiisg];
|
||||||
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
y[0] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (tiisg == 0) {
|
||||||
|
y[0] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = sumf;
|
|
||||||
|
|
||||||
// recurse
|
// recurse
|
||||||
s0 = s;
|
s0 = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign the final state to the output buffer
|
||||||
|
s_buff[i] = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||||
// TODO: optimize (e.g. by parallelizing over d_state)
|
|
||||||
kernel void kernel_ssm_scan_f32_group(
|
kernel void kernel_ssm_scan_f32_group(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const void * src1,
|
device const void * src1,
|
||||||
|
@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
device const void * src5,
|
device const void * src5,
|
||||||
device const void * src6,
|
device const void * src6,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
threadgroup float * shared [[threadgroup(0)]],
|
||||||
constant ggml_metal_kargs_ssm_scan & args,
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||||
|
|
||||||
|
const int64_t i0 = tpitg.x;
|
||||||
const int64_t i1 = tgpig.x;
|
const int64_t i1 = tgpig.x;
|
||||||
const int64_t ir = tgpig.y; // current head
|
const int64_t ir = tgpig.y; // current head
|
||||||
const int64_t i3 = tgpig.z; // current seq
|
const int64_t i3 = tgpig.z; // current seq
|
||||||
|
@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
const int64_t ng = args.n_group;
|
const int64_t ng = args.n_group;
|
||||||
const int64_t n_t = args.n_seq_tokens;
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
|
|
||||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
const int64_t s_off = args.s_off;
|
||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src6;
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
const int64_t i = i0 + i1*nc;
|
||||||
|
float s0 = s0_buff[i];
|
||||||
|
float s = s_buff[i];
|
||||||
|
|
||||||
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||||
|
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||||
|
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||||
|
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||||
|
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||||
|
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
||||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
const float dA = exp(dt_soft_plus * A[0]);
|
const float dA = exp(dt_soft_plus * A[0]);
|
||||||
float sumf = 0.0f;
|
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
const float state = (s0 * dA) + (B[i0] * x_dt);
|
||||||
const int64_t i = i0 + i1*nc;
|
s = state;
|
||||||
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
|
||||||
sumf += state * C[i0];
|
// Parallel sum: This relies on the fact that this kernel will be
|
||||||
s[i] = state;
|
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||||
|
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
||||||
|
// compute y = sum({state * C[i] for i in range(d_state)}).
|
||||||
|
// To parallelize this effectively, we first use simd_sum over each SIMD
|
||||||
|
// group to compute the sum of each SIMD group, then place the result in
|
||||||
|
// the SIMD group's indexed bucket in the shared memory. We then sum
|
||||||
|
// over the individual group sums to compute the final sum.
|
||||||
|
|
||||||
|
// Computed for each thread
|
||||||
|
float sumf = state * C[i0];
|
||||||
|
|
||||||
|
// Sum the threads in the simd group => simd sum
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
// Once per simd group, place the group sum into the shared buffer
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shared[sgitg] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = sumf;
|
// Wait for all threads in the threadgroup to reach this point. This
|
||||||
|
// ensures that all elements of the shared buffer are populated with the
|
||||||
|
// sum of the individual simd groups.
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// For simd group 0 at indices < num simd groups, extract the shared
|
||||||
|
// simd sum
|
||||||
|
sumf = 0.0f;
|
||||||
|
if (sgitg == 0) {
|
||||||
|
if (tiisg < sgptg) {
|
||||||
|
sumf = shared[tiisg];
|
||||||
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
y[0] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// recurse
|
// recurse
|
||||||
s0 = s;
|
s0 = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign the final state to the output buffer
|
||||||
|
s_buff[i] = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_rwkv_wkv6_f32(
|
kernel void kernel_rwkv_wkv6_f32(
|
||||||
|
|
|
@ -6656,20 +6656,18 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr
|
||||||
static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
|
static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
|
||||||
struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
|
struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
|
||||||
struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
|
struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
|
||||||
fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
|
fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
|
||||||
gparent0 ? (void *) gparent0 : (void *) parent,
|
gparent0 ? (void *) gparent0 : (void *) parent,
|
||||||
gparent0 ? "g" : "x",
|
|
||||||
gparent ? (void *) gparent : (void *) node,
|
gparent ? (void *) gparent : (void *) node,
|
||||||
gparent ? "g" : "x",
|
|
||||||
gparent ? "empty" : "vee",
|
gparent ? "empty" : "vee",
|
||||||
gparent ? "dashed" : "solid",
|
gparent ? "dashed" : "solid",
|
||||||
label);
|
label);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
|
static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
|
||||||
fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
|
fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
|
||||||
(void *) parent, "x",
|
(void *) parent,
|
||||||
(void *) node, "x",
|
(void *) node,
|
||||||
label);
|
label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
{
|
{
|
||||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||||
const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
|
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
|
||||||
|
|
||||||
if (!supports_set_rows && !cparams.kv_unified) {
|
if (!supports_set_rows && !cparams.kv_unified) {
|
||||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||||
|
@ -899,6 +899,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
|
// overlap with device computation.
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: hacky solution
|
// TODO: hacky solution
|
||||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||||
//cross.t_embd = t_embd;
|
//cross.t_embd = t_embd;
|
||||||
|
@ -1229,6 +1235,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
//synchronize();
|
//synchronize();
|
||||||
|
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
|
// overlap with device computation.
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -287,6 +287,10 @@ private:
|
||||||
|
|
||||||
bool has_evaluated_once = false;
|
bool has_evaluated_once = false;
|
||||||
|
|
||||||
|
// env: LLAMA_SET_ROWS (temporary)
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
||||||
|
bool supports_set_rows = false;
|
||||||
|
|
||||||
// perf
|
// perf
|
||||||
mutable int64_t t_start_us = 0;
|
mutable int64_t t_start_us = 0;
|
||||||
mutable int64_t t_load_us = 0;
|
mutable int64_t t_load_us = 0;
|
||||||
|
|
|
@ -98,7 +98,7 @@ struct llama_hparams {
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
float rope_freq_scale_train_swa;
|
float rope_freq_scale_train_swa;
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul;
|
float rope_yarn_log_mul = 0.0f;
|
||||||
|
|
||||||
std::array<int, 4> rope_sections;
|
std::array<int, 4> rope_sections;
|
||||||
|
|
||||||
|
|
|
@ -1374,7 +1374,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
// that have no expert_gating_func model parameter set
|
// that have no expert_gating_func model parameter set
|
||||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
||||||
}
|
}
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 27: type = LLM_TYPE_16B; break;
|
case 27: type = LLM_TYPE_16B; break;
|
||||||
|
@ -16291,7 +16291,7 @@ private:
|
||||||
{
|
{
|
||||||
// PLaMo-2 uses combined QKV tensor
|
// PLaMo-2 uses combined QKV tensor
|
||||||
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
|
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
|
||||||
cb(qkv, "qkv", il);
|
cb(qkv, "wqkv", il);
|
||||||
|
|
||||||
// split QKV tensor into Q, K, V
|
// split QKV tensor into Q, K, V
|
||||||
const int64_t n_embd_head_q = hparams.n_embd_head_k;
|
const int64_t n_embd_head_q = hparams.n_embd_head_k;
|
||||||
|
@ -16331,7 +16331,7 @@ private:
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
|
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
|
||||||
}
|
}
|
||||||
|
|
||||||
cb(cur, "attn_out", il);
|
cb(cur, "attn_out", il);
|
||||||
|
@ -16406,8 +16406,9 @@ private:
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0, last_conv,
|
ggml_cpy(ctx0, last_conv,
|
||||||
ggml_view_1d(ctx0, conv_states_all,
|
ggml_view_1d(ctx0, conv_states_all,
|
||||||
(d_conv - 1)*(d_inner)*(n_seqs),
|
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
|
||||||
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
|
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
|
||||||
|
cb(conv_states_all, "mamba_conv1d_state", il);
|
||||||
|
|
||||||
// 1D convolution
|
// 1D convolution
|
||||||
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
|
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
|
||||||
|
@ -16470,9 +16471,9 @@ private:
|
||||||
// store last states
|
// store last states
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
|
ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
|
||||||
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
|
ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
|
||||||
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
cb(ssm_states_all, "mamba_ssm_states", il);
|
||||||
|
|
||||||
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
|
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
|
||||||
cb(y, "mamba_y_view", il);
|
cb(y, "mamba_y_view", il);
|
||||||
|
|
|
@ -2366,7 +2366,7 @@ struct clip_model_loader {
|
||||||
|
|
||||||
// create data context
|
// create data context
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(),
|
/*.mem_size =*/ static_cast<size_t>(gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue