server : disable context shift by default (#15416)

* server : disable context shift by default

ggml-ci

* server : make scopr of test parameters local
This commit is contained in:
Georgi Gerganov 2025-08-19 16:46:37 +03:00 committed by GitHub
parent a6d3cfe7fa
commit d2fcd91cf9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 27 additions and 20 deletions

View file

@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.ctx_shift = false; params.ctx_shift = false;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--context-shift"},
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](common_params & params) {
params.ctx_shift = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
add_opt(common_arg( add_opt(common_arg(
{"--chunks"}, "N", {"--chunks"}, "N",
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),

View file

@ -375,7 +375,7 @@ struct common_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = false; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache bool kv_unified = false; // enable unified KV cache

View file

@ -5,7 +5,7 @@ from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()

View file

@ -7,7 +7,7 @@ from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
"temperature": 1.0, "temperature": 1.0,
"cache_prompt": False, "cache_prompt": False,
}) })
assert res.status_code == 200 assert res.status_code == 400
def test_completion_with_tokens_input(): def test_completion_with_tokens_input():

View file

@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip() """.strip()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@ -25,6 +25,7 @@ def test_ctx_shift_enabled():
# the prompt is truncated to keep the last 109 tokens # the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full # 64 tokens are generated thanks to shifting the context when it gets full
global server global server
server.enable_ctx_shift = True
server.start() server.start()
res = server.make_request("POST", "/completion", data={ res = server.make_request("POST", "/completion", data={
"n_predict": 64, "n_predict": 64,
@ -42,7 +43,6 @@ def test_ctx_shift_enabled():
]) ])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server global server
server.disable_ctx_shift = True
server.n_predict = -1 server.n_predict = -1
server.start() server.start()
res = server.make_request("POST", "/completion", data={ res = server.make_request("POST", "/completion", data={
@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
def test_ctx_shift_disabled_long_prompt(): def test_ctx_shift_disabled_long_prompt():
global server global server
server.disable_ctx_shift = True
server.start() server.start()
res = server.make_request("POST", "/completion", data={ res = server.make_request("POST", "/completion", data={
"n_predict": 64, "n_predict": 64,
@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt():
def test_ctx_shift_disabled_stream(): def test_ctx_shift_disabled_stream():
global server global server
server.disable_ctx_shift = True
server.start() server.start()
res = server.make_stream_request("POST", "/v1/completions", data={ res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256, "n_predict": 256,

View file

@ -8,7 +8,7 @@ server = ServerPreset.bert_bge_small()
EPSILON = 1e-3 EPSILON = 1e-3
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.bert_bge_small() server = ServerPreset.bert_bge_small()

View file

@ -3,7 +3,7 @@ from utils import *
server = ServerPreset.tinyllama_infill() server = ServerPreset.tinyllama_infill()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama_infill() server = ServerPreset.tinyllama_infill()

View file

@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe()
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.stories15m_moe() server = ServerPreset.stories15m_moe()

View file

@ -4,7 +4,7 @@ from utils import *
server = ServerPreset.jina_reranker_tiny() server = ServerPreset.jina_reranker_tiny()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.jina_reranker_tiny() server = ServerPreset.jina_reranker_tiny()

View file

@ -6,7 +6,7 @@ server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key" TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()

View file

@ -3,7 +3,7 @@ from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()

View file

@ -16,7 +16,7 @@ def create_server():
server.draft_max = 8 server.draft_max = 8
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def fixture_create_server(): def fixture_create_server():
return create_server() return create_server()
@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded():
def test_with_ctx_shift(): def test_with_ctx_shift():
global server global server
server.n_ctx = 64 server.n_ctx = 64
server.enable_ctx_shift = True
server.start() server.start()
res = server.make_request("POST", "/completion", data={ res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56, "prompt": "Hello " * 56,

View file

@ -4,7 +4,7 @@ from utils import *
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(autouse=True)
def create_server(): def create_server():
global server global server
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()

View file

@ -22,6 +22,8 @@ def create_server():
server.model_alias = "tinyllama-2-tool-call" server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081 server.server_port = 8081
server.n_slots = 1 server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048
class CompletionMode(Enum): class CompletionMode(Enum):
NORMAL = "normal" NORMAL = "normal"

View file

@ -79,7 +79,7 @@ class ServerProcess:
draft: int | None = None draft: int | None = None
api_key: str | None = None api_key: str | None = None
lora_files: List[str] | None = None lora_files: List[str] | None = None
disable_ctx_shift: int | None = False enable_ctx_shift: int | None = False
draft_min: int | None = None draft_min: int | None = None
draft_max: int | None = None draft_max: int | None = None
no_webui: bool | None = None no_webui: bool | None = None
@ -178,8 +178,8 @@ class ServerProcess:
if self.lora_files: if self.lora_files:
for lora_file in self.lora_files: for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file]) server_args.extend(["--lora", lora_file])
if self.disable_ctx_shift: if self.enable_ctx_shift:
server_args.extend(["--no-context-shift"]) server_args.append("--context-shift")
if self.api_key: if self.api_key:
server_args.extend(["--api-key", self.api_key]) server_args.extend(["--api-key", self.api_key])
if self.draft_max: if self.draft_max:

View file

@ -581,7 +581,6 @@ int main(int argc, char ** argv) {
params.model = params.vocoder.model; params.model = params.vocoder.model;
params.embedding = true; params.embedding = true;
params.ctx_shift = false; // silence warning
params.n_ubatch = params.n_batch; params.n_ubatch = params.n_batch;
common_init_result llama_init_cts = common_init_from_params(params); common_init_result llama_init_cts = common_init_from_params(params);