mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
server : disable context shift by default (#15416)
* server : disable context shift by default ggml-ci * server : make scopr of test parameters local
This commit is contained in:
parent
a6d3cfe7fa
commit
d2fcd91cf9
16 changed files with 27 additions and 20 deletions
|
@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--context-shift"},
|
||||||
|
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||||
|
[](common_params & params) {
|
||||||
|
params.ctx_shift = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--chunks"}, "N",
|
{"--chunks"}, "N",
|
||||||
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
||||||
|
|
|
@ -375,7 +375,7 @@ struct common_params {
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
bool flash_attn = false; // flash attention
|
bool flash_attn = false; // flash attention
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool ctx_shift = true; // context shift on inifinite text generation
|
bool ctx_shift = false; // context shift on inifinite text generation
|
||||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
bool kv_unified = false; // enable unified KV cache
|
bool kv_unified = false; // enable unified KV cache
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from utils import *
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from utils import *
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"cache_prompt": False,
|
"cache_prompt": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_completion_with_tokens_input():
|
def test_completion_with_tokens_input():
|
||||||
|
|
|
@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
|
||||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
@ -25,6 +25,7 @@ def test_ctx_shift_enabled():
|
||||||
# the prompt is truncated to keep the last 109 tokens
|
# the prompt is truncated to keep the last 109 tokens
|
||||||
# 64 tokens are generated thanks to shifting the context when it gets full
|
# 64 tokens are generated thanks to shifting the context when it gets full
|
||||||
global server
|
global server
|
||||||
|
server.enable_ctx_shift = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 64,
|
"n_predict": 64,
|
||||||
|
@ -42,7 +43,6 @@ def test_ctx_shift_enabled():
|
||||||
])
|
])
|
||||||
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
||||||
global server
|
global server
|
||||||
server.disable_ctx_shift = True
|
|
||||||
server.n_predict = -1
|
server.n_predict = -1
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
|
||||||
|
|
||||||
def test_ctx_shift_disabled_long_prompt():
|
def test_ctx_shift_disabled_long_prompt():
|
||||||
global server
|
global server
|
||||||
server.disable_ctx_shift = True
|
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 64,
|
"n_predict": 64,
|
||||||
|
@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt():
|
||||||
|
|
||||||
def test_ctx_shift_disabled_stream():
|
def test_ctx_shift_disabled_stream():
|
||||||
global server
|
global server
|
||||||
server.disable_ctx_shift = True
|
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_stream_request("POST", "/v1/completions", data={
|
res = server.make_stream_request("POST", "/v1/completions", data={
|
||||||
"n_predict": 256,
|
"n_predict": 256,
|
||||||
|
|
|
@ -8,7 +8,7 @@ server = ServerPreset.bert_bge_small()
|
||||||
|
|
||||||
EPSILON = 1e-3
|
EPSILON = 1e-3
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.bert_bge_small()
|
server = ServerPreset.bert_bge_small()
|
||||||
|
|
|
@ -3,7 +3,7 @@ from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllama_infill()
|
server = ServerPreset.tinyllama_infill()
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama_infill()
|
server = ServerPreset.tinyllama_infill()
|
||||||
|
|
|
@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe()
|
||||||
|
|
||||||
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
|
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
|
|
|
@ -4,7 +4,7 @@ from utils import *
|
||||||
server = ServerPreset.jina_reranker_tiny()
|
server = ServerPreset.jina_reranker_tiny()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.jina_reranker_tiny()
|
server = ServerPreset.jina_reranker_tiny()
|
||||||
|
|
|
@ -6,7 +6,7 @@ server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
TEST_API_KEY = "sk-this-is-the-secret-key"
|
TEST_API_KEY = "sk-this-is-the-secret-key"
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
|
@ -3,7 +3,7 @@ from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
|
@ -16,7 +16,7 @@ def create_server():
|
||||||
server.draft_max = 8
|
server.draft_max = 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def fixture_create_server():
|
def fixture_create_server():
|
||||||
return create_server()
|
return create_server()
|
||||||
|
|
||||||
|
@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded():
|
||||||
def test_with_ctx_shift():
|
def test_with_ctx_shift():
|
||||||
global server
|
global server
|
||||||
server.n_ctx = 64
|
server.n_ctx = 64
|
||||||
|
server.enable_ctx_shift = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "Hello " * 56,
|
"prompt": "Hello " * 56,
|
||||||
|
|
|
@ -4,7 +4,7 @@ from utils import *
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
|
@ -22,6 +22,8 @@ def create_server():
|
||||||
server.model_alias = "tinyllama-2-tool-call"
|
server.model_alias = "tinyllama-2-tool-call"
|
||||||
server.server_port = 8081
|
server.server_port = 8081
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
|
server.n_ctx = 8192
|
||||||
|
server.n_batch = 2048
|
||||||
|
|
||||||
class CompletionMode(Enum):
|
class CompletionMode(Enum):
|
||||||
NORMAL = "normal"
|
NORMAL = "normal"
|
||||||
|
|
|
@ -79,7 +79,7 @@ class ServerProcess:
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
lora_files: List[str] | None = None
|
lora_files: List[str] | None = None
|
||||||
disable_ctx_shift: int | None = False
|
enable_ctx_shift: int | None = False
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
draft_max: int | None = None
|
draft_max: int | None = None
|
||||||
no_webui: bool | None = None
|
no_webui: bool | None = None
|
||||||
|
@ -178,8 +178,8 @@ class ServerProcess:
|
||||||
if self.lora_files:
|
if self.lora_files:
|
||||||
for lora_file in self.lora_files:
|
for lora_file in self.lora_files:
|
||||||
server_args.extend(["--lora", lora_file])
|
server_args.extend(["--lora", lora_file])
|
||||||
if self.disable_ctx_shift:
|
if self.enable_ctx_shift:
|
||||||
server_args.extend(["--no-context-shift"])
|
server_args.append("--context-shift")
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
server_args.extend(["--api-key", self.api_key])
|
server_args.extend(["--api-key", self.api_key])
|
||||||
if self.draft_max:
|
if self.draft_max:
|
||||||
|
|
|
@ -581,7 +581,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
params.model = params.vocoder.model;
|
params.model = params.vocoder.model;
|
||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
params.ctx_shift = false; // silence warning
|
|
||||||
params.n_ubatch = params.n_batch;
|
params.n_ubatch = params.n_batch;
|
||||||
|
|
||||||
common_init_result llama_init_cts = common_init_from_params(params);
|
common_init_result llama_init_cts = common_init_from_params(params);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue