diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 11ff38762..abc4fa1c8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -34,6 +34,7 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) + add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/model-conversion/.gitignore b/examples/model-conversion/.gitignore new file mode 100644 index 000000000..451227547 --- /dev/null +++ b/examples/model-conversion/.gitignore @@ -0,0 +1,3 @@ +.model_name +data +ppl diff --git a/examples/model-conversion/CMakeLists.txt b/examples/model-conversion/CMakeLists.txt new file mode 100644 index 000000000..fc1746ce4 --- /dev/null +++ b/examples/model-conversion/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-logits) +add_executable(${TARGET} logits.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile new file mode 100644 index 000000000..27d95b4f2 --- /dev/null +++ b/examples/model-conversion/Makefile @@ -0,0 +1,163 @@ +# Validation functions +define validate_model_path + @if [ -z "$(MODEL_PATH)" ]; then \ + echo "Error: MODEL_PATH must be provided either as:"; \ + echo " 1. Environment variable: export MODEL_PATH=/path/to/model"; \ + echo " 2. Command line argument: make $(1) MODEL_PATH=/path/to/model"; \ + exit 1; \ + fi +endef + +define validate_embedding_model_path + @if [ -z "$(EMBEDDING_MODEL_PATH)" ]; then \ + echo "Error: EMBEDDING_MODEL_PATH must be provided either as:"; \ + echo " 1. Environment variable: export EMBEDDING_MODEL_PATH=/path/to/model"; \ + echo " 2. Command line argument: make $(1) EMBEDDING_MODEL_PATH=/path/to/model"; \ + exit 1; \ + fi +endef + +### +### Casual Model targets/recipes +### +causal-convert-model-bf16: OUTTYPE=bf16 +causal-convert-model-bf16: causal-convert-model + +causal-convert-model: + $(call validate_model_path,causal-convert-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh + +causal-run-original-model: + $(call validate_model_path,causal-run-original-model) + @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py + +causal-run-converted-model: + @CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh + +causal-verify-logits: causal-run-original-model causal-run-converted-model + @./scripts/causal/compare-logits.py + @MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH} + +causal-run-original-embeddings: + @./scripts/causal/run-casual-gen-embeddings-org.sh + +causal-run-converted-embeddings: + @./scripts/causal/run-converted-model-embeddings-logits.sh + +causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-embeddings + @./scripts/causal/compare-embeddings-logits.sh + +causal-inspect-original-model: + @./scripts/utils/inspect-org-model.py + +causal-inspect-converted-model: + @./scripts/utils/inspect-converted-model.sh + +causal-start-embedding-server: + @./scripts/utils/run-embedding-server.sh ${CONVERTED_MODEL} + +causal-curl-embedding-endpoint: causal-run-original-embeddings + @./scripts/utils/curl-embedding-server.sh | ./scripts/causal/compare-embeddings-logits.sh + +causal-quantize-Q8_0: QUANTIZED_TYPE = Q8_0 +causal-quantize-Q8_0: causal-quantize-model + +causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0 +causal-quantize-Q4_0: causal-quantize-model + +causal-quantize-model: + @CONVERTED_MODEL="$(CONVERTED_MODEL)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" ./scripts/utils/quantize.sh ${CONVERTED_MODEL} ${QUANTIZED_TYPE} + @echo "Export the quantized model path to QUANTIZED_MODEL variable in your environment" + +causal-run-quantized-model: + @QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL} + + +### +### Embedding Model targets/recipes +### + +embedding-convert-model-bf16: OUTTYPE=bf16 +embedding-convert-model-bf16: embedding-convert-model + +embedding-convert-model: + $(call validate_embedding_model_path,embedding-convert-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/embedding/convert-model.sh + +embedding-run-original-model: + $(call validate_embedding_model_path,embedding-run-original-model) + @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/embedding/run-original-model.py + +embedding-run-converted-model: + @CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/embedding/run-converted-model.sh ${CONVERTED_EMBEDDING_MODEL} + +embedding-verify-logits: embedding-run-original-model embedding-run-converted-model + @./scripts/embedding/compare-embeddings-logits.sh + +embedding-inspect-original-model: + $(call validate_embedding_model_path,embedding-inspect-original-model) + @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} + +embedding-inspect-converted-model: + @CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL} + +embedding-start-embedding-server: + @./scripts/utils/run-embedding-server.sh ${CONVERTED_EMBEDDING_MODEL} + +embedding-curl-embedding-endpoint: + @./scripts/utils/curl-embedding-server.sh | ./scripts/embedding/compare-embeddings-logits.sh + +embedding-quantize-Q8_0: QUANTIZED_TYPE = Q8_0 +embedding-quantize-Q8_0: embedding-quantize-model + +embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0 +embedding-quantize-Q4_0: embedding-quantize-model + +embedding-quantize-model: + @./scripts/utils/quantize.sh ${CONVERTED_EMBEDDING_MODEL} ${QUANTIZED_TYPE} + @echo "Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment" + +embedding-run-quantized-model: + @./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL} + +### +### Perplexity targets/recipes +### +perplexity-data-gen: + CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/utils/perplexity-gen.sh + +perplexity-run-full: + QUANTIZED_MODEL="$(QUANTIZED_MODEL)" LOOGITS_FILE="$(LOGITS_FILE)" \ + ./scripts/utils/perplexity-run.sh + +perplexity-run: + QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/utils/perplexity-run-simple.sh + +### +### HuggingFace targets/recipes +### + +hf-create-model: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" + +hf-create-model-private: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p + +hf-upload-gguf-to-model: + @./scripts/utils/hf-upload-gguf-model.py -m "${MODEL_PATH}" -r "${REPO_ID}" -o "${NAME_IN_REPO}" + +hf-create-collection: + @./scripts/utils/hf-create-collection.py -n "${NAME}" -d "${DESCRIPTION}" -ns "${NAMESPACE}" + +hf-add-model-to-collection: + @./scripts/utils/hf-add-model-to-collection.py -c "${COLLECTION}" -m "${MODEL}" + + +.PHONY: clean +clean: + @${RM} -rf data .converted_embedding_model.txt .converted_model.txt .embedding_model_name.txt .model_name.txt + diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md new file mode 100644 index 000000000..c3c5001ea --- /dev/null +++ b/examples/model-conversion/README.md @@ -0,0 +1,335 @@ +# Model Conversion Example +This directory contains scripts and code to help in the process of converting +HuggingFace PyTorch models to GGUF format. + +The motivation for having this is that the conversion process can often be an +iterative process, where the original model is inspected, converted, updates +made to llama.cpp, converted again, etc. Once the model has been converted it +needs to be verified against the original model, and then optionally quantified, +and is some cases perplexity checked of the quantized model. And finally the +model/models need to the ggml-org on Hugging Face. This tool/example tries to +help with this process. + +### Overview +The idea is that the makefile targets and scripts here can be used in the +development/conversion process assisting with things like: + +* inspect/run the original model to figure out how it works +* convert the original model to GGUF format +* inspect/run the converted model +* verify the logits produced by the original model and the converted model +* quantize the model to GGUF format +* run perplexity evaluation to verify that the quantized model is performing + as expected +* upload the model to HuggingFace to make it available for others + +## Setup +Create virtual python environment +```console +$ python3.11 -m venv venv +$ source venv/bin/activate +(venv) $ pip install -r requirements.txt +``` + +## Causal Language Model Conversion +This section describes the steps to convert a causal language model to GGUF and +to verify that the conversion was successful. + +### Download the original model +First, clone the original model to some local directory: +```console +$ mkdir models && cd models +$ git clone https://huggingface.co/user/model_name +$ cd model_name +$ git lfs install +$ git lfs pull +``` + +### Set the MODEL_PATH +The path to the downloaded model can be provided in two ways: + +**Option 1: Environment variable (recommended for iterative development)** +```console +export MODEL_PATH=~/work/ai/models/some_model +``` + +**Option 2: Command line argument (for one-off tasks)** +```console +make causal-convert-model MODEL_PATH=~/work/ai/models/some_model +``` + +Command line arguments take precedence over environment variables when both are provided. + +In cases where the transformer implementation for the model has not been released +yet it is possible to set the environment variable `UNRELEASED_MODEL_NAME` which +will the cause the transformer implementation to be loaded explicitely and not +use AutoModelForCausalLM: +``` +export UNRELEASED_MODEL_NAME=SomeNewModel +``` + +### Inspecting the original tensors +```console +# Using environment variable +(venv) $ make causal-inspect-original-model + +# Or using command line argument +(venv) $ make causal-inspect-original-model MODEL_PATH=~/work/ai/models/some_model +``` + +### Running the original model +This is mainly to verify that the original model works, and to compare the output +from the converted model. +```console +# Using environment variable +(venv) $ make causal-run-original-model + +# Or using command line argument +(venv) $ make causal-run-original-model MODEL_PATH=~/work/ai/models/some_model +``` +This command will save two file to the `data` directory, one is a binary file +containing logits which will be used for comparison with the converted model +later, and the other is a text file which allows for manual visual inspection. + +### Model conversion +After updates have been made to [gguf-py](../../gguf-py) to add support for the +new model, the model can be converted to GGUF format using the following command: +```console +# Using environment variable +(venv) $ make causal-convert-model + +# Or using command line argument +(venv) $ make causal-convert-model MODEL_PATH=~/work/ai/models/some_model +``` + +### Inspecting the converted model +The converted model can be inspected using the following command: +```console +(venv) $ make inspect-converted-model +``` + +### Running the converted model +```console +(venv) $ make run-converted-model +``` + +### Model logits verfication +The following target will run the original model and the converted model and +compare the logits: +```console +(venv) $ make causal-verify-logits +``` + +### Quantizing the model +The causal model can be quantized to GGUF format using the following command: +```console +(venv) $ make causal-quantize-Q8_0 +Quantized model saved to: /path/to/quantized/model-Q8_0.gguf +Export the quantized model path to QUANTIZED_MODEL variable in your environment +``` +This will show the path to the quantized model in the terminal, which can then +be used set the `QUANTIZED_MODEL` environment variable: +```console +export QUANTIZED_MODEL=/path/to/quantized/model-Q8_0.gguf +``` +The the quantized model can be run using the following command: +```console +(venv) $ make causal-run-quantized-model +``` + + +## Embedding Language Model Conversion + +### Download the original model +```console +$ mkdir models && cd models +$ git clone https://huggingface.co/user/model_name +$ cd model_name +$ git lfs install +$ git lfs pull +``` + +The path to the embedding model can be provided in two ways: + +**Option 1: Environment variable (recommended for iterative development)** +```console +export EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` + +**Option 2: Command line argument (for one-off tasks)** +```console +make embedding-convert-model EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` + +Command line arguments take precedence over environment variables when both are provided. + +### Running the original model +This is mainly to verify that the original model works and to compare the output +with the output from the converted model. +```console +# Using environment variable +(venv) $ make embedding-run-original-model + +# Or using command line argument +(venv) $ make embedding-run-original-model EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` +This command will save two files to the `data` directory, one is a binary +file containing logits which will be used for comparison with the converted +model, and the other is a text file which allows for manual visual inspection. + +### Model conversion +After updates have been made to [gguf-py](../../gguf-py) to add support for the +new model the model can be converted to GGUF format using the following command: +```console +(venv) $ make embedding-convert-model +``` + +### Run the converted model +```console +(venv) $ make embedding-run-converted-model +``` + +### Model logits verfication +The following target will run the original model and the converted model (which +was done manually in the previous steps) and compare the logits: +```console +(venv) $ make embedding-verify-logits +``` + +### llama-server verification +To verify that the converted model works with llama-server, the following +command can be used: +```console +(venv) $ make embedding-start-embedding-server +``` +Then open another terminal and set the `EMBEDDINGS_MODEL_PATH` environment +variable as this will not be inherited by the new terminal: +```console +(venv) $ make embedding-curl-embedding-endpoint +``` +This will call the `embedding` endpoing and the output will be piped into +the same verification script as used by the target `embedding-verify-logits`. + +The causal model can also be used to produce embeddings and this can be verified +using the following commands: +```console +(venv) $ make causal-start-embedding-server +``` +Then open another terminal and set the `MODEL_PATH` environment +variable as this will not be inherited by the new terminal: +```console +(venv) $ make casual-curl-embedding-endpoint +``` + +### Quantizing the model +The embedding model can be quantized to GGUF format using the following command: +```console +(venv) $ make embedding-quantize-Q8_0 +Quantized model saved to: /path/to/quantized/model-Q8_0.gguf +Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment +``` +This will show the path to the quantized model in the terminal, which can then +be used set the `QUANTIZED_EMBEDDING_MODEL` environment variable: +```console +export QUANTIZED_EMBEDDING_MODEL=/path/to/quantized/model-Q8_0.gguf +``` +The the quantized model can be run using the following command: +```console +(venv) $ make embedding-run-quantized-model +``` + +## Perplexity Evaluation + +### Simple perplexity evaluation +This allows to run the perplexity evaluation without having to generate a +token/logits file: +```console +(venv) $ make perplexity-run QUANTIZED_MODEL=~/path/to/quantized/model.gguf +``` +This will use the wikitext dataset to run the perplexity evaluation and and +output the perplexity score to the terminal. This value can then be compared +with the perplexity score of the unquantized model. + +### Full perplexity evaluation +First use the converted, non-quantized, model to generate the perplexity evaluation +dataset using the following command: +```console +$ make perplexity-data-gen CONVERTED_MODEL=~/path/to/converted/model.gguf +``` +This will generate a file in the `data` directory named after the model and with +a `.kld` suffix which contains the tokens and the logits for the wikitext dataset. + +After the dataset has been generated, the perplexity evaluation can be run using +the quantized model: +```console +$ make perplexity-run-full QUANTIZED_MODEL=~/path/to/quantized/model-Qxx.gguf LOGITS_FILE=data/model.gguf.ppl +``` + +> šŸ“ **Note:** The `LOGITS_FILE` is the file generated by the previous command +> can be very large, so make sure you have enough disk space available. + +## HuggingFace utilities +The following targets are useful for creating collections and model repositories +on Hugging Face in the the ggml-org. These can be used when preparing a relase +to script the process for new model releases. + +For the following targets a `HF_TOKEN` environment variable is required. + +> šŸ“ **Note:** Don't forget to logout from Hugging Face after running these +> commands, otherwise you might have issues pulling/cloning repositories as +> the token will still be in use: +> $ huggingface-cli logout +> $ unset HF_TOKEN + +### Create a new Hugging Face Model (model repository) +This will create a new model repsository on Hugging Face with the specified +model name. +```console +(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" +Repository ID: danbev/TestModel-GGUF +Repository created: https://huggingface.co/danbev/TestModel-GGUF +``` +Note that we append a `-GGUF` suffix to the model name to ensure a consistent +naming convention for GGUF models. + +### Upload a GGUF model to model repository +The following target uploads a model to an existing Hugging Face model repository. +```console +(venv) $ make hf-upload-gguf-to-model MODEL_PATH=dummy-model1.gguf REPO_ID=danbev/TestModel-GGUF +šŸ“¤ Uploading dummy-model1.gguf to danbev/TestModel-GGUF/dummy-model1.gguf +āœ… Upload successful! +šŸ”— File available at: https://huggingface.co/danbev/TestModel-GGUF/blob/main/dummy-model1.gguf +``` +This command can also be used to update an existing model file in a repository. + +### Create a new Collection +```console +(venv) $ make hf-new-collection NAME=TestCollection DESCRIPTION="Collection for testing scripts" NAMESPACE=danbev +šŸš€ Creating Hugging Face Collection +Title: TestCollection +Description: Collection for testing scripts +Namespace: danbev +Private: False +āœ… Authenticated as: danbev +šŸ“š Creating collection: 'TestCollection'... +āœ… Collection created successfully! +šŸ“‹ Collection slug: danbev/testcollection-68930fcf73eb3fc200b9956d +šŸ”— Collection URL: https://huggingface.co/collections/danbev/testcollection-68930fcf73eb3fc200b9956d + +šŸŽ‰ Collection created successfully! +Use this slug to add models: danbev/testcollection-68930fcf73eb3fc200b9956d +``` + +### Add model to a Collection +```console +(venv) $ make hf-add-model-to-collection COLLECTION=danbev/testcollection-68930fcf73eb3fc200b9956d MODEL=danbev/TestModel-GGUF +āœ… Authenticated as: danbev +šŸ” Checking if model exists: danbev/TestModel-GGUF +āœ… Model found: danbev/TestModel-GGUF +šŸ“š Adding model to collection... +āœ… Model added to collection successfully! +šŸ”— Collection URL: https://huggingface.co/collections/danbev/testcollection-68930fcf73eb3fc200b9956d + +šŸŽ‰ Model added successfully! + +``` diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp new file mode 100644 index 000000000..2cac6a3b3 --- /dev/null +++ b/examples/model-conversion/logits.cpp @@ -0,0 +1,209 @@ +#include "llama.h" +#include +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt = "Hello, my name is"; + int ngl = 0; + bool embedding_mode = false; + + { + int i = 1; + for (; i < argc; i++) { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + try { + ngl = std::stoi(argv[++i]); + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-embd-mode") == 0) { + if (i + 1 < argc) { + try { + embedding_mode = true; + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else { + // prompt starts here + break; + } + } + + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + if (i < argc) { + prompt = argv[i++]; + for (; i < argc; i++) { + prompt += " "; + prompt += argv[i]; + } + } + } + + ggml_backend_load_all(); + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // Extract basename from model_path + const char * basename = strrchr(model_path.c_str(), '/'); + basename = (basename == NULL) ? model_path.c_str() : basename + 1; + + char model_name[256]; + strncpy(model_name, basename, 255); + model_name[255] = '\0'; + + char * dot = strrchr(model_name, '.'); + if (dot != NULL && strcmp(dot, ".gguf") == 0) { + *dot = '\0'; + } + printf("Model name: %s\n", model_name); + + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + + std::vector prompt_tokens(n_prompt); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { + fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); + return 1; + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_prompt; + ctx_params.n_batch = n_prompt; + ctx_params.no_perf = false; + if (embedding_mode) { + ctx_params.embeddings = true; + ctx_params.n_ubatch = ctx_params.n_batch; + } + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + printf("Input prompt: \"%s\"\n", prompt.c_str()); + printf("Tokenized prompt (%d tokens): ", n_prompt); + for (auto id : prompt_tokens) { + char buf[128]; + int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); + if (n < 0) { + fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); + return 1; + } + std::string s(buf, n); + printf("%s", s.c_str()); + } + printf("\n"); + + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + float * logits; + int n_logits; + const char * type; + + if (embedding_mode) { + logits = llama_get_embeddings(ctx); + n_logits = llama_model_n_embd(model) * batch.n_tokens; + type = "-embeddings"; + printf("Embeddings size: %d\n", n_logits); + } else { + logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + n_logits = llama_vocab_n_tokens(vocab); + type = ""; + printf("Vocab size: %d\n", n_logits); + } + + std::filesystem::create_directory("data"); + + // Save logits to binary file + char bin_filename[512]; + snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); + printf("Saving logits to %s\n", bin_filename); + + FILE * f = fopen(bin_filename, "wb"); + if (f == NULL) { + fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); + return 1; + } + fwrite(logits, sizeof(float), n_logits, f); + fclose(f); + + // Also save as text for debugging + char txt_filename[512]; + snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type); + f = fopen(txt_filename, "w"); + if (f == NULL) { + fprintf(stderr, "%s: error: failed to open text output file\n", __func__); + return 1; + } + for (int i = 0; i < n_logits; i++) { + fprintf(f, "%d: %.6f\n", i, logits[i]); // Added index and changed format + } + fclose(f); + + // Print first and last 10 logits for quick verification + printf("First 10 logits: "); + for (int i = 0; i < 10 && i < n_logits; i++) { + printf("%.6f ", logits[i]); + } + printf("\n"); + + printf("Last 10 logits: "); + for (int i = n_logits - 10; i < n_logits; i++) { + if (i >= 0) printf("%.6f ", logits[i]); + } + printf("\n\n"); + + printf("Logits saved to %s\n", bin_filename); + printf("Logits saved to %s\n", txt_filename); + + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt new file mode 100644 index 000000000..e1aa259e9 --- /dev/null +++ b/examples/model-conversion/requirements.txt @@ -0,0 +1,4 @@ +torch~=2.6.0 +torchvision~=0.21.0 +transformers~=4.55.0 +huggingface-hub~=0.34.0 diff --git a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh new file mode 100755 index 000000000..287158f63 --- /dev/null +++ b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh @@ -0,0 +1,43 @@ +#/bin/bash + +set -e + +MODEL_PATH="${1:-"$MODEL_PATH"}" +MODEL_NAME="${2:-$(basename "$MODEL_PATH")}" + +if [ -t 0 ]; then + CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin" +else + # Process piped JSON data and convert to binary (matching logits.cpp format) + TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) + python3 -c " +import json +import sys +import struct + +data = json.load(sys.stdin) + +# Flatten all embeddings completely +flattened = [] +for item in data: + embedding = item['embedding'] + for token_embedding in embedding: + flattened.extend(token_embedding) + +print(f'Total embedding values: {len(flattened)}', file=sys.stderr) + +# Write as binary floats - matches logitc.cpp fwrite format +with open('$TEMP_FILE', 'wb') as f: + for value in flattened: + f.write(struct.pack('f', value)) +" + CPP_EMBEDDINGS="$TEMP_FILE" + trap "rm -f $TEMP_FILE" EXIT +fi + +python scripts/utils/semantic_check.py --model-path $MODEL_PATH \ + --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \ + --cpp-embeddings $CPP_EMBEDDINGS \ + --prompt "Hello world today" \ + --causal + diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py new file mode 100755 index 000000000..fb959f0d5 --- /dev/null +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +import numpy as np +import sys +import os +from pathlib import Path + +def quick_logits_check(pytorch_file, llamacpp_file): + """Lightweight sanity check before NMSE""" + + try: + pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32) + llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32) + except Exception as e: + print(f"āŒ NOK: Failed to load files - {e}") + return False + + # Check shapes match + if pytorch_logits.shape != llamacpp_logits.shape: + print(f"āŒ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}") + return False + + # Calculate key metrics + diff = pytorch_logits - llamacpp_logits + abs_diff = np.abs(diff) + max_diff = np.max(abs_diff) + + # Get top 10 predictions from both models + pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1] + llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1] + print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}") + print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}") + print(f"Max absolute difference: {max_diff:.4f}") + + if max_diff > 1.0: + print(f"āŒ NOK: Large differences detected - max diff: {max_diff:.4f}") + return False + + return True + +def main(): + model_path = os.getenv('MODEL_PATH') + if not model_path: + print("Error: MODEL_PATH environment variable not set") + sys.exit(1) + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + model_name = os.path.splitext(os.path.basename(model_path))[0] + data_dir = Path("data") + + pytorch_file = data_dir / f"pytorch-{model_name}.bin" + llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + if not pytorch_file.exists(): + print(f"Error: PyTorch logits file not found: {pytorch_file}") + print("Please run scripts/run-org-model.sh first to generate this file.") + sys.exit(1) + + if not llamacpp_file.exists(): + print(f"Error: llama.cpp logits file not found: {llamacpp_file}") + print("Please run scripts/run-converted-model.sh first to generate this file.") + sys.exit(1) + + print("Checked all required files were found. Proceeding...\n") + + + print("šŸ” GGML Model Validation for model ", model_name) + print("=" * 40) + print(f"PyTorch logits : {pytorch_file}") + print(f"llama.cpp logits: {llamacpp_file}") + print() + + success = quick_logits_check(pytorch_file, llamacpp_file) + + # Exit with appropriate code + if success: + print("āœ… OK: Lightweight model check successful!") + print(" Ok to proceed with NMSE check...") + sys.exit(0) + else: + print(f"āŒ NOK: Top 10 predictions don't match - generation will differ") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh new file mode 100755 index 000000000..56b21f9ba --- /dev/null +++ b/examples/model-conversion/scripts/causal/convert-model.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" +OUTPUT_DIR="${OUTPUT_DIR:-../../models}" +TYPE="${OUTTYPE:-f16}" +METADATA_OVERRIDE="${METADATA_OVERRIDE:-}" +CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" + +echo "Model path: ${MODEL_PATH}" +echo "Model name: ${MODEL_NAME}" +echo "Data type: ${TYPE}" +echo "Converted model path:: ${CONVERTED_MODEL}" +echo "Metadata override: ${METADATA_OVERRIDE}" +python ../../convert_hf_to_gguf.py --verbose \ + ${MODEL_PATH} \ + --outfile ${CONVERTED_MODEL} \ + --outtype ${TYPE} \ + --metadata "${METADATA_OVERRIDE}" + +echo "" +echo "The environment variable CONVERTED_MODEL can be set to this path using:" +echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})" diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.sh b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.sh new file mode 100755 index 000000000..2fb54ab99 --- /dev/null +++ b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import argparse +import os +import importlib +import sys +import torch +import numpy as np + +from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM +from pathlib import Path + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + +config = AutoConfig.from_pretrained(model_path) + +print("Model type: ", config.model_type) +print("Vocab size: ", config.vocab_size) +print("Hidden size: ", config.hidden_size) +print("Number of layers: ", config.num_hidden_layers) +print("BOS token id: ", config.bos_token_id) +print("EOS token id: ", config.eos_token_id) + +print("Loading model and tokenizer using AutoTokenizer:", model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}ForCausalLM" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") +else: + model = AutoModelForCausalLM.from_pretrained(model_path) +print(f"Model class: {type(model)}") +#print(f"Model file: {type(model).__module__}") + +model_name = os.path.basename(model_path) +print(f"Model name: {model_name}") + +prompt = "Hello world today" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids +print(f"Input tokens: {input_ids}") +print(f"Input text: {repr(prompt)}") +print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") + +with torch.no_grad(): + outputs = model(input_ids, output_hidden_states=True) + + # Extract hidden states from the last layer + # outputs.hidden_states is a tuple of (num_layers + 1) tensors + # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size] + last_hidden_states = outputs.hidden_states[-1] + + # Get embeddings for all tokens + token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension + + print(f"Hidden states shape: {last_hidden_states.shape}") + print(f"Token embeddings shape: {token_embeddings.shape}") + print(f"Hidden dimension: {token_embeddings.shape[-1]}") + print(f"Number of tokens: {token_embeddings.shape[0]}") + + # Save raw token embeddings + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" + txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" + + # Save all token embeddings as binary + print(token_embeddings) + token_embeddings.astype(np.float32).tofile(bin_filename) + + # Save as text for inspection + with open(txt_filename, "w") as f: + for i, embedding in enumerate(token_embeddings): + for j, val in enumerate(embedding): + f.write(f"{i} {j} {val:.6f}\n") + + # Print embeddings per token in the requested format + print("\nToken embeddings:") + tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) + for i, embedding in enumerate(token_embeddings): + # Format: show first few values, ..., then last few values + if len(embedding) > 10: + # Show first 3 and last 3 values with ... in between + first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3]) + last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:]) + print(f"embedding {i}: {first_vals} ... {last_vals}") + else: + # If embedding is short, show all values + vals = " ".join(f"{val:8.6f}" for val in embedding) + print(f"embedding {i}: {vals}") + + # Also show token info for reference + print(f"\nToken reference:") + for i, token in enumerate(tokens): + print(f" Token {i}: {repr(token)}") + + print(f"Saved bin logits to: {bin_filename}") + print(f"Saved txt logist to: {txt_filename}") diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh new file mode 100755 index 000000000..64709f179 --- /dev/null +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +set -e + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +cmake --build ../../build --target llama-logits -j8 + +../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today" diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh new file mode 100755 index 000000000..e2762729e --- /dev/null +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -e + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL + +cmake --build ../../build --target llama-logits -j8 + +../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is" diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py new file mode 100755 index 000000000..f6188ea6f --- /dev/null +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +import argparse +import os +import importlib +from pathlib import Path + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch +import numpy as np + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + +config = AutoConfig.from_pretrained(model_path) + +print("Model type: ", config.model_type) +print("Vocab size: ", config.vocab_size) +print("Hidden size: ", config.hidden_size) +print("Number of layers: ", config.num_hidden_layers) +print("BOS token id: ", config.bos_token_id) +print("EOS token id: ", config.eos_token_id) + +print("Loading model and tokenizer using AutoTokenizer:", model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) +config = AutoConfig.from_pretrained(model_path) + +if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}ForCausalLM" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) +else: + model = AutoModelForCausalLM.from_pretrained(model_path) + +model_name = os.path.basename(model_path) +# Printing the Model class to allow for easier debugging. This can be useful +# when working with models that have not been publicly released yet and this +# migth require that the concrete class is imported and used directly instead +# of using AutoModelForCausalLM. +print(f"Model class: {model.__class__.__name__}") + +prompt = "Hello, my name is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +print(f"Input tokens: {input_ids}") +print(f"Input text: {repr(prompt)}") +print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") + +with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Extract logits for the last token (next token prediction) + last_logits = logits[0, -1, :].cpu().numpy() + + print(f"Logits shape: {logits.shape}") + print(f"Last token logits shape: {last_logits.shape}") + print(f"Vocab size: {len(last_logits)}") + + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + bin_filename = data_dir / f"pytorch-{model_name}.bin" + txt_filename = data_dir / f"pytorch-{model_name}.txt" + + # Save to file for comparison + last_logits.astype(np.float32).tofile(bin_filename) + + # Also save as text file for easy inspection + with open(txt_filename, "w") as f: + for i, logit in enumerate(last_logits): + f.write(f"{i}: {logit:.6f}\n") + + # Print some sample logits for quick verification + print(f"First 10 logits: {last_logits[:10]}") + print(f"Last 10 logits: {last_logits[-10:]}") + + # Show top 5 predicted tokens + top_indices = np.argsort(last_logits)[-5:][::-1] + print("Top 5 predictions:") + for idx in top_indices: + token = tokenizer.decode([idx]) + print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") + + print(f"Saved bin logits to: {bin_filename}") + print(f"Saved txt logist to: {txt_filename}") diff --git a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh new file mode 100755 index 000000000..35b5d7198 --- /dev/null +++ b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh @@ -0,0 +1,42 @@ +#/bin/bash + +set -e + +MODEL_PATH="${1:-"$EMBEDDING_MODEL_PATH"}" +MODEL_NAME="${2:-$(basename "$MODEL_PATH")}" + +if [ -t 0 ]; then + CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin" +else + # Process piped JSON data and convert to binary (matching logits.cpp format) + TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) + python3 -c " +import json +import sys +import struct + +data = json.load(sys.stdin) + +# Flatten all embeddings completely +flattened = [] +for item in data: + embedding = item['embedding'] + for token_embedding in embedding: + flattened.extend(token_embedding) + +print(f'Total embedding values: {len(flattened)}', file=sys.stderr) + +# Write as binary floats - matches logitc.cpp fwrite format +with open('$TEMP_FILE', 'wb') as f: + for value in flattened: + f.write(struct.pack('f', value)) +" + CPP_EMBEDDINGS="$TEMP_FILE" + trap "rm -f $TEMP_FILE" EXIT +fi + +python scripts/utils/semantic_check.py --model-path $MODEL_PATH \ + --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \ + --cpp-embeddings $CPP_EMBEDDINGS \ + --prompt "Hello world today" + diff --git a/examples/model-conversion/scripts/embedding/convert-model.sh b/examples/model-conversion/scripts/embedding/convert-model.sh new file mode 100755 index 000000000..0609e3535 --- /dev/null +++ b/examples/model-conversion/scripts/embedding/convert-model.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e + +MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}" +OUTPUT_DIR="${OUTPUT_DIR:-../../models}" +TYPE="${OUTTYPE:-f16}" +METADATA_OVERRIDE="${METADATA_OVERRIDE:-}" +CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" + +echo "Model path: ${EMBEDDING_MODEL_PATH}" +echo "Model name: ${MODEL_NAME}" +echo "Data type: ${TYPE}" +echo "Converted model path:: ${CONVERTED_MODEL}" +python ../../convert_hf_to_gguf.py --verbose \ + ${EMBEDDING_MODEL_PATH} \ + --outfile ${CONVERTED_MODEL} \ + --outtype ${TYPE} + +echo "" +echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:" +echo "export CONVERTED_EMBEDDING_MODEL=$(realpath ${CONVERTED_MODEL})" diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh new file mode 100755 index 000000000..589609041 --- /dev/null +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -e + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_EMBEDDING_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_EMBEDDING_MODEL environment variable" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL + +cmake --build ../../build --target llama-logits -j8 + +../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "Hello world today" diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py new file mode 100755 index 000000000..b9db0b893 --- /dev/null +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +import argparse +import os +import numpy as np +import importlib +from pathlib import Path + +from transformers import AutoTokenizer, AutoConfig, AutoModel +import torch + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable") + +tokenizer = AutoTokenizer.from_pretrained(model_path) + +if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}Model" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) +else: + model = AutoModel.from_pretrained(model_path) +print(f"Model class: {type(model)}") +#print(f"Model file: {type(model).__module__}") +config = AutoConfig.from_pretrained(model_path) + +model_name = os.path.basename(model_path) + +texts = [ "Hello world today" ] + +encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" +) + +tokens = encoded['input_ids'][0] +token_strings = tokenizer.convert_ids_to_tokens(tokens) +for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + +with torch.no_grad(): + outputs = model(**encoded) + hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior) + all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + + print(f"Hidden states shape: {hidden_states.shape}") + print(f"All embeddings shape: {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1]}") + + # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE + n_embd = all_embeddings.shape[1] + n_embd_count = all_embeddings.shape[0] + + print() # Empty line to match C++ output + + for j in range(n_embd_count): + embedding = all_embeddings[j] + print(f"embedding {j}: ", end="") + + # Print first 3 values + for i in range(min(3, n_embd)): + print(f"{embedding[i]:9.6f} ", end="") + + print(" ... ", end="") + + # Print last 3 values + for i in range(n_embd - 3, n_embd): + print(f"{embedding[i]:9.6f} ", end="") + + print() # New line + + print() # Final empty line to match C++ output + + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" + txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" + + # Save all embeddings flattened (matching what embedding.cpp would save if it did) + flattened_embeddings = all_embeddings.flatten() + flattened_embeddings.astype(np.float32).tofile(bin_filename) + + with open(txt_filename, "w") as f: + f.write(f"# Model class: {model_name}\n") + f.write(f"# Tokens: {token_strings}\n") + f.write(f"# Shape: {all_embeddings.shape}\n") + f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n") + + for j in range(n_embd_count): + f.write(f"# Token {j} ({token_strings[j]}):\n") + for i, value in enumerate(all_embeddings[j]): + f.write(f"{j}_{i}: {value:.6f}\n") + f.write("\n") + print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens Ɨ {n_embd} dimensions)") + print("") + print(f"Saved bin embeddings to: {bin_filename}") + print(f"Saved txt embeddings to: {txt_filename}") diff --git a/examples/model-conversion/scripts/readme.md.template b/examples/model-conversion/scripts/readme.md.template new file mode 100644 index 000000000..87800a1b9 --- /dev/null +++ b/examples/model-conversion/scripts/readme.md.template @@ -0,0 +1,13 @@ +--- +base_model: +- {base_model} +--- +# {model_name} GGUF + +Recommended way to run this model: + +```sh +llama-server -hf {namespace}/{model_name}-GGUF -c 0 -fa +``` + +Then, access http://localhost:8080 diff --git a/examples/model-conversion/scripts/utils/check-nmse.py b/examples/model-conversion/scripts/utils/check-nmse.py new file mode 100755 index 000000000..196a6210f --- /dev/null +++ b/examples/model-conversion/scripts/utils/check-nmse.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +import numpy as np +import sys +import os +import argparse +from pathlib import Path + +def calculate_nmse(reference, test): + mse = np.mean((test - reference) ** 2) + ref_var = np.var(reference) + if ref_var == 0: + nmse = float('inf') if mse > 0 else 0.0 + return mse, mse, ref_var + + nmse = mse / ref_var + + return nmse, mse, ref_var + +def load_logits(file_path): + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + if file_path.suffix == '.npy': + return np.load(file_path) + elif file_path.suffix == '.bin': + return np.fromfile(file_path, dtype=np.float32) + else: + # Try to load as text file + try: + # If it has index format "0: value", extract just values + data = [] + with open(file_path, 'r') as f: + for line in f: + if ':' in line: + # Format: "index: value" + value = float(line.split(':')[1].strip()) + else: + # Just the value + value = float(line.strip()) + data.append(value) + return np.array(data, dtype=np.float32) + except: + return np.loadtxt(file_path, dtype=np.float32) + +def interpret_nmse(nmse): + """Provide interpretation of NMSE value""" + if nmse == 0: + return "Perfect match", "šŸŽ‰" + elif nmse < 1e-6: + return "Essentially identical", "āœ…" + elif nmse < 1e-4: + return "Excellent match", "āœ…" + elif nmse < 1e-3: + return "Very good match", "šŸ‘" + elif nmse < 1e-2: + return "Good match", "šŸ‘" + elif nmse < 0.1: + return "Acceptable match", "āš ļø" + elif nmse < 1.0: + return "Poor match", "āŒ" + else: + return "Very poor match (worse than noise)", "āŒ" + +def main(): + parser = argparse.ArgumentParser(description='Validate model logits') + parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory') + args = parser.parse_args() + + model_name = os.path.splitext(os.path.basename(args.model_path))[0] + data_dir = Path("data") + + pytorch_file = data_dir / f"pytorch-{model_name}.bin" + llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + print(f"Model name: {model_name}") + print(f"PyTorch logits file: {pytorch_file}") + print(f"llama.cpp logits file: {llamacpp_file}") + + reference_file = pytorch_file + test_file = llamacpp_file + + print("šŸ“Š NMSE Check for Model Comparison") + print("=" * 50) + print(f"Reference (ground truth): {reference_file}") + print(f"Test (to evaluate): {test_file}") + print() + + try: + print("Loading reference logits...") + reference = load_logits(reference_file) + print(f" Shape: {reference.shape}, Type: {reference.dtype}") + + print("Loading test logits...") + test = load_logits(test_file) + print(f" Shape: {test.shape}, Type: {test.dtype}") + + # Check shapes match + if reference.shape != test.shape: + print(f"\nāŒ Error: Shape mismatch!") + print(f" Reference: {reference.shape}") + print(f" Test: {test.shape}") + sys.exit(1) + + print(f"\nāœ… Shapes match: {reference.shape}") + + nmse, mse, ref_var = calculate_nmse(reference, test) + + # Additional metrics + max_abs_error = np.max(np.abs(test - reference)) + mean_abs_error = np.mean(np.abs(test - reference)) + + # Results + print(f"\nšŸ“ˆ METRICS") + print("=" * 30) + print(f"MSE (Mean Squared Error): {mse:.6e}") + print(f"Reference Variance: {ref_var:.6e}") + print(f"NMSE: {nmse:.6e}") + print(f"Max Absolute Error: {max_abs_error:.6f}") + print(f"Mean Absolute Error: {mean_abs_error:.6f}") + + # NMSE in dB (common in signal processing) + if nmse > 0: + nmse_db = 10 * np.log10(nmse) + print(f"NMSE (dB): {nmse_db:.2f} dB") + + # Interpretation + interpretation, emoji = interpret_nmse(nmse) + print(f"\nšŸŽÆ INTERPRETATION") + print("=" * 30) + print(f"{emoji} {interpretation}") + + # Detailed guidance + print(f"\nšŸ“‹ GUIDANCE") + print("=" * 30) + if nmse < 1e-3: + print("āœ… EXCELLENT: Your GGML conversion is working very well!") + print(" The differences are negligible for practical use.") + elif nmse < 1e-2: + print("šŸ‘ GOOD: Your GGML conversion is working well.") + print(" Small differences are likely due to precision/quantization.") + elif nmse < 0.1: + print("āš ļø ACCEPTABLE: Conversion is working but with some differences.") + print(" Check if you're using quantization (Q4, Q8, etc.)") + print(" Test generation quality to see if it's acceptable.") + else: + print("āŒ PROBLEMATIC: Large differences detected.") + print(" Check your conversion process for potential issues.") + print(" Verify you're using the same model weights.") + + # NMSE benchmarks + print(f"\nšŸ“š NMSE BENCHMARKS") + print("=" * 30) + print("< 1e-6: Essentially identical") + print("< 1e-4: Excellent (typical for good conversions)") + print("< 1e-3: Very good") + print("< 1e-2: Good (acceptable for most use cases)") + print("< 0.1: Acceptable (may need verification)") + print("> 1.0: Poor (worse than random)") + + # Exit code based on NMSE + if nmse < 1e-2: + print(f"\nāœ… RESULT: PASS (NMSE = {nmse:.2e})") + sys.exit(0) + else: + print(f"\nāŒ RESULT: NEEDS REVIEW (NMSE = {nmse:.2e})") + sys.exit(1) + + except Exception as e: + print(f"āŒ Error: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/create-collection-add-model.sh b/examples/model-conversion/scripts/utils/create-collection-add-model.sh new file mode 100644 index 000000000..4809da6cb --- /dev/null +++ b/examples/model-conversion/scripts/utils/create-collection-add-model.sh @@ -0,0 +1,6 @@ + +COLLECTION_SLUG=$(python ./create_collection.py --return-slug) +echo "Created collection: $COLLECTION_SLUG" + +# Use it in the next command +python add_model_to_collection.py "$COLLECTION_SLUG" "username/my-model" diff --git a/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py b/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py new file mode 100755 index 000000000..7e38af3c1 --- /dev/null +++ b/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import sys + +def add_model_to_collection(collection_slug, model_id, note=""): + """ + Add a model to an existing collection + + Args: + collection_slug: The slug of the collection (e.g., "username/collection-name-12345") + model_id: The model repository ID (e.g., "username/model-name") + note: Optional note about the model + + Returns: + True if successful, False if failed + """ + + # Initialize API + api = HfApi() + + try: + user_info = api.whoami() + print(f"āœ… Authenticated as: {user_info['name']}") + + # Verify the model exists + print(f"šŸ” Checking if model exists: {model_id}") + try: + model_info = api.model_info(model_id) + except Exception as e: + print(f"āŒ Model not found or not accessible: {model_id}") + print(f"Error: {e}") + return False + + print(f"šŸ“š Adding model to collection...") + api.add_collection_item( + collection_slug=collection_slug, + item_id=model_id, + item_type="model", + note=note + ) + + print(f"āœ… Model added to collection successfully!") + print(f"šŸ”— Collection URL: https://huggingface.co/collections/{collection_slug}") + + return True + + except Exception as e: + print(f"āŒ Error adding model to collection: {e}") + return False + +def main(): + # This script requires that the environment variable HF_TOKEN is set with your + # Hugging Face API token. + api = HfApi() + + parser = argparse.ArgumentParser(description='Add model to a Huggingface Collection') + parser.add_argument('--collection', '-c', help='The collection slug username/collection-hash', required=True) + parser.add_argument('--model', '-m', help='The model to add to the Collection', required=True) + parser.add_argument('--note', '-n', help='An optional note/description', required=False) + args = parser.parse_args() + + collection = args.collection + model = args.model + note = args.note + + success = add_model_to_collection( + collection_slug=collection, + model_id=model, + note=note + ) + + if success: + print("\nšŸŽ‰ Model added successfully!") + else: + print("\nāŒ Failed to add model to collection") + sys.exit(1) +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/hf-create-collection.py b/examples/model-conversion/scripts/utils/hf-create-collection.py new file mode 100755 index 000000000..e0fa60af1 --- /dev/null +++ b/examples/model-conversion/scripts/utils/hf-create-collection.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import os +import sys + + +def create_collection(title, description, private=False, namespace=None, return_slug=False): + """ + Create a new collection on Hugging Face + + Args: + title: Collection title + description: Collection description + private: Whether the collection should be private (default: False) + namespace: Optional namespace (defaults to your username) + + Returns: + Collection object if successful, None if failed + """ + + # Check if HF_TOKEN is available + token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") + if not token: + print("āŒ No HF_TOKEN or HUGGINGFACE_HUB_TOKEN found in environment variables") + print("Please set your Hugging Face token as an environment variable") + return None + + # Initialize API + api = HfApi() + + try: + # Test authentication first + user_info = api.whoami() + if not return_slug: + print(f"āœ… Authenticated as: {user_info['name']}") + + # Create the collection + if not return_slug: + print(f"šŸ“š Creating collection: '{title}'...") + collection = api.create_collection( + title=title, + description=description, + private=private, + namespace=namespace + ) + + if not return_slug: + print(f"āœ… Collection created successfully!") + print(f"šŸ“‹ Collection slug: {collection.slug}") + print(f"šŸ”— Collection URL: https://huggingface.co/collections/{collection.slug}") + + return collection + + except Exception as e: + print(f"āŒ Error creating collection: {e}") + return None + +def main(): + # This script requires that the environment variable HF_TOKEN is set with your + # Hugging Face API token. + api = HfApi() + + parser = argparse.ArgumentParser(description='Create a Huggingface Collection') + parser.add_argument('--name', '-n', help='The name/title of the Collection', required=True) + parser.add_argument('--description', '-d', help='The description for the Collection', required=True) + parser.add_argument('--namespace', '-ns', help='The namespace to add the Collection to', required=True) + parser.add_argument('--private', '-p', help='Create a private Collection', action='store_true') # Fixed + parser.add_argument('--return-slug', '-s', help='Only output the collection slug', action='store_true') # Fixed + + args = parser.parse_args() + + name = args.name + description = args.description + private = args.private + namespace = args.namespace + return_slug = args.return_slug + + if not return_slug: + print("šŸš€ Creating Hugging Face Collection") + print(f"Title: {name}") + print(f"Description: {description}") + print(f"Namespace: {namespace}") + print(f"Private: {private}") + + collection = create_collection( + title=name, + description=description, + private=private, + namespace=namespace, + return_slug=return_slug + ) + + if collection: + if return_slug: + print(collection.slug) + else: + print("\nšŸŽ‰ Collection created successfully!") + print(f"Use this slug to add models: {collection.slug}") + else: + print("\nāŒ Failed to create collection") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/hf-create-model.py b/examples/model-conversion/scripts/utils/hf-create-model.py new file mode 100755 index 000000000..09bb8511e --- /dev/null +++ b/examples/model-conversion/scripts/utils/hf-create-model.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse + +# This script requires that the environment variable HF_TOKEN is set with your +# Hugging Face API token. +api = HfApi() + +def load_template_and_substitute(template_path, **kwargs): + try: + with open(template_path, 'r', encoding='utf-8') as f: + template_content = f.read() + + return template_content.format(**kwargs) + except FileNotFoundError: + print(f"Template file '{template_path}' not found!") + return None + except KeyError as e: + print(f"Missing template variable: {e}") + return None + +parser = argparse.ArgumentParser(description='Create a new Hugging Face model repository') +parser.add_argument('--model-name', '-m', help='Name for the model', required=True) +parser.add_argument('--namespace', '-ns', help='Namespace to add the model to', required=True) +parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="") +parser.add_argument('--no-card', action='store_true', help='Skip creating model card') +parser.add_argument('--private', '-p', action='store_true', help='Create private model') + +args = parser.parse_args() + +repo_id = f"{args.namespace}/{args.model_name}-GGUF" +print("Repository ID: ", repo_id) + +repo_url = api.create_repo( + repo_id=repo_id, + repo_type="model", + private=args.private, + exist_ok=False +) + +if not args.no_card: + template_path = "scripts/readme.md.template" + model_card_content = load_template_and_substitute( + template_path, + model_name=args.model_name, + namespace=args.namespace, + base_model=args.org_base_model, + ) + + if model_card_content: + api.upload_file( + path_or_fileobj=model_card_content.encode('utf-8'), + path_in_repo="README.md", + repo_id=repo_id + ) + print("Model card created successfully.") + else: + print("Failed to create model card.") + +print(f"Repository created: {repo_url}") + + diff --git a/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py b/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py new file mode 100755 index 000000000..15ccb1150 --- /dev/null +++ b/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import os + +def upload_gguf_file(local_file_path, repo_id, filename_in_repo=None): + """ + Upload a GGUF file to a Hugging Face model repository + + Args: + local_file_path: Path to your local GGUF file + repo_id: Your repository ID (e.g., "username/model-name") + filename_in_repo: Optional custom name for the file in the repo + """ + + if not os.path.exists(local_file_path): + print(f"āŒ File not found: {local_file_path}") + return False + + if filename_in_repo is None: + filename_in_repo = os.path.basename(local_file_path) + + if filename_in_repo is None or filename_in_repo == "": + filename_in_repo = os.path.basename(local_file_path) + + print(f"šŸ“¤ Uploading {local_file_path} to {repo_id}/{filename_in_repo}") + + api = HfApi() + + try: + api.upload_file( + path_or_fileobj=local_file_path, + path_in_repo=filename_in_repo, + repo_id=repo_id, + repo_type="model", + commit_message=f"Upload {filename_in_repo}" + ) + + print("āœ… Upload successful!") + print(f"šŸ”— File available at: https://huggingface.co/{repo_id}/blob/main/{filename_in_repo}") + return True + + except Exception as e: + print(f"āŒ Upload failed: {e}") + return False + +# This script requires that the environment variable HF_TOKEN is set with your +# Hugging Face API token. +api = HfApi() + +parser = argparse.ArgumentParser(description='Upload a GGUF model to a Huggingface model repository') +parser.add_argument('--gguf-model-path', '-m', help='The GGUF model file to upload', required=True) +parser.add_argument('--repo-id', '-r', help='The repository to upload to', required=True) +parser.add_argument('--name', '-o', help='The name in the model repository', required=False) +args = parser.parse_args() + +upload_gguf_file(args.gguf_model_path, args.repo_id, args.name) diff --git a/examples/model-conversion/scripts/utils/inspect-converted-model.sh b/examples/model-conversion/scripts/utils/inspect-converted-model.sh new file mode 100755 index 000000000..e5b932454 --- /dev/null +++ b/examples/model-conversion/scripts/utils/inspect-converted-model.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +../../gguf-py/gguf/scripts/gguf_dump.py $CONVERTED_MODEL diff --git a/examples/model-conversion/scripts/utils/inspect-org-model.py b/examples/model-conversion/scripts/utils/inspect-org-model.py new file mode 100755 index 000000000..bc6f45a5f --- /dev/null +++ b/examples/model-conversion/scripts/utils/inspect-org-model.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +import argparse +import os +import json +from safetensors import safe_open +from collections import defaultdict + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + +# Check if there's an index file (multi-file model) +index_path = os.path.join(model_path, "model.safetensors.index.json") +single_file_path = os.path.join(model_path, "model.safetensors") + +if os.path.exists(index_path): + # Multi-file model + print("Multi-file model detected") + + with open(index_path, 'r') as f: + index_data = json.load(f) + + # Get the weight map (tensor_name -> file_name) + weight_map = index_data.get("weight_map", {}) + + # Group tensors by file for efficient processing + file_tensors = defaultdict(list) + for tensor_name, file_name in weight_map.items(): + file_tensors[file_name].append(tensor_name) + + print("Tensors in model:") + + # Process each shard file + for file_name, tensor_names in file_tensors.items(): + file_path = os.path.join(model_path, file_name) + print(f"\n--- From {file_name} ---") + + with safe_open(file_path, framework="pt") as f: + for tensor_name in sorted(tensor_names): + tensor = f.get_tensor(tensor_name) + print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}") + +elif os.path.exists(single_file_path): + # Single file model (original behavior) + print("Single-file model detected") + + with safe_open(single_file_path, framework="pt") as f: + keys = f.keys() + print("Tensors in model:") + for key in sorted(keys): + tensor = f.get_tensor(key) + print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}") + +else: + print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}") + print("Available files:") + if os.path.exists(model_path): + for item in sorted(os.listdir(model_path)): + print(f" {item}") + else: + print(f" Directory {model_path} does not exist") + exit(1) diff --git a/examples/model-conversion/scripts/utils/perplexity-gen.sh b/examples/model-conversion/scripts/utils/perplexity-gen.sh new file mode 100755 index 000000000..3db0b3fd2 --- /dev/null +++ b/examples/model-conversion/scripts/utils/perplexity-gen.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -e + +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +# Check if data/wikitext-2-raw directory exists +if [ ! -d "ppl/wikitext-2-raw" ]; then + echo "ppl/wikitext-2-raw directory does not exist. Downloading..." >&2 + mkdir -p ppl + pushd ppl + ./../../../scripts/get-wikitext-2.sh + popd +fi + +mkdir -p ppl +OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld" +echo "Model: $CONVERTED_MODEL" + +cmake --build ../../build --target llama-perplexity -j8 + +../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \ + -f ppl/wikitext-2-raw/wiki.test.raw \ + --kl-divergence-base $OUTPUTFILE + +echo "Generated logits in $OUTPUTFILE" + diff --git a/examples/model-conversion/scripts/utils/perplexity-run-simple.sh b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh new file mode 100755 index 000000000..69b3438f5 --- /dev/null +++ b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e + +QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" + +if [ -z "$QUANTIZED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. QUANTIZED_MODEL environment variable" >&2 + exit 1 +fi + +# Check if data/wikitext-2-raw directory exists +if [ ! -d "ppl/wikitext-2-raw" ]; then + echo "ppl/wikitext-2-raw directory does not exist. Downloading..." >&2 + mkdir -p ppl + pushd ppl + ./../../../scripts/get-wikitext-2.sh + popd +fi + +cmake --build ../../build --target llama-perplexity -j8 + +../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw + + diff --git a/examples/model-conversion/scripts/utils/perplexity-run.sh b/examples/model-conversion/scripts/utils/perplexity-run.sh new file mode 100755 index 000000000..3bce7c847 --- /dev/null +++ b/examples/model-conversion/scripts/utils/perplexity-run.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +set -e + +QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" +LOGITS_FILE="${1:-"$LOGITS_FILE"}" + +if [ -z "$QUANTIZED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. QUANTIZED_MODEL environment variable" >&2 + exit 1 +fi + +if [ ! -f ${LOGITS_FILE} ]; then + echo "Error: logits file '${LOGITS_FILE} was not found" + echo "Did you run the perplexity-gen.sh script?" + exit 1 +fi + +echo "Model: $QUANTIZED_MODEL" +echo "Data file: $LOGITS_FILE" + +cmake --build ../../build --target llama-perplexity -j8 + +../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \ + --kl-divergence-base $LOGITS_FILE \ + --kl-divergence diff --git a/examples/model-conversion/scripts/utils/quantize.sh b/examples/model-conversion/scripts/utils/quantize.sh new file mode 100755 index 000000000..bcb877575 --- /dev/null +++ b/examples/model-conversion/scripts/utils/quantize.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -e + +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}" +QUANTIZED_MODEL=$CONVERTED_MODEL + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL + +# Process the quantized model filename +if [[ "$QUANTIZED_MODEL" == *.gguf ]]; then + # Remove .gguf suffix, add quantized type, then add .gguf back + BASE_NAME="${QUANTIZED_MODEL%.gguf}" + QUANTIZED_MODEL="${BASE_NAME}-${QUANTIZED_TYPE}.gguf" +else + echo "Error: QUANTIZED_MODEL must end with .gguf extension" >&2 + exit 1 +fi + + +cmake --build ../../build --target llama-quantize -j8 + +../../build/bin/llama-quantize $CONVERTED_MODEL $QUANTIZED_MODEL $QUANTIZED_TYPE + +echo "Quantized model saved to: $QUANTIZED_MODEL" diff --git a/examples/model-conversion/scripts/utils/run-embedding-server.sh b/examples/model-conversion/scripts/utils/run-embedding-server.sh new file mode 100755 index 000000000..828fc4706 --- /dev/null +++ b/examples/model-conversion/scripts/utils/run-embedding-server.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +# +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL + +cmake --build ../../build --target llama-server + +../../build/bin/llama-server -m $CONVERTED_MODEL \ + --embedding \ + --pooling none diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py new file mode 100644 index 000000000..d21104809 --- /dev/null +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 + +import numpy as np +import argparse +import os +import importlib + +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +def cosine_similarity(a, b=None): + a = np.asarray(a) + if b is None: + b = a + else: + b = np.asarray(b) + + if a.ndim == 1: + a = a.reshape(1, -1) + if b.ndim == 1: + b = b.reshape(1, -1) + + a_norms = np.linalg.norm(a, axis=1, keepdims=True) + b_norms = np.linalg.norm(b, axis=1, keepdims=True) + + a_norms = np.where(a_norms == 0, 1e-8, a_norms) + b_norms = np.where(b_norms == 0, 1e-8, b_norms) + + a_normalized = a / a_norms + b_normalized = b / b_norms + + # Compute cosine similarity + return np.dot(a_normalized, b_normalized.T) + +def load_embeddings_from_file(filename, n_tokens, n_embd): + embeddings = np.fromfile(filename, dtype=np.float32) + return embeddings.reshape(n_tokens, n_embd) + +def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): + np.set_printoptions(suppress=True, precision=6) + print("pytorch embeddings:"); + print(python_emb) + print("llama.cpp embeddings:"); + print(cpp_emb) + print(f"\n=== Prompt: '{prompt}' ===") + print(f"Tokens: {tokens}") + print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}") + + n_tokens = len(tokens) + + # 1. Direct embedding comparison + print(f"\n1. Raw Embedding Magnitude Comparison:") + # Check if the distance of each token embedding from the origin and compare + # if the vectors are on the same "sphere". This does not tell us about + # direction (meaning of the token embedding), just magnitude. + for i in range(n_tokens): + py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings + cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cosine similarity between tokens within each model + # Here we check the direction of token embeddings to see if the have the + # same meaning (similarity). This is done by calculating cosine similarity + # of a pair of token embeddings within each model. + print(f"\n2. Within-Model Token Similarities:") + print(" Python model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + print(" llama.cpp model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + # 3. Cross-model similarity (same token position) + print(f"\n3. Cross-Model Same-Token Similarities:") + for i in range(n_tokens): + sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] + print(f" Token {i} ({tokens[i]}): {sim:.4f}") + + # 4. Similarity matrix comparison + print(f"\n4. Similarity Matrix Differences:") + py_sim_matrix = cosine_similarity(python_emb) + cpp_sim_matrix = cosine_similarity(cpp_emb) + diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + + print(f" Max difference: {np.max(diff_matrix):.4f}") + print(f" Mean difference: {np.mean(diff_matrix):.4f}") + print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + + return { + 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], + 'similarity_matrix_diff': diff_matrix, + 'max_diff': np.max(diff_matrix), + 'mean_diff': np.mean(diff_matrix), + 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) + } + +def main(): + parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings') + parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model') + parser.add_argument('--python-embeddings', '-pe', help='Path to pytorch embeddings "logits" binary file') + parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file') + parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true') + parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt') + + args = parser.parse_args() + + print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") + print("=" * 70) + + # Single prompt detailed comparison + print(f"\nTesting with prompt: '{args.prompt}'") + + # Load the python model to get configuration information and also to load the tokenizer. + print("Loading model and tokenizer using AutoTokenizer:", args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + config = AutoConfig.from_pretrained(args.model_path) + + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + if args.causal: + class_name = f"{unreleased_model_name}ForCausalLM" + else: + class_name = f"{unreleased_model_name}Model" + print(f"Model class: {class_name}") + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(args.model_path) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + if args.causal: + model = AutoModelForCausalLM.from_pretrained(args.model_path) + else: + model = AutoModel.from_pretrained(args.model_path) + + encoded = tokenizer(args.prompt, return_tensors="pt") + tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) + n_tokens = len(tokens) + print(f"n_tokens: {n_tokens}"); + print(f"hidden_size: {model.config.hidden_size}") + + # Load binary embeddings from data directory. + llamacpp_embeddings = load_embeddings_from_file(args.cpp_embeddings, n_tokens, model.config.hidden_size) + python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size) + + # Run comparison + results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, args.prompt) + + # Summary + print(f"\n=== SUMMARY ===") + avg_cross_sim = np.mean(results['cross_model_similarities']) + print(f"Average cross-model similarity: {avg_cross_sim:.4f}") + print(f"Similarity matrix RMS difference: {results['rms_diff']:.4f}") + + # Quality assessment + if avg_cross_sim > 0.95: + print("āœ… EXCELLENT: Models are highly similar") + elif avg_cross_sim > 0.90: + print("āœ… VERY GOOD: Models are very similar") + elif avg_cross_sim > 0.80: + print("āš ļø GOOD: Models are reasonably similar") + elif avg_cross_sim > 0.70: + print("āš ļø FAIR: Models have some differences") + else: + print("āŒ POOR: Models are significantly different") + +if __name__ == "__main__": + main()