diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 829d68368..0d4ea03b4 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1078,17 +1078,76 @@ class MiniCPMModel(Model): self.gguf_writer.add_name("MiniCPM") self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) def set_vocab(self): self._set_vocab_hf() + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # HF models permute some of the tensors, so we need to undo that + if name.endswith(("q_proj.weight")): + data_torch = self._reverse_hf_permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight")): + data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + class QwenModel(Model): @staticmethod diff --git a/examples/llava/README.md b/examples/llava/README.md index 295181a34..721d5e613 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -14,7 +14,7 @@ Build with cmake or run `make llava-cli` to build it. After building, run: `./llava-cli` to see the usage. For example: ```sh -./llava-cli -m llava-v1.5-7b/ggml-model-q5_k.gguf --mmproj llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg +./llava-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg ``` **note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. @@ -38,7 +38,7 @@ python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b 3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF: ```sh -python ./examples/llava/convert-image-encoder-to-gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b +python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b ``` 4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 53b65e632..27fe34cbd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5311,22 +5311,26 @@ template static __global__ void #endif // __CUDA_ARCH__ >= CC_VOLTA } -template +#define MMVQ_NWARPS_NVIDIA 4 +#define MMVQ_NWARPS_AMD_RDNA2 1 +#define MMVQ_NWARPS_AMD_OLD 4 + +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; - const int row = blockIdx.x*blockDim.y + threadIdx.y; - - if (row >= nrows_x) { - return; - } + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row = blockIdx.x; const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; // partial sum for each thread float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f}; @@ -5334,12 +5338,12 @@ static __global__ void mul_mat_vec_q( const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) { + for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) { const int ibx = row*blocks_per_row_x + i; // x block index const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int + const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int #pragma unroll for (int j = 0; j < ncols_y; ++j) { @@ -5347,9 +5351,25 @@ static __global__ void mul_mat_vec_q( } } + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { + tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j]; + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < nwarps-1; ++i) { + tmp[j] += tmp_shared[i][j][threadIdx.x]; + } tmp[j] = warp_reduce_sum(tmp[j]); if (threadIdx.x == 0) { @@ -6834,46 +6854,65 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= 4); - const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - switch (ncols_y) { - case 1: - mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 2: - mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 3: - mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 4: - mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - // case 5: - // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 6: - // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 7: - // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 8: - // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + int nwarps; + if (g_device_caps[id].cc >= CC_OFFSET_AMD) { + nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD; + } else { + nwarps = MMVQ_NWARPS_NVIDIA; + } + + const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + switch (nwarps) { + case 1: switch(ncols_y) { + case 1: + mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 2: + mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 3: + mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 4: + mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } break; + case 4: switch(ncols_y) { + case 1: + mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 2: + mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 3: + mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 4: + mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } break; + default: GGML_ASSERT(false); - // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; } } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a03df4c65..dd562a898 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -12148,7 +12148,8 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream) { - const int64_t ne00 = src0->ne[0]; + GGML_TENSOR_BINARY_OP_LOCALS + const int64_t row_diff = row_high - row_low; // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics @@ -12167,8 +12168,9 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec( } else { src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_sycl((const char *)src1_ddf_i, (char *)src1_dfloat, - ne00, ne00, 1, sizeof(float), 0, 0, ne00, 1, - sizeof(sycl::half), 0, 0, stream); + ne00, ne00, ne01, ne02, nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, + nb13, stream); } } #else diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 9e2846ee4..254f648a6 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -744,6 +744,8 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz } if (memory_type_index >= mem_props.memoryTypeCount) { + ctx->device.lock()->device.destroyBuffer(buf->buffer); + buf->size = 0; throw vk::OutOfDeviceMemoryError("No suitable memory type found"); } @@ -3875,7 +3877,7 @@ static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){ #ifdef GGML_VULKAN_DEBUG - std::cerr << "ggml_ctx->preallocate_buffers_graph(" << node << ")" << std::endl; + std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl; #endif const bool any_on_device = node->backend == GGML_BACKEND_GPU || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) @@ -3994,8 +3996,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { return; } #ifdef GGML_VULKAN_DEBUG - std::cerr << "ggml_ctx->preallocate_buffers()" << std::endl; - std::cerr << "qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << std::endl; + std::cerr << "ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << ")" << std::endl; #endif #if defined(GGML_VULKAN_RUN_TESTS) ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached); diff --git a/llama.cpp b/llama.cpp index cc59f781c..c99e091e2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2993,6 +2993,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_MINICPM: { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 40: model.type = e_model::MODEL_2B; break; default: model.type = e_model::MODEL_UNKNOWN; @@ -4279,8 +4281,7 @@ static bool llm_load_tensors( ctx_bufs.emplace_back(ctx, buf); } - // print memory requirements - { + if (llama_supports_gpu_offload()) { const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); @@ -4292,10 +4293,11 @@ static bool llm_load_tensors( const int max_offloadable_layers = hparams.n_layer + 1; LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } - for (ggml_backend_buffer_t buf : model.bufs) { - LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); - } + // print memory requirements + for (ggml_backend_buffer_t buf : model.bufs) { + LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); } // populate tensors_by_name @@ -8890,7 +8892,7 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can // } const int64_t t_start_sample_us = ggml_time_us(); - + if (k <= 0) { k = candidates->size; }