mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
temp checkpoint
This commit is contained in:
commit
697ca70115
14 changed files with 321 additions and 60 deletions
|
@ -43,10 +43,17 @@
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
#if defined(LLAVA_LOG_OFF)
|
||||||
#define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
# define LOG_INF(...)
|
||||||
#define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
# define LOG_WRN(...)
|
||||||
#define LOG_DBG(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
# define LOG_ERR(...)
|
||||||
|
# define LOG_DBG(...)
|
||||||
|
#else // defined(LLAVA_LOG_OFF)
|
||||||
|
# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
||||||
|
#endif // defined(LLAVA_LOG_OFF)
|
||||||
|
|
||||||
//#define CLIP_DEBUG_FUNCTIONS
|
//#define CLIP_DEBUG_FUNCTIONS
|
||||||
|
|
||||||
|
|
|
@ -11,13 +11,17 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
#if defined(LLAVA_LOG_OFF)
|
||||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
# define LOG_INF(...)
|
||||||
|
# define LOG_WRN(...)
|
||||||
#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
# define LOG_ERR(...)
|
||||||
#define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
# define LOG_DBG(...)
|
||||||
#define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
#else // defined(LLAVA_LOG_OFF)
|
||||||
#define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
|
||||||
|
# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
|
||||||
|
#endif // defined(LLAVA_LOG_OFF)
|
||||||
|
|
||||||
// RGB uint8 image
|
// RGB uint8 image
|
||||||
struct clip_image_u8 {
|
struct clip_image_u8 {
|
||||||
|
@ -498,10 +502,16 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
|
||||||
errno = 0;
|
errno = 0;
|
||||||
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
|
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
|
||||||
if (ferror(file)) {
|
if (ferror(file)) {
|
||||||
die_fmt("read error: %s", strerror(errno));
|
LOG_ERR("read error: %s", strerror(errno));
|
||||||
|
free(buffer);
|
||||||
|
fclose(file);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
if (ret != (size_t) fileSize) {
|
if (ret != (size_t) fileSize) {
|
||||||
die("unexpectedly reached end of file");
|
LOG_ERR("unexpectedly reached end of file");
|
||||||
|
free(buffer);
|
||||||
|
fclose(file);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
fclose(file); // Close the file
|
fclose(file); // Close the file
|
||||||
|
|
||||||
|
|
|
@ -32,3 +32,17 @@ def test_server_models():
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert len(res.body["data"]) == 1
|
assert len(res.body["data"]) == 1
|
||||||
assert res.body["data"][0]["id"] == server.model_alias
|
assert res.body["data"][0]["id"] == server.model_alias
|
||||||
|
|
||||||
|
def test_load_split_model():
|
||||||
|
global server
|
||||||
|
server.model_hf_repo = "ggml-org/models"
|
||||||
|
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
||||||
|
server.model_alias = "tinyllama-split"
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"n_predict": 16,
|
||||||
|
"prompt": "Hello",
|
||||||
|
"temperature": 0.0,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert match_regex("(little|girl)+", res.body["content"])
|
||||||
|
|
|
@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
||||||
assert res.status_code != 200
|
assert res.status_code != 200
|
||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("messages", [
|
||||||
|
None,
|
||||||
|
"string",
|
||||||
|
[123],
|
||||||
|
[{}],
|
||||||
|
[{"role": 123}],
|
||||||
|
[{"role": "system", "content": 123}],
|
||||||
|
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||||
|
[{"role": "system", "content": "test"}, {}],
|
||||||
|
])
|
||||||
|
def test_invalid_chat_completion_req(messages):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"messages": messages,
|
||||||
|
})
|
||||||
|
assert res.status_code == 400 or res.status_code == 500
|
||||||
|
assert "error" in res.body
|
||||||
|
|
|
@ -8,6 +8,7 @@ def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama_infill()
|
server = ServerPreset.tinyllama_infill()
|
||||||
|
|
||||||
|
|
||||||
def test_infill_without_input_extra():
|
def test_infill_without_input_extra():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -19,6 +20,7 @@ def test_infill_without_input_extra():
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
def test_infill_with_input_extra():
|
def test_infill_with_input_extra():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -33,3 +35,23 @@ def test_infill_with_input_extra():
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input_extra", [
|
||||||
|
{},
|
||||||
|
{"filename": "ok"},
|
||||||
|
{"filename": 123},
|
||||||
|
{"filename": 123, "text": "abc"},
|
||||||
|
{"filename": 123, "text": 456},
|
||||||
|
])
|
||||||
|
def test_invalid_input_extra_req(input_extra):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/infill", data={
|
||||||
|
"prompt": "Complete this",
|
||||||
|
"input_extra": [input_extra],
|
||||||
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
||||||
|
"input_suffix": "}\n",
|
||||||
|
})
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
|
@ -36,3 +36,20 @@ def test_rerank():
|
||||||
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
||||||
assert most_relevant["index"] == 2
|
assert most_relevant["index"] == 2
|
||||||
assert least_relevant["index"] == 3
|
assert least_relevant["index"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("documents", [
|
||||||
|
[],
|
||||||
|
None,
|
||||||
|
123,
|
||||||
|
[1, 2, 3],
|
||||||
|
])
|
||||||
|
def test_invalid_rerank_req(documents):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/rerank", data={
|
||||||
|
"query": "Machine learning is",
|
||||||
|
"documents": documents,
|
||||||
|
})
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
103
examples/server/tests/unit/test_speculative.py
Normal file
103
examples/server/tests/unit/test_speculative.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
import pytest
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
# We use a F16 MOE gguf as main model, and q4_0 as draft model
|
||||||
|
|
||||||
|
server = ServerPreset.stories15m_moe()
|
||||||
|
|
||||||
|
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
|
||||||
|
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.stories15m_moe()
|
||||||
|
# download draft model file if needed
|
||||||
|
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
|
||||||
|
model_draft_file = f'../../../{file_name}'
|
||||||
|
if not os.path.exists(model_draft_file):
|
||||||
|
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
|
||||||
|
with open(model_draft_file, 'wb') as f:
|
||||||
|
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
|
||||||
|
print(f"Done downloading draft model file")
|
||||||
|
# set default values
|
||||||
|
server.model_draft = model_draft_file
|
||||||
|
server.draft_min = 4
|
||||||
|
server.draft_max = 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def fixture_create_server():
|
||||||
|
return create_server()
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_and_without_draft():
|
||||||
|
global server
|
||||||
|
server.model_draft = None # disable draft model
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
content_no_draft = res.body["content"]
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
# create new server with draft model
|
||||||
|
create_server()
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
content_draft = res.body["content"]
|
||||||
|
|
||||||
|
assert content_no_draft == content_draft
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_draft_min_draft_max():
|
||||||
|
global server
|
||||||
|
test_values = [
|
||||||
|
(1, 2),
|
||||||
|
(1, 4),
|
||||||
|
(4, 8),
|
||||||
|
(4, 12),
|
||||||
|
(8, 16),
|
||||||
|
]
|
||||||
|
last_content = None
|
||||||
|
for draft_min, draft_max in test_values:
|
||||||
|
server.stop()
|
||||||
|
server.draft_min = draft_min
|
||||||
|
server.draft_max = draft_max
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
if last_content is not None:
|
||||||
|
assert last_content == res.body["content"]
|
||||||
|
last_content = res.body["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||||
|
(1, 2),
|
||||||
|
(2, 2),
|
||||||
|
])
|
||||||
|
def test_multi_requests_parallel(n_slots: int, n_requests: int):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.start()
|
||||||
|
tasks = []
|
||||||
|
for _ in range(n_requests):
|
||||||
|
tasks.append((server.make_request, ("POST", "/completion", {
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
})))
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
for res in results:
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])
|
|
@ -46,6 +46,7 @@ class ServerProcess:
|
||||||
model_alias: str | None = None
|
model_alias: str | None = None
|
||||||
model_url: str | None = None
|
model_url: str | None = None
|
||||||
model_file: str | None = None
|
model_file: str | None = None
|
||||||
|
model_draft: str | None = None
|
||||||
n_threads: int | None = None
|
n_threads: int | None = None
|
||||||
n_gpu_layer: int | None = None
|
n_gpu_layer: int | None = None
|
||||||
n_batch: int | None = None
|
n_batch: int | None = None
|
||||||
|
@ -68,6 +69,8 @@ class ServerProcess:
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
lora_files: List[str] | None = None
|
lora_files: List[str] | None = None
|
||||||
disable_ctx_shift: int | None = False
|
disable_ctx_shift: int | None = False
|
||||||
|
draft_min: int | None = None
|
||||||
|
draft_max: int | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
@ -102,6 +105,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--model", self.model_file])
|
server_args.extend(["--model", self.model_file])
|
||||||
if self.model_url:
|
if self.model_url:
|
||||||
server_args.extend(["--model-url", self.model_url])
|
server_args.extend(["--model-url", self.model_url])
|
||||||
|
if self.model_draft:
|
||||||
|
server_args.extend(["--model-draft", self.model_draft])
|
||||||
if self.model_hf_repo:
|
if self.model_hf_repo:
|
||||||
server_args.extend(["--hf-repo", self.model_hf_repo])
|
server_args.extend(["--hf-repo", self.model_hf_repo])
|
||||||
if self.model_hf_file:
|
if self.model_hf_file:
|
||||||
|
@ -147,6 +152,10 @@ class ServerProcess:
|
||||||
server_args.extend(["--no-context-shift"])
|
server_args.extend(["--no-context-shift"])
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
server_args.extend(["--api-key", self.api_key])
|
server_args.extend(["--api-key", self.api_key])
|
||||||
|
if self.draft_max:
|
||||||
|
server_args.extend(["--draft-max", self.draft_max])
|
||||||
|
if self.draft_min:
|
||||||
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
@ -185,7 +194,8 @@ class ServerProcess:
|
||||||
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
|
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
server_instances.remove(self)
|
if self in server_instances:
|
||||||
|
server_instances.remove(self)
|
||||||
if self.process:
|
if self.process:
|
||||||
print(f"Stopping server with pid={self.process.pid}")
|
print(f"Stopping server with pid={self.process.pid}")
|
||||||
self.process.kill()
|
self.process.kill()
|
||||||
|
|
|
@ -1739,7 +1739,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||||
case GGML_OP_ROPE: {
|
case GGML_OP_ROPE: {
|
||||||
// TODO: with ops-test v == 1
|
// TODO: with ops-test v == 1
|
||||||
float * ext_factor = (float*)((int32_t*)op->op_params + 7);
|
float * ext_factor = (float*)((int32_t*)op->op_params + 7);
|
||||||
float * attn_factor = (float*)((int32_t*)op->op_params + 8);
|
|
||||||
// TODO: n_dims <= ne0
|
// TODO: n_dims <= ne0
|
||||||
if (op->src[0]->ne[0] != op->op_params[1]) {
|
if (op->src[0]->ne[0] != op->op_params[1]) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -1748,17 +1747,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||||
if (*ext_factor != 0) {
|
if (*ext_factor != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TODO: attn_factor != 1
|
return true;
|
||||||
if (*attn_factor != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
//TODO: type == GGML_TYPE_F16
|
|
||||||
switch (op->src[0]->type) {
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case GGML_OP_UPSCALE: {
|
case GGML_OP_UPSCALE: {
|
||||||
// aclnnUpsampleNearest2dGetWorkspaceSize not support
|
// aclnnUpsampleNearest2dGetWorkspaceSize not support
|
||||||
|
|
|
@ -1020,7 +1020,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
|
||||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
|
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||||
|
|
||||||
float32x4_t sumf = vdupq_n_f32(0);
|
float32x4_t sumf = vdupq_n_f32(0);
|
||||||
for (int l = 0; l < nb; l++) {
|
for (int l = 0; l < nb; l++) {
|
||||||
|
@ -3507,7 +3507,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
||||||
for (int y = 0; y < nr / 4; y++) {
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
|
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||||
|
|
||||||
float32x4_t sumf[4];
|
float32x4_t sumf[4];
|
||||||
for (int m = 0; m < 4; m++) {
|
for (int m = 0; m < 4; m++) {
|
||||||
|
|
|
@ -1792,11 +1792,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
|
@ -1812,7 +1813,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
|
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
|
@ -2348,10 +2349,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||||
const block_q8_1 * restrict b_y0 = &vy0[i];
|
const block_q8_1 * restrict b_y0 = &vy0[i];
|
||||||
const block_q8_1 * restrict b_y1 = &vy1[i];
|
const block_q8_1 * restrict b_y1 = &vy1[i];
|
||||||
|
|
||||||
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
|
float32_t summs_t[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
|
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
|
||||||
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
|
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
|
||||||
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
|
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
|
||||||
|
};
|
||||||
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||||
|
@ -2372,10 +2375,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
// mmla into int32x4_t
|
// mmla into int32x4_t
|
||||||
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
|
@ -2390,15 +2395,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||||
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
sumv2 = vaddq_f32(sumv2, summs0);
|
sumv2 = vaddq_f32(sumv2, summs0);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32 (sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -3375,10 +3382,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
|
@ -3394,13 +3403,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
|
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
|
||||||
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32(sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -7677,8 +7677,8 @@ UseGgmlGemm2:;
|
||||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||||
|
|
||||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
// these checks are needed to avoid crossing dim1 boundaries
|
||||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
// can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
|
||||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
||||||
num_rows_per_vec_dot = 1;
|
num_rows_per_vec_dot = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3447,8 +3447,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||||
|
|
||||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||||
// KQ single-batch
|
// TODO: Refactor and cleanup of mul mat dispatching.
|
||||||
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
|
if (src0->ne[3] == 1 && src1->ne[3] == 1) {
|
||||||
|
// KQ single-batch
|
||||||
|
// mmv p021 was specific for these dimensions
|
||||||
|
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
|
||||||
|
} else {
|
||||||
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
||||||
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||||
|
}
|
||||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||||
// KQV single-batch
|
// KQV single-batch
|
||||||
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
||||||
|
@ -4486,7 +4493,7 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
|
||||||
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
return op->ne[1]; // this will increse the speed of prefill in test
|
return 0;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
return op->ne[1];
|
return op->ne[1];
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
|
|
@ -5672,6 +5672,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
} else {
|
} else {
|
||||||
compute_ctx = ctx->compute_ctx.lock();
|
compute_ctx = ctx->compute_ctx.lock();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
switch (node->op) {
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_GET_ROWS:
|
||||||
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_CONCAT:
|
||||||
|
case GGML_OP_UPSCALE:
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
case GGML_OP_SIN:
|
||||||
|
case GGML_OP_COS:
|
||||||
|
case GGML_OP_CLAMP:
|
||||||
|
case GGML_OP_PAD:
|
||||||
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
case GGML_OP_DUP:
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_GROUP_NORM:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_UNARY:
|
||||||
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
{
|
||||||
|
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
||||||
|
// do the only thing needed for the dryrun.
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
|
@ -6401,16 +6443,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
||||||
int submit_node_idx = 0; // index to first node in a batch
|
int submit_node_idx = 0; // index to first node in a batch
|
||||||
|
|
||||||
// submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution
|
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
|
||||||
constexpr int submit_count = 100;
|
// Start with a smaller count to get work submitted right away, and increase it after each submit.
|
||||||
|
int nodes_per_submit = 20;
|
||||||
int submitted_nodes = 0;
|
int submitted_nodes = 0;
|
||||||
|
int submit_count = 0;
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
if (first_node_in_batch) {
|
if (first_node_in_batch) {
|
||||||
submit_node_idx = i;
|
submit_node_idx = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool submit = (submitted_nodes >= submit_count) || (i == last_node);
|
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
|
||||||
|
|
||||||
|
|
||||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
||||||
|
|
||||||
|
@ -6427,6 +6470,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
if (submit) {
|
if (submit) {
|
||||||
first_node_in_batch = true;
|
first_node_in_batch = true;
|
||||||
submitted_nodes = 0;
|
submitted_nodes = 0;
|
||||||
|
switch (submit_count) {
|
||||||
|
case 0:
|
||||||
|
nodes_per_submit = 50;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
nodes_per_submit = 100;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
submit_count++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue