mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 01:29:42 +00:00
⚡ release v0.2.3
This commit is contained in:
parent
034a116365
commit
848fe8ab97
9 changed files with 233 additions and 8 deletions
|
@ -1,4 +1,18 @@
|
||||||
|
<!-- omit in toc -->
|
||||||
# FAQ
|
# FAQ
|
||||||
|
- [Install](#install)
|
||||||
|
- [Q: ImportError: /lib/x86\_64-linux-gnu/libstdc++.so.6: version GLIBCXX\_3.4.32' not found](#q-importerror-libx86_64-linux-gnulibstdcso6-version-glibcxx_3432-not-found)
|
||||||
|
- [Q: DeepSeek-R1 not outputting initial token](#q-deepseek-r1-not-outputting-initial--token)
|
||||||
|
- [Usage](#usage)
|
||||||
|
- [Q: If I got more VRAM than the model's requirement, how can I fully utilize it?](#q-if-i-got-more-vram-than-the-models-requirement-how-can-i-fully-utilize-it)
|
||||||
|
- [Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?](#q-if-i-dont-have-enough-vram-but-i-have-multiple-gpus-how-can-i-utilize-them)
|
||||||
|
- [Q: How to get the best performance?](#q-how-to-get-the-best-performance)
|
||||||
|
- [Q: My DeepSeek-R1 model is not thinking.](#q-my-deepseek-r1-model-is-not-thinking)
|
||||||
|
- [Q: Loading gguf error](#q-loading-gguf-error)
|
||||||
|
- [Q: Version \`GLIBCXX\_3.4.30' not found](#q-version-glibcxx_3430-not-found)
|
||||||
|
- [Q: When running the bfloat16 moe model, the data shows NaN](#q-when-running-the-bfloat16-moe-model-the-data-shows-nan)
|
||||||
|
- [Q: Using fp8 prefill very slow.](#q-using-fp8-prefill-very-slow)
|
||||||
|
- [Q: Possible ways to run graphics cards using volta and turing architectures](#q-possible-ways-to-run-graphics-cards-using-volta-and-turing-architectures)
|
||||||
## Install
|
## Install
|
||||||
### Q: ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version GLIBCXX_3.4.32' not found
|
### Q: ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version GLIBCXX_3.4.32' not found
|
||||||
```
|
```
|
||||||
|
@ -96,4 +110,58 @@ RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
|
||||||
|
|
||||||
### Q: Using fp8 prefill very slow.
|
### Q: Using fp8 prefill very slow.
|
||||||
|
|
||||||
The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster.
|
The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster.
|
||||||
|
|
||||||
|
### Q: Possible ways to run graphics cards using volta and turing architectures
|
||||||
|
|
||||||
|
From: https://github.com/kvcache-ai/ktransformers/issues/374
|
||||||
|
|
||||||
|
1. First, download the latest source code using git.
|
||||||
|
2. Then, modify the DeepSeek-V3-Chat-multi-gpu-4.yaml in the source code and all related yaml files, replacing all instances of KLinearMarlin with KLinearTorch.
|
||||||
|
3. Next, you need to compile from the ktransformer source code until it successfully compiles on your local machine.
|
||||||
|
4. Then, install flash-attn. It won't be used, but not installing it will cause an error.
|
||||||
|
5. Then, modify local_chat.py, replacing all instances of flash_attention_2 with eager.
|
||||||
|
6. Then, run local_chat.py. Be sure to follow the official tutorial's commands and adjust according to your local machine's parameters.
|
||||||
|
7. During the running process, check the memory usage. Observe its invocation through the top command. The memory capacity on a single CPU must be greater than the complete size of the model. (For multiple CPUs, it's just a copy.)
|
||||||
|
Finally, confirm that the model is fully loaded into memory and specific weight layers are fully loaded into the GPU memory. Then, try to input content in the chat interface and observe if there are any errors.
|
||||||
|
|
||||||
|
Attention, for better perfomance, you can check this [method](https://github.com/kvcache-ai/ktransformers/issues/374#issuecomment-2667520838) in the issue
|
||||||
|
>
|
||||||
|
>https://github.com/kvcache-ai/ktransformers/blob/89f8218a2ab7ff82fa54dbfe30df741c574317fc/ktransformers/operators/attention.py#L274-L279
|
||||||
|
>
|
||||||
|
>```diff
|
||||||
|
>+ original_dtype = query_states.dtype
|
||||||
|
>+ target_dtype = torch.half
|
||||||
|
>+ query_states = query_states.to(target_dtype)
|
||||||
|
>+ compressed_kv_with_k_pe = compressed_kv_with_k_pe.to(target_dtype)
|
||||||
|
>+ compressed_kv = compressed_kv.to(target_dtype)
|
||||||
|
>+ attn_output = attn_output.to(target_dtype)
|
||||||
|
>
|
||||||
|
>decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||||
|
> page_table,
|
||||||
|
> position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
|
||||||
|
> 4, #num_kv_splits # follow vLLM, fix it TODO
|
||||||
|
> self.softmax_scale,
|
||||||
|
> past_key_value.page_size)
|
||||||
|
>
|
||||||
|
>+ attn_output = attn_output.to(original_dtype)
|
||||||
|
>```
|
||||||
|
>
|
||||||
|
>https://github.com/kvcache-ai/ktransformers/blob/89f8218a2ab7ff82fa54dbfe30df741c574317fc/ktransformers/operators/attention.py#L320-L326
|
||||||
|
>
|
||||||
|
>```diff
|
||||||
|
>- attn_output = flash_attn_func(
|
||||||
|
>- query_states,
|
||||||
|
>- key_states,
|
||||||
|
>- value_states_padded,
|
||||||
|
>- softmax_scale=self.softmax_scale,
|
||||||
|
>- causal=True,
|
||||||
|
>- )
|
||||||
|
>+ attn_output = F.scaled_dot_product_attention(
|
||||||
|
>+ query_states.transpose(1, 2),
|
||||||
|
>+ key_states.transpose(1, 2),
|
||||||
|
>+ value_states_padded.transpose(1, 2),
|
||||||
|
>+ scale=self.softmax_scale,
|
||||||
|
>+ is_causal=True
|
||||||
|
>+ ).transpose(1, 2)
|
||||||
|
>```
|
|
@ -26,7 +26,7 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju
|
||||||
|
|
||||||
|
|
||||||
## The Result Table
|
## The Result Table
|
||||||
|
Uses DeepSeek-V3 model (Some specific cases are R1)
|
||||||
| | | | | | | | |
|
| | | | | | | | |
|
||||||
| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ |
|
| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ |
|
||||||
| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM Kernel | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)<br> | Ktrans Point |
|
| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM Kernel | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)<br> | Ktrans Point |
|
||||||
|
@ -37,9 +37,11 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju
|
||||||
| 4 | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 |
|
| 4 | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 |
|
||||||
| 5 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 |
|
| 5 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 |
|
||||||
| 6 | q4km | cpuinfer | fp8 | fp8gemm | triton | 81.6 | 81.5 |
|
| 6 | q4km | cpuinfer | fp8 | fp8gemm | triton | 81.6 | 81.5 |
|
||||||
| MMLU-pro | | | | | | | |
|
| 7 (DeepSeek-R1) | iq1 | cpuinfer | fp8 | fp8gemm | triton | 78.6 | 83.6 |
|
||||||
|
| MMLU-pro<br>(shuffle 1k) | | | | | | | |
|
||||||
| 1 | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 |
|
| 1 | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 |
|
||||||
| 2 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 |
|
| 2 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 |
|
||||||
|
| 3 (DeepSeek-R1) | iq1 | cpuinfer | fp8 | fp8gem | triton | 71.9 | tbd |
|
||||||
| HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
| HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
||||||
| GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
| GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
||||||
|
|
||||||
|
@ -54,6 +56,8 @@ By default, The MLA kernel uses triton in linux and torch in windows. But we nee
|
||||||
4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
||||||
5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
||||||
6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
||||||
|
7. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
||||||
- MMLU-pro test
|
- MMLU-pro test
|
||||||
1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
||||||
2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)
|
||||||
|
3. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.
|
|
@ -1,5 +1,11 @@
|
||||||
|
<!-- omit in toc -->
|
||||||
# How to Run DeepSeek-R1
|
# How to Run DeepSeek-R1
|
||||||
|
- [Preparation](#preparation)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Attention](#attention)
|
||||||
|
- [Supported models include:](#supported-models-include)
|
||||||
|
- [Support quantize format:](#support-quantize-format)
|
||||||
|
|
||||||
In this document, we will show you how to install and run KTransformers on your local machine. There are two versions:
|
In this document, we will show you how to install and run KTransformers on your local machine. There are two versions:
|
||||||
* V0.2 is the current main branch.
|
* V0.2 is the current main branch.
|
||||||
* V0.3 is a preview version only provides binary distribution for now.
|
* V0.3 is a preview version only provides binary distribution for now.
|
||||||
|
@ -56,6 +62,8 @@ Some preparation:
|
||||||
- At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases.
|
- At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
### Attention
|
||||||
|
If you want to use numa support, not only do you need to set USE_NUMA=1, but you also need to make sure you have installed the libnuma-dev (`sudo apt-get install libnuma-dev` may help you).
|
||||||
|
|
||||||
<!-- 1. ~~Use a Docker image, see [documentation for Docker](./doc/en/Docker.md)~~
|
<!-- 1. ~~Use a Docker image, see [documentation for Docker](./doc/en/Docker.md)~~
|
||||||
|
|
||||||
|
|
|
@ -8,4 +8,4 @@ Version : 1.0.0
|
||||||
LastEditors : chenxl
|
LastEditors : chenxl
|
||||||
LastEditTime : 2025-02-15 03:53:02
|
LastEditTime : 2025-02-15 03:53:02
|
||||||
'''
|
'''
|
||||||
__version__ = "0.2.2rc1"
|
__version__ = "0.2.3"
|
2
ktransformers/tests/.gitignore
vendored
2
ktransformers/tests/.gitignore
vendored
|
@ -1 +1 @@
|
||||||
humaneval/results
|
results/
|
133
ktransformers/tests/AIME_2024/eval_api.py
Normal file
133
ktransformers/tests/AIME_2024/eval_api.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from evaluation import filter_answer
|
||||||
|
from prompts import instruct_prompt
|
||||||
|
import pandas as pd
|
||||||
|
from datasets import load_dataset
|
||||||
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||||
|
|
||||||
|
|
||||||
|
def generate_text(api_url,question , model_name, stream=False, auth_token=None):
|
||||||
|
headers = {
|
||||||
|
'accept': 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
# 添加 API Key
|
||||||
|
'Authorization' : 'Bearer ' + auth_token if auth_token else ''
|
||||||
|
}
|
||||||
|
question = instruct_prompt(question)
|
||||||
|
data = {
|
||||||
|
"messages": [{"content": question, "role": "user"}],
|
||||||
|
"model": model_name,
|
||||||
|
"stream": stream,
|
||||||
|
"temperature": 0.6,
|
||||||
|
"max_tokens": 10240,
|
||||||
|
}
|
||||||
|
print(f"content: {question}")
|
||||||
|
response = requests.post(api_url, headers=headers, json=data,verify=False)
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
results = result.get('choices', [{}])[0].get('message', {}).get('content', '')
|
||||||
|
return filter_answer(results)
|
||||||
|
else:
|
||||||
|
print(f"API Request failed with status code {response.status_code}")
|
||||||
|
return None
|
||||||
|
def load_data(file_path):
|
||||||
|
"""
|
||||||
|
Load data from a Parquet file into a list.
|
||||||
|
Each record in the Parquet file should represent an individual record.
|
||||||
|
"""
|
||||||
|
# 读取 Parquet 文件
|
||||||
|
# dataset = load_dataset('parquet', data_files=file_path)
|
||||||
|
data = []
|
||||||
|
ds = load_dataset(file_path)
|
||||||
|
df = pd.DataFrame(ds['train'])
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
data.append(row.to_dict())
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_score(pred, answer):
|
||||||
|
"""
|
||||||
|
Calculate scores between the prediction and the answer.
|
||||||
|
Uses ROUGE scores as the evaluation metric.
|
||||||
|
:param pred: The predicted string.
|
||||||
|
:param answer: The reference answer string.
|
||||||
|
:return: A dictionary containing ROUGE scores.
|
||||||
|
"""
|
||||||
|
if pred == answer:
|
||||||
|
return 1
|
||||||
|
# if we need to compare str with number, convert teh str to number
|
||||||
|
try:
|
||||||
|
pred = float(pred)
|
||||||
|
answer = float(answer)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if pred == answer:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def run_eval_api(
|
||||||
|
api_url: str,
|
||||||
|
model_name: str,
|
||||||
|
out_path: str,
|
||||||
|
format_tabs: bool = False,
|
||||||
|
auth_token: str = None,
|
||||||
|
problem_file: str = None,
|
||||||
|
append: bool = False
|
||||||
|
):
|
||||||
|
|
||||||
|
data = load_data(problem_file)
|
||||||
|
pbar = tqdm.tqdm(total=len(data) * 1)
|
||||||
|
|
||||||
|
for i in range(len(data)):
|
||||||
|
data_item = data[i]
|
||||||
|
question = data_item['Problem']
|
||||||
|
# Start the timer for this evaluation
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
completion = generate_text(api_url, question, model_name, auth_token=auth_token)
|
||||||
|
if completion is None:
|
||||||
|
raise Exception(f"Failed to get prediction for {question}")
|
||||||
|
answer = data_item['Answer']
|
||||||
|
score = get_score(completion, answer)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
result = {
|
||||||
|
"question_id": data_item["ID"],
|
||||||
|
"answer": answer,
|
||||||
|
"prediction": completion,
|
||||||
|
"score": score,
|
||||||
|
"time": elapsed_time
|
||||||
|
}
|
||||||
|
with open(out_path, "a" if append else "w") as f:
|
||||||
|
f.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to get prediction for {question}")
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append):
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||||
|
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||||
|
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-R1", help="Model Name")
|
||||||
|
parser.add_argument("--out_path", type=str, default="results/api/eval_aime.jsonl", help="Output Path")
|
||||||
|
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
|
||||||
|
parser.add_argument("--format_tabs", action="store_true", help="Format Tabs")
|
||||||
|
parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File")
|
||||||
|
parser.add_argument("--no_append", action="store_false", help="Append to existing file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# api_url = "https://api.siliconflow.cn/v1/chat/completions"
|
||||||
|
main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append)
|
10
ktransformers/tests/AIME_2024/evaluation.py
Normal file
10
ktransformers/tests/AIME_2024/evaluation.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35
|
||||||
|
def filter_answer(completion: str) -> str:
|
||||||
|
# the answer is the last part of the completion, it's a int64 number
|
||||||
|
# get the last line
|
||||||
|
completion = completion.strip().split("\n")[-1]
|
||||||
|
# handle the $\\boxed{...}$ format
|
||||||
|
if "$\\boxed{" in completion:
|
||||||
|
return completion.split("}")[0].split("{")[-1]
|
||||||
|
return completion.split()[-1]
|
||||||
|
|
2
ktransformers/tests/AIME_2024/prompts.py
Normal file
2
ktransformers/tests/AIME_2024/prompts.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
def instruct_prompt(prompt: str) -> str:
|
||||||
|
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nSolve the following math problem without any tests or explanation only one answer surrounede by '$\\boxed{{}}$'\n{prompt}\n\n### Response:"""
|
|
@ -13,7 +13,7 @@ def generate_text(api_url,question , model_name, stream=False, auth_token=None):
|
||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
# 添加 API Key
|
# 添加 API Key
|
||||||
'Authorization' : 'Bearer ' + auth_token
|
'Authorization' : 'Bearer ' + auth_token if auth_token else ''
|
||||||
}
|
}
|
||||||
question = instruct_prompt(question)
|
question = instruct_prompt(question)
|
||||||
data = {
|
data = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue