diff --git a/common/arg.cpp b/common/arg.cpp index 98baac4c1..d3868018e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.ctx_shift = false; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); + add_opt(common_arg( + {"--context-shift"}, + string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](common_params & params) { + params.ctx_shift = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT")); add_opt(common_arg( {"--chunks"}, "N", string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), diff --git a/common/common.h b/common/common.h index dfb63461b..920de7b50 100644 --- a/common/common.h +++ b/common/common.h @@ -375,7 +375,7 @@ struct common_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics - bool ctx_shift = true; // context shift on inifinite text generation + bool ctx_shift = false; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 1485de8ce..c7b3af048 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,7 +5,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index be3a0052c..adb6f2786 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -7,7 +7,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -229,7 +229,7 @@ def test_nocache_long_input_prompt(): "temperature": 1.0, "cache_prompt": False, }) - assert res.status_code == 200 + assert res.status_code == 400 def test_completion_with_tokens_input(): diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 2431ac708..8f51bc301 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. """.strip() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -25,6 +25,7 @@ def test_ctx_shift_enabled(): # the prompt is truncated to keep the last 109 tokens # 64 tokens are generated thanks to shifting the context when it gets full global server + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -42,7 +43,6 @@ def test_ctx_shift_enabled(): ]) def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): global server - server.disable_ctx_shift = True server.n_predict = -1 server.start() res = server.make_request("POST", "/completion", data={ @@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr def test_ctx_shift_disabled_long_prompt(): global server - server.disable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt(): def test_ctx_shift_disabled_stream(): global server - server.disable_ctx_shift = True server.start() res = server.make_stream_request("POST", "/v1/completions", data={ "n_predict": 256, diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 0feb452cc..50601b839 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -8,7 +8,7 @@ server = ServerPreset.bert_bge_small() EPSILON = 1e-3 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.bert_bge_small() diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py index 10554db0f..73dacdae8 100644 --- a/tools/server/tests/unit/test_infill.py +++ b/tools/server/tests/unit/test_infill.py @@ -3,7 +3,7 @@ from utils import * server = ServerPreset.tinyllama_infill() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama_infill() diff --git a/tools/server/tests/unit/test_lora.py b/tools/server/tests/unit/test_lora.py index c1aa8be70..00b2f245f 100644 --- a/tools/server/tests/unit/test_lora.py +++ b/tools/server/tests/unit/test_lora.py @@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe() LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.stories15m_moe() diff --git a/tools/server/tests/unit/test_rerank.py b/tools/server/tests/unit/test_rerank.py index f4f570ad5..0b63c7821 100644 --- a/tools/server/tests/unit/test_rerank.py +++ b/tools/server/tests/unit/test_rerank.py @@ -4,7 +4,7 @@ from utils import * server = ServerPreset.jina_reranker_tiny() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.jina_reranker_tiny() diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 620b25376..0e1158055 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -6,7 +6,7 @@ server = ServerPreset.tinyllama2() TEST_API_KEY = "sk-this-is-the-secret-key" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_slot_save.py b/tools/server/tests/unit/test_slot_save.py index 38704f5ec..1b428cc2a 100644 --- a/tools/server/tests/unit/test_slot_save.py +++ b/tools/server/tests/unit/test_slot_save.py @@ -3,7 +3,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index 54db38cf3..38ca4325b 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -16,7 +16,7 @@ def create_server(): server.draft_max = 8 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def fixture_create_server(): return create_server() @@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded(): def test_with_ctx_shift(): global server server.n_ctx = 64 + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "Hello " * 56, diff --git a/tools/server/tests/unit/test_tokenize.py b/tools/server/tests/unit/test_tokenize.py index 382457c9d..424cac5f3 100644 --- a/tools/server/tests/unit/test_tokenize.py +++ b/tools/server/tests/unit/test_tokenize.py @@ -4,7 +4,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 20f048c6f..a3c3ccdf5 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -22,6 +22,8 @@ def create_server(): server.model_alias = "tinyllama-2-tool-call" server.server_port = 8081 server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 class CompletionMode(Enum): NORMAL = "normal" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index bc547ca03..49277e600 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -79,7 +79,7 @@ class ServerProcess: draft: int | None = None api_key: str | None = None lora_files: List[str] | None = None - disable_ctx_shift: int | None = False + enable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None @@ -178,8 +178,8 @@ class ServerProcess: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.disable_ctx_shift: - server_args.extend(["--no-context-shift"]) + if self.enable_ctx_shift: + server_args.append("--context-shift") if self.api_key: server_args.extend(["--api-key", self.api_key]) if self.draft_max: diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index a71e9bf5b..18f01a994 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -581,7 +581,6 @@ int main(int argc, char ** argv) { params.model = params.vocoder.model; params.embedding = true; - params.ctx_shift = false; // silence warning params.n_ubatch = params.n_batch; common_init_result llama_init_cts = common_init_from_params(params);