temp checkpoint

This commit is contained in:
Concedo 2024-11-30 12:13:20 +08:00
commit 697ca70115
14 changed files with 321 additions and 60 deletions

View file

@ -43,10 +43,17 @@
#include <cinttypes> #include <cinttypes>
#include <limits> #include <limits>
#if defined(LLAVA_LOG_OFF)
# define LOG_INF(...)
# define LOG_WRN(...)
# define LOG_ERR(...)
# define LOG_DBG(...)
#else // defined(LLAVA_LOG_OFF)
# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) # define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) # define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) # define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
#define LOG_DBG(...) do { fprintf(stderr, __VA_ARGS__); } while (0) # define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
#endif // defined(LLAVA_LOG_OFF)
//#define CLIP_DEBUG_FUNCTIONS //#define CLIP_DEBUG_FUNCTIONS

View file

@ -11,13 +11,17 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) #if defined(LLAVA_LOG_OFF)
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) # define LOG_INF(...)
# define LOG_WRN(...)
# define LOG_ERR(...)
# define LOG_DBG(...)
#else // defined(LLAVA_LOG_OFF)
# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) # define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) # define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) # define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) # define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
#endif // defined(LLAVA_LOG_OFF)
// RGB uint8 image // RGB uint8 image
struct clip_image_u8 { struct clip_image_u8 {
@ -498,10 +502,16 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
errno = 0; errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) { if (ferror(file)) {
die_fmt("read error: %s", strerror(errno)); LOG_ERR("read error: %s", strerror(errno));
free(buffer);
fclose(file);
return false;
} }
if (ret != (size_t) fileSize) { if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file"); LOG_ERR("unexpectedly reached end of file");
free(buffer);
fclose(file);
return false;
} }
fclose(file); // Close the file fclose(file); // Close the file

View file

@ -32,3 +32,17 @@ def test_server_models():
assert res.status_code == 200 assert res.status_code == 200
assert len(res.body["data"]) == 1 assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias assert res.body["data"][0]["id"] == server.model_alias
def test_load_split_model():
global server
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])

View file

@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert res.status_code != 200 assert res.status_code != 200
assert "error" in res.body assert "error" in res.body
@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body

View file

@ -8,6 +8,7 @@ def create_server():
global server global server
server = ServerPreset.tinyllama_infill() server = ServerPreset.tinyllama_infill()
def test_infill_without_input_extra(): def test_infill_without_input_extra():
global server global server
server.start() server.start()
@ -19,6 +20,7 @@ def test_infill_without_input_extra():
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
def test_infill_with_input_extra(): def test_infill_with_input_extra():
global server global server
server.start() server.start()
@ -33,3 +35,23 @@ def test_infill_with_input_extra():
}) })
assert res.status_code == 200 assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [
{},
{"filename": "ok"},
{"filename": 123},
{"filename": 123, "text": "abc"},
{"filename": 123, "text": 456},
])
def test_invalid_input_extra_req(input_extra):
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 400
assert "error" in res.body

View file

@ -36,3 +36,20 @@ def test_rerank():
assert most_relevant["relevance_score"] > least_relevant["relevance_score"] assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2 assert most_relevant["index"] == 2
assert least_relevant["index"] == 3 assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [
[],
None,
123,
[1, 2, 3],
])
def test_invalid_rerank_req(documents):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": documents,
})
assert res.status_code == 400
assert "error" in res.body

View file

@ -0,0 +1,103 @@
import pytest
from utils import *
# We use a F16 MOE gguf as main model, and q4_0 as draft model
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
def create_server():
global server
server = ServerPreset.stories15m_moe()
# download draft model file if needed
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
model_draft_file = f'../../../{file_name}'
if not os.path.exists(model_draft_file):
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
with open(model_draft_file, 'wb') as f:
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
print(f"Done downloading draft model file")
# set default values
server.model_draft = model_draft_file
server.draft_min = 4
server.draft_max = 8
@pytest.fixture(scope="module", autouse=True)
def fixture_create_server():
return create_server()
def test_with_and_without_draft():
global server
server.model_draft = None # disable draft model
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_no_draft = res.body["content"]
server.stop()
# create new server with draft model
create_server()
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_draft = res.body["content"]
assert content_no_draft == content_draft
def test_different_draft_min_draft_max():
global server
test_values = [
(1, 2),
(1, 4),
(4, 8),
(4, 12),
(8, 16),
]
last_content = None
for draft_min, draft_max in test_values:
server.stop()
server.draft_min = draft_min
server.draft_max = draft_max
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
if last_content is not None:
assert last_content == res.body["content"]
last_content = res.body["content"]
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})))
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])

View file

@ -46,6 +46,7 @@ class ServerProcess:
model_alias: str | None = None model_alias: str | None = None
model_url: str | None = None model_url: str | None = None
model_file: str | None = None model_file: str | None = None
model_draft: str | None = None
n_threads: int | None = None n_threads: int | None = None
n_gpu_layer: int | None = None n_gpu_layer: int | None = None
n_batch: int | None = None n_batch: int | None = None
@ -68,6 +69,8 @@ class ServerProcess:
response_format: str | None = None response_format: str | None = None
lora_files: List[str] | None = None lora_files: List[str] | None = None
disable_ctx_shift: int | None = False disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
# session variables # session variables
process: subprocess.Popen | None = None process: subprocess.Popen | None = None
@ -102,6 +105,8 @@ class ServerProcess:
server_args.extend(["--model", self.model_file]) server_args.extend(["--model", self.model_file])
if self.model_url: if self.model_url:
server_args.extend(["--model-url", self.model_url]) server_args.extend(["--model-url", self.model_url])
if self.model_draft:
server_args.extend(["--model-draft", self.model_draft])
if self.model_hf_repo: if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo]) server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file: if self.model_hf_file:
@ -147,6 +152,10 @@ class ServerProcess:
server_args.extend(["--no-context-shift"]) server_args.extend(["--no-context-shift"])
if self.api_key: if self.api_key:
server_args.extend(["--api-key", self.api_key]) server_args.extend(["--api-key", self.api_key])
if self.draft_max:
server_args.extend(["--draft-max", self.draft_max])
if self.draft_min:
server_args.extend(["--draft-min", self.draft_min])
args = [str(arg) for arg in [server_path, *server_args]] args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}") print(f"bench: starting server with: {' '.join(args)}")
@ -185,6 +194,7 @@ class ServerProcess:
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None: def stop(self) -> None:
if self in server_instances:
server_instances.remove(self) server_instances.remove(self)
if self.process: if self.process:
print(f"Stopping server with pid={self.process.pid}") print(f"Stopping server with pid={self.process.pid}")

View file

@ -1739,7 +1739,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ROPE: { case GGML_OP_ROPE: {
// TODO: with ops-test v == 1 // TODO: with ops-test v == 1
float * ext_factor = (float*)((int32_t*)op->op_params + 7); float * ext_factor = (float*)((int32_t*)op->op_params + 7);
float * attn_factor = (float*)((int32_t*)op->op_params + 8);
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) { if (op->src[0]->ne[0] != op->op_params[1]) {
return false; return false;
@ -1748,17 +1747,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (*ext_factor != 0) { if (*ext_factor != 0) {
return false; return false;
} }
// TODO: attn_factor != 1
if (*attn_factor != 1) {
return false;
}
//TODO: type == GGML_TYPE_F16
switch (op->src[0]->type) {
case GGML_TYPE_F32:
return true; return true;
default:
return false;
}
} }
case GGML_OP_UPSCALE: { case GGML_OP_UPSCALE: {
// aclnnUpsampleNearest2dGetWorkspaceSize not support // aclnnUpsampleNearest2dGetWorkspaceSize not support

View file

@ -1020,7 +1020,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
float * res_ptr = s; float * res_ptr = s;
for (int x = 0; x < nc / ncols_interleaved; x++) { for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
float32x4_t sumf = vdupq_n_f32(0); float32x4_t sumf = vdupq_n_f32(0);
for (int l = 0; l < nb; l++) { for (int l = 0; l < nb; l++) {
@ -3507,7 +3507,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
for (int y = 0; y < nr / 4; y++) { for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) { for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
float32x4_t sumf[4]; float32x4_t sumf[4];
for (int m = 0; m < 4; m++) { for (int m = 0; m < 4; m++) {

View file

@ -1792,11 +1792,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -2348,10 +2349,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const block_q8_1 * restrict b_y0 = &vy0[i]; const block_q8_1 * restrict b_y0 = &vy0[i];
const block_q8_1 * restrict b_y1 = &vy1[i]; const block_q8_1 * restrict b_y1 = &vy1[i];
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), float32_t summs_t[4] = {
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
};
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -2372,10 +2375,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
// mmla into int32x4_t // mmla into int32x4_t
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -2395,10 +2400,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
sumv2 = vaddq_f32(sumv2, summs0); sumv2 = vaddq_f32(sumv2, summs0);
vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s, vget_low_f32 (sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2));
return; return;
} }
#endif #endif
@ -3375,10 +3382,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -3396,11 +3405,13 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale); l1, r1)), l2, r2)), l3, r3))), scale);
} }
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s, vget_low_f32 (sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2));
return; return;
} }
#endif #endif

