diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 58c8ea447..33fd8cf0e 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -43,10 +43,17 @@ #include #include -#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -#define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -#define LOG_DBG(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +#if defined(LLAVA_LOG_OFF) +# define LOG_INF(...) +# define LOG_WRN(...) +# define LOG_ERR(...) +# define LOG_DBG(...) +#else // defined(LLAVA_LOG_OFF) +# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +#endif // defined(LLAVA_LOG_OFF) //#define CLIP_DEBUG_FUNCTIONS diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index be6988540..4ca53a0b8 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -11,13 +11,17 @@ #include #include -#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) -#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) - -#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -#define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -#define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +#if defined(LLAVA_LOG_OFF) +# define LOG_INF(...) +# define LOG_WRN(...) +# define LOG_ERR(...) +# define LOG_DBG(...) +#else // defined(LLAVA_LOG_OFF) +# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +#endif // defined(LLAVA_LOG_OFF) // RGB uint8 image struct clip_image_u8 { @@ -498,10 +502,16 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long errno = 0; size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer if (ferror(file)) { - die_fmt("read error: %s", strerror(errno)); + LOG_ERR("read error: %s", strerror(errno)); + free(buffer); + fclose(file); + return false; } if (ret != (size_t) fileSize) { - die("unexpectedly reached end of file"); + LOG_ERR("unexpectedly reached end of file"); + free(buffer); + fclose(file); + return false; } fclose(file); // Close the file diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index 84db5ca1c..d82d54a5a 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -32,3 +32,17 @@ def test_server_models(): assert res.status_code == 200 assert len(res.body["data"]) == 1 assert res.body["data"][0]["id"] == server.model_alias + +def test_load_split_model(): + global server + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" + server.model_alias = "tinyllama-split" + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 16, + "prompt": "Hello", + "temperature": 0.0, + }) + assert res.status_code == 200 + assert match_regex("(little|girl)+", res.body["content"]) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index d7aeb288d..1048d6fca 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert res.status_code != 200 assert "error" in res.body + +@pytest.mark.parametrize("messages", [ + None, + "string", + [123], + [{}], + [{"role": 123}], + [{"role": "system", "content": 123}], + # [{"content": "hello"}], # TODO: should not be a valid case + [{"role": "system", "content": "test"}, {}], +]) +def test_invalid_chat_completion_req(messages): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "messages": messages, + }) + assert res.status_code == 400 or res.status_code == 500 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index 38ce6c429..6a6d40a1c 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -8,6 +8,7 @@ def create_server(): global server server = ServerPreset.tinyllama_infill() + def test_infill_without_input_extra(): global server server.start() @@ -19,6 +20,7 @@ def test_infill_without_input_extra(): assert res.status_code == 200 assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) + def test_infill_with_input_extra(): global server server.start() @@ -33,3 +35,23 @@ def test_infill_with_input_extra(): }) assert res.status_code == 200 assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) + + +@pytest.mark.parametrize("input_extra", [ + {}, + {"filename": "ok"}, + {"filename": 123}, + {"filename": 123, "text": "abc"}, + {"filename": 123, "text": 456}, +]) +def test_invalid_input_extra_req(input_extra): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "prompt": "Complete this", + "input_extra": [input_extra], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 400 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py index 3a49fd3ac..189bc4c96 100644 --- a/examples/server/tests/unit/test_rerank.py +++ b/examples/server/tests/unit/test_rerank.py @@ -36,3 +36,20 @@ def test_rerank(): assert most_relevant["relevance_score"] > least_relevant["relevance_score"] assert most_relevant["index"] == 2 assert least_relevant["index"] == 3 + + +@pytest.mark.parametrize("documents", [ + [], + None, + 123, + [1, 2, 3], +]) +def test_invalid_rerank_req(documents): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": documents, + }) + assert res.status_code == 400 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py new file mode 100644 index 000000000..982d6abb4 --- /dev/null +++ b/examples/server/tests/unit/test_speculative.py @@ -0,0 +1,103 @@ +import pytest +from utils import * + +# We use a F16 MOE gguf as main model, and q4_0 as draft model + +server = ServerPreset.stories15m_moe() + +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" + +def create_server(): + global server + server = ServerPreset.stories15m_moe() + # download draft model file if needed + file_name = MODEL_DRAFT_FILE_URL.split('/').pop() + model_draft_file = f'../../../{file_name}' + if not os.path.exists(model_draft_file): + print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") + with open(model_draft_file, 'wb') as f: + f.write(requests.get(MODEL_DRAFT_FILE_URL).content) + print(f"Done downloading draft model file") + # set default values + server.model_draft = model_draft_file + server.draft_min = 4 + server.draft_max = 8 + + +@pytest.fixture(scope="module", autouse=True) +def fixture_create_server(): + return create_server() + + +def test_with_and_without_draft(): + global server + server.model_draft = None # disable draft model + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_no_draft = res.body["content"] + server.stop() + + # create new server with draft model + create_server() + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_draft = res.body["content"] + + assert content_no_draft == content_draft + + +def test_different_draft_min_draft_max(): + global server + test_values = [ + (1, 2), + (1, 4), + (4, 8), + (4, 12), + (8, 16), + ] + last_content = None + for draft_min, draft_max in test_values: + server.stop() + server.draft_min = draft_min + server.draft_max = draft_max + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + if last_content is not None: + assert last_content == res.body["content"] + last_content = res.body["content"] + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 2), + (2, 2), +]) +def test_multi_requests_parallel(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.start() + tasks = [] + for _ in range(n_requests): + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }))) + results = parallel_function_calls(tasks) + for res in results: + assert res.status_code == 200 + assert match_regex("(wise|kind|owl|answer)+", res.body["content"]) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a831f113f..e17a05ff6 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -46,6 +46,7 @@ class ServerProcess: model_alias: str | None = None model_url: str | None = None model_file: str | None = None + model_draft: str | None = None n_threads: int | None = None n_gpu_layer: int | None = None n_batch: int | None = None @@ -68,6 +69,8 @@ class ServerProcess: response_format: str | None = None lora_files: List[str] | None = None disable_ctx_shift: int | None = False + draft_min: int | None = None + draft_max: int | None = None # session variables process: subprocess.Popen | None = None @@ -102,6 +105,8 @@ class ServerProcess: server_args.extend(["--model", self.model_file]) if self.model_url: server_args.extend(["--model-url", self.model_url]) + if self.model_draft: + server_args.extend(["--model-draft", self.model_draft]) if self.model_hf_repo: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: @@ -147,6 +152,10 @@ class ServerProcess: server_args.extend(["--no-context-shift"]) if self.api_key: server_args.extend(["--api-key", self.api_key]) + if self.draft_max: + server_args.extend(["--draft-max", self.draft_max]) + if self.draft_min: + server_args.extend(["--draft-min", self.draft_min]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") @@ -185,7 +194,8 @@ class ServerProcess: raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") def stop(self) -> None: - server_instances.remove(self) + if self in server_instances: + server_instances.remove(self) if self.process: print(f"Stopping server with pid={self.process.pid}") self.process.kill() diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index bcb54e444..04e25b8ab 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1739,7 +1739,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_ROPE: { // TODO: with ops-test v == 1 float * ext_factor = (float*)((int32_t*)op->op_params + 7); - float * attn_factor = (float*)((int32_t*)op->op_params + 8); // TODO: n_dims <= ne0 if (op->src[0]->ne[0] != op->op_params[1]) { return false; @@ -1748,17 +1747,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (*ext_factor != 0) { return false; } - // TODO: attn_factor != 1 - if (*attn_factor != 1) { - return false; - } - //TODO: type == GGML_TYPE_F16 - switch (op->src[0]->type) { - case GGML_TYPE_F32: - return true; - default: - return false; - } + return true; } case GGML_OP_UPSCALE: { // aclnnUpsampleNearest2dGetWorkspaceSize not support diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.c b/ggml/src/ggml-cpu/ggml-cpu-aarch64.c index 69d3d327d..14a1f00eb 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.c @@ -1020,7 +1020,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void float * res_ptr = s; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); float32x4_t sumf = vdupq_n_f32(0); for (int l = 0; l < nb; l++) { @@ -3507,7 +3507,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); float32x4_t sumf[4]; for (int m = 0; m < 4; m++) { diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index e53aae7e3..33c455c73 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1792,11 +1792,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; - + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -1812,7 +1813,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); @@ -2348,10 +2349,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r const block_q8_1 * restrict b_y0 = &vy0[i]; const block_q8_1 * restrict b_y1 = &vy1[i]; - float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; + float32_t summs_t[4] = { + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s) + }; summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -2372,10 +2375,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); // mmla into int32x4_t - float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, - GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, - GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, - GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -2390,15 +2395,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = vaddq_f32(sumv2, summs0); vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); + return; } #endif @@ -3375,10 +3382,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -3394,13 +3403,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); + return; } #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b160c9d34..b7a6009d9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7677,8 +7677,8 @@ UseGgmlGemm2:; // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too + // these checks are needed to avoid crossing dim1 boundaries + // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { num_rows_per_vec_dot = 1; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b6392ed8d..808f74fa0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3447,8 +3447,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch - ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); + // TODO: Refactor and cleanup of mul mat dispatching. + if (src0->ne[3] == 1 && src1->ne[3] == 1) { + // KQ single-batch + // mmv p021 was specific for these dimensions + ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); + } else { + // The kernel from the if path is faster for that specific case, but does not support all mul mats. + ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); + } } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); @@ -4486,7 +4493,7 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_ static int64_t get_op_batch_size(const ggml_tensor * op) { switch (op->op) { case GGML_OP_GET_ROWS: - return op->ne[1]; // this will increse the speed of prefill in test + return 0; case GGML_OP_MUL_MAT: return op->ne[1]; case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1b443c556..9539ee96b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5672,6 +5672,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } else { compute_ctx = ctx->compute_ctx.lock(); } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_UNARY: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_LEAKY_RELU: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return false; + } + default: + break; + } } switch (node->op) { @@ -6401,16 +6443,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch - // submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution - constexpr int submit_count = 100; + // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. + // Start with a smaller count to get work submitted right away, and increase it after each submit. + int nodes_per_submit = 20; int submitted_nodes = 0; + int submit_count = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (first_node_in_batch) { submit_node_idx = i; } - bool submit = (submitted_nodes >= submit_count) || (i == last_node); - + bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); @@ -6427,6 +6470,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (submit) { first_node_in_batch = true; submitted_nodes = 0; + switch (submit_count) { + case 0: + nodes_per_submit = 50; + break; + default: + nodes_per_submit = 100; + break; + } + submit_count++; } }