View file

@ -7677,8 +7677,8 @@ UseGgmlGemm2:;
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
int64_t num_rows_per_vec_dot = vec_dot_num_rows; int64_t num_rows_per_vec_dot = vec_dot_num_rows;
// TODO: currently the mmla kernels support only even numbered rows/cols. // these checks are needed to avoid crossing dim1 boundaries
// this check can be removed once they are extended to support odd numbered rows/cols too // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
num_rows_per_vec_dot = 1; num_rows_per_vec_dot = 1;
} }

View file

@ -3447,8 +3447,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// TODO: Refactor and cleanup of mul mat dispatching.
if (src0->ne[3] == 1 && src1->ne[3] == 1) {
// KQ single-batch // KQ single-batch
// mmv p021 was specific for these dimensions
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
} else {
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch // KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
@ -4486,7 +4493,7 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
static int64_t get_op_batch_size(const ggml_tensor * op) { static int64_t get_op_batch_size(const ggml_tensor * op) {
switch (op->op) { switch (op->op) {
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
return op->ne[1]; // this will increse the speed of prefill in test return 0;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
return op->ne[1]; return op->ne[1];
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:

View file

@ -5672,6 +5672,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
} else { } else {
compute_ctx = ctx->compute_ctx.lock(); compute_ctx = ctx->compute_ctx.lock();
} }
} else {
switch (node->op) {
case GGML_OP_REPEAT:
case GGML_OP_ACC:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ARGSORT:
case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_LEAKY_RELU:
{
// These operations all go through ggml_vk_op_f32, so short-circuit and
// do the only thing needed for the dryrun.
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
return false;
}
default:
break;
}
} }
switch (node->op) { switch (node->op) {
@ -6401,16 +6443,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool first_node_in_batch = true; // true if next node will be first node in a batch bool first_node_in_batch = true; // true if next node will be first node in a batch
int submit_node_idx = 0; // index to first node in a batch int submit_node_idx = 0; // index to first node in a batch
// submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
constexpr int submit_count = 100; // Start with a smaller count to get work submitted right away, and increase it after each submit.
int nodes_per_submit = 20;
int submitted_nodes = 0; int submitted_nodes = 0;
int submit_count = 0;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) { if (first_node_in_batch) {
submit_node_idx = i; submit_node_idx = i;
} }
bool submit = (submitted_nodes >= submit_count) || (i == last_node); bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
@ -6427,6 +6470,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
if (submit) { if (submit) {
first_node_in_batch = true; first_node_in_batch = true;
submitted_nodes = 0; submitted_nodes = 0;
switch (submit_count) {
case 0:
nodes_per_submit = 50;
break;
default:
nodes_per_submit = 100;
break;
}
submit_count++;
} }
} }