mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-26 10:31:03 +00:00
Revert "[FIX] Vllm guided decoding params (#3662)"
This reverts commit fb4f0fdf56.
This commit is contained in:
parent
fb4f0fdf56
commit
ba2897a318
51 changed files with 2649 additions and 2698 deletions
|
|
@ -22,10 +22,10 @@ def safe_remove_directory(path):
|
|||
|
||||
print("🔥 Loading the 16-bit merged model from disk...")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./gpt-oss-finetuned-merged",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./gpt-oss-finetuned-merged",
|
||||
max_seq_length = 1024,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
print("✅ Merged model loaded successfully.")
|
||||
|
||||
|
|
@ -36,14 +36,14 @@ messages = [
|
|||
]
|
||||
inputs = merged_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
reasoning_effort="low", # **NEW!** Set reasoning effort to low, medium or high
|
||||
add_generation_prompt = True,
|
||||
return_tensors = "pt",
|
||||
return_dict = True,
|
||||
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
|
||||
).to(merged_model.device)
|
||||
|
||||
_ = merged_model.generate(
|
||||
**inputs, max_new_tokens=512, streamer=TextStreamer(merged_tokenizer)
|
||||
**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer)
|
||||
)
|
||||
print("\n✅ Inference complete.")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -40,17 +40,17 @@ def formatting_prompts_func(examples):
|
|||
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
|
||||
max_seq_length = 1024
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
|
||||
"unsloth/gpt-oss-20b", max_seq_length = max_seq_length, load_in_4bit = True
|
||||
)
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
|
||||
formatting_prompts_func, batched=True
|
||||
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split = "train[:50]").map(
|
||||
formatting_prompts_func, batched = True
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=8,
|
||||
target_modules=[
|
||||
r = 8,
|
||||
target_modules = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
|
|
@ -59,22 +59,22 @@ model = FastLanguageModel.get_peft_model(
|
|||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
lora_alpha = 16,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset,
|
||||
args = SFTConfig(
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4,
|
||||
max_steps = 10,
|
||||
learning_rate = 2e-4,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ print("Fine-tuning complete.")
|
|||
# --- Merge and Save ---
|
||||
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer
|
||||
save_directory = "./gpt-oss-finetuned-merged", tokenizer = tokenizer
|
||||
)
|
||||
print("✅ Model merged and saved.")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -36,25 +36,25 @@ else:
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
# Load small dataset for quick training
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train[:100]"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train[:100]"
|
||||
)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
|
||||
print("✅ Base model loaded successfully!")
|
||||
|
||||
|
|
@ -64,8 +64,8 @@ print(f"{'='*80}")
|
|||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -74,40 +74,40 @@ model = FastLanguageModel.get_peft_model(
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=10, # Very short training for test
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=5,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 10, # Very short training for test
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 5,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -119,9 +119,9 @@ print("🔍 PHASE 3: Save with Forced 4bit Merge")
|
|||
print(f"{'='*80}")
|
||||
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./test_4bit_model",
|
||||
tokenizer=tokenizer,
|
||||
save_method="forced_merged_4bit",
|
||||
save_directory = "./test_4bit_model",
|
||||
tokenizer = tokenizer,
|
||||
save_method = "forced_merged_4bit",
|
||||
)
|
||||
|
||||
print("✅ Model saved with forced 4bit merge!")
|
||||
|
|
@ -137,15 +137,15 @@ torch.cuda.empty_cache()
|
|||
|
||||
# Load the 4bit merged model
|
||||
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
|
||||
model_name="./test_4bit_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./test_4bit_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
tokenizer_4bit = get_chat_template(
|
||||
tokenizer_4bit,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
print("✅ 4bit model loaded successfully!")
|
||||
|
|
@ -153,8 +153,8 @@ print("✅ 4bit model loaded successfully!")
|
|||
# Add LoRA adapters to the 4bit model
|
||||
model_4bit = FastLanguageModel.get_peft_model(
|
||||
model_4bit,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -163,39 +163,39 @@ model_4bit = FastLanguageModel.get_peft_model(
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
# Second fine-tuning
|
||||
trainer_4bit = SFTTrainer(
|
||||
model=model_4bit,
|
||||
tokenizer=tokenizer_4bit,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=10, # Very short training for test
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=5,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs_4bit",
|
||||
report_to="none",
|
||||
model = model_4bit,
|
||||
tokenizer = tokenizer_4bit,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer_4bit),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 10, # Very short training for test
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 5,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs_4bit",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -208,8 +208,8 @@ print(f"{'='*80}")
|
|||
|
||||
try:
|
||||
model_4bit.save_pretrained_merged(
|
||||
save_directory="./test_should_fail",
|
||||
tokenizer=tokenizer_4bit,
|
||||
save_directory = "./test_should_fail",
|
||||
tokenizer = tokenizer_4bit,
|
||||
# No save_method specified, should default to regular merge
|
||||
)
|
||||
assert False, "Expected TypeError but merge succeeded!"
|
||||
|
|
@ -225,9 +225,9 @@ print(f"{'='*80}")
|
|||
|
||||
try:
|
||||
model_4bit.save_pretrained_merged(
|
||||
save_directory="./test_4bit_second",
|
||||
tokenizer=tokenizer_4bit,
|
||||
save_method="forced_merged_4bit",
|
||||
save_directory = "./test_4bit_second",
|
||||
tokenizer = tokenizer_4bit,
|
||||
save_method = "forced_merged_4bit",
|
||||
)
|
||||
print("✅ Successfully saved 4bit model with forced 4bit method!")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -36,14 +36,14 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
"""Load model and compute perplexity in subprocess"""
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
|
|
@ -51,20 +51,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Load model
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
)
|
||||
# Set up tokenizer
|
||||
merged_tokenizer = get_chat_template(
|
||||
merged_tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
# Load dataset fresh in subprocess
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
# Format the dataset
|
||||
|
|
@ -72,13 +72,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
merged_tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
# Compute perplexity using the passed dataset
|
||||
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
|
||||
|
|
@ -104,7 +104,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Main execution code should be wrapped in this guard
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
|
|
@ -114,38 +114,38 @@ if __name__ == "__main__":
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.2-3B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Llama-3.2-3B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import standardize_sharegpt
|
||||
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train"
|
||||
)
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -154,40 +154,40 @@ if __name__ == "__main__":
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 10,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -195,8 +195,8 @@ if __name__ == "__main__":
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
|
||||
# run training
|
||||
|
|
@ -207,7 +207,7 @@ if __name__ == "__main__":
|
|||
# saving and merging the model to local disk
|
||||
print("merge and save to local disk")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./unsloth_out/merged_llama_text_model", tokenizer=tokenizer
|
||||
save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
# print("cleaning")
|
||||
|
|
@ -219,10 +219,10 @@ if __name__ == "__main__":
|
|||
# load model from local disk and test
|
||||
print("Loading merged model in 4 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
@ -231,7 +231,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Computing 8-bit model perplexity in subprocess...")
|
||||
result_queue = mp.Queue()
|
||||
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
|
||||
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -240,10 +240,10 @@ if __name__ == "__main__":
|
|||
|
||||
print("Loading merged model in 16 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = False,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
|
|||
|
|
@ -30,17 +30,17 @@ from tests.utils.perplexity_eval import (
|
|||
)
|
||||
|
||||
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
"""Load model and compute perplexity in subprocess"""
|
||||
from unsloth import FastLanguageModel
|
||||
from tests.utils.perplexity_eval import ppl_model
|
||||
|
||||
# Load model
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
model_name = "./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
)
|
||||
# Set up tokenizer
|
||||
# merged_tokenizer = get_chat_template(
|
||||
|
|
@ -50,7 +50,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Load dataset fresh in subprocess
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
|
@ -103,7 +103,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
"text": texts,
|
||||
}
|
||||
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
# Compute perplexity using the passed dataset
|
||||
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
|
||||
|
|
@ -129,7 +129,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Main execution code should be wrapped in this guard
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
|
|
@ -139,13 +139,13 @@ if __name__ == "__main__":
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b-v0.3",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/mistral-7b-v0.3",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
EOS_TOKEN = tokenizer.eos_token
|
||||
|
|
@ -200,21 +200,21 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train"
|
||||
)
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -223,39 +223,39 @@ if __name__ == "__main__":
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 200,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ if __name__ == "__main__":
|
|||
# saving and merging the model to local disk
|
||||
print("merge and save to local disk")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./unsloth_out/merged_mistral_text_model", tokenizer=tokenizer
|
||||
save_directory = "./unsloth_out/merged_mistral_text_model", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
# print("cleaning")
|
||||
|
|
@ -279,10 +279,10 @@ if __name__ == "__main__":
|
|||
# load model from local disk and test
|
||||
print("Loading merged model in 4 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
@ -291,7 +291,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Computing 8-bit model perplexity in subprocess...")
|
||||
result_queue = mp.Queue()
|
||||
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
|
||||
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -300,10 +300,10 @@ if __name__ == "__main__":
|
|||
|
||||
print("Loading merged model in 16 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_mistral_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = False,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -45,7 +45,7 @@ def formatting_prompts_func(examples):
|
|||
}
|
||||
|
||||
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
"""Load model and compute perplexity in subprocess"""
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
|
|
@ -53,20 +53,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Load model
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
model_name = "./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
)
|
||||
# Set up tokenizer
|
||||
merged_tokenizer = get_chat_template(
|
||||
merged_tokenizer,
|
||||
chat_template="phi-4",
|
||||
chat_template = "phi-4",
|
||||
)
|
||||
|
||||
# Load dataset fresh in subprocess
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
# Format the dataset
|
||||
|
|
@ -74,13 +74,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
merged_tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
# Compute perplexity using the passed dataset
|
||||
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
|
||||
|
|
@ -106,7 +106,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Main execution code should be wrapped in this guard
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
|
|
@ -116,36 +116,36 @@ if __name__ == "__main__":
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Phi-4",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Phi-4",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="phi-4",
|
||||
chat_template = "phi-4",
|
||||
)
|
||||
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train"
|
||||
)
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -154,40 +154,40 @@ if __name__ == "__main__":
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 200,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -195,8 +195,8 @@ if __name__ == "__main__":
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|im_start|>user<|im_sep|>\n\n",
|
||||
response_part="<|im_start|>assistant<|im_sep|>\n\n",
|
||||
instruction_part = "<|im_start|>user<|im_sep|>\n\n",
|
||||
response_part = "<|im_start|>assistant<|im_sep|>\n\n",
|
||||
)
|
||||
|
||||
# run training
|
||||
|
|
@ -207,7 +207,7 @@ if __name__ == "__main__":
|
|||
# saving and merging the model to local disk
|
||||
print("merge and save to local disk")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./unsloth_out/merged_phi4_text_model", tokenizer=tokenizer
|
||||
save_directory = "./unsloth_out/merged_phi4_text_model", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
# print("cleaning")
|
||||
|
|
@ -219,10 +219,10 @@ if __name__ == "__main__":
|
|||
# load model from local disk and test
|
||||
print("Loading merged model in 4 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
@ -231,7 +231,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Computing 8-bit model perplexity in subprocess...")
|
||||
result_queue = mp.Queue()
|
||||
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
|
||||
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -240,10 +240,10 @@ if __name__ == "__main__":
|
|||
|
||||
print("Loading merged model in 16 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_phi4_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = False,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
|
|||
|
|
@ -35,14 +35,14 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
"""Load model and compute perplexity in subprocess"""
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
|
|
@ -50,20 +50,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Load model
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
)
|
||||
# Set up tokenizer
|
||||
merged_tokenizer = get_chat_template(
|
||||
merged_tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
# Load dataset fresh in subprocess
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
# Format the dataset
|
||||
|
|
@ -71,13 +71,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
merged_tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
# Compute perplexity using the passed dataset
|
||||
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
|
||||
|
|
@ -103,7 +103,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Main execution code should be wrapped in this guard
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
|
|
@ -113,31 +113,31 @@ if __name__ == "__main__":
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import standardize_sharegpt
|
||||
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train"
|
||||
)
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
print("\n dataset sample [0]")
|
||||
print(dataset_train[0])
|
||||
|
|
@ -146,8 +146,8 @@ if __name__ == "__main__":
|
|||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -156,40 +156,40 @@ if __name__ == "__main__":
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 200,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -197,8 +197,8 @@ if __name__ == "__main__":
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
|
||||
tokenizer.decode(trainer.train_dataset[0]["input_ids"])
|
||||
|
|
@ -211,7 +211,7 @@ if __name__ == "__main__":
|
|||
# saving and merging the model to local disk
|
||||
print("merge and save to local disk")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./unsloth_out/merged_llama_text_model", tokenizer=tokenizer
|
||||
save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
# print("cleaning")
|
||||
|
|
@ -223,10 +223,10 @@ if __name__ == "__main__":
|
|||
# load model from local disk and test
|
||||
print("Loading merged model in 4 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
@ -235,7 +235,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Computing 8-bit model perplexity in subprocess...")
|
||||
result_queue = mp.Queue()
|
||||
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
|
||||
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -244,10 +244,10 @@ if __name__ == "__main__":
|
|||
|
||||
print("Loading merged model in 16 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_llama_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = False,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ print("🔍 PHASE 1: Loading Base Model")
|
|||
print(f"{'='*80}")
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b-v0.3",
|
||||
max_seq_length=2048,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
model_name = "unsloth/mistral-7b-v0.3",
|
||||
max_seq_length = 2048,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ print(f"\n{'='*80}")
|
|||
print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
|
||||
print(f"{'='*80}")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with warnings.catch_warnings(record = True) as w:
|
||||
warnings.simplefilter("always")
|
||||
model.save_pretrained_merged("test_output", tokenizer)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ print("🔍 PHASE 1: Loading Base Model")
|
|||
print(f"{'='*80}")
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name="unsloth/whisper-large-v3",
|
||||
dtype=None, # Leave as None for auto detection
|
||||
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model=WhisperForConditionalGeneration,
|
||||
whisper_language="English",
|
||||
whisper_task="transcribe",
|
||||
model_name = "unsloth/whisper-large-v3",
|
||||
dtype = None, # Leave as None for auto detection
|
||||
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model = WhisperForConditionalGeneration,
|
||||
whisper_language = "English",
|
||||
whisper_task = "transcribe",
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ print(f"\n{'='*80}")
|
|||
print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
|
||||
print(f"{'='*80}")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with warnings.catch_warnings(record = True) as w:
|
||||
warnings.simplefilter("always")
|
||||
model.save_pretrained_merged("test_output", tokenizer)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/orpheus-3b-0.1-ft",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
dtype=None, # Select None for auto detection
|
||||
load_in_4bit=False, # Select True for 4bit which reduces memory usage
|
||||
model_name = "unsloth/orpheus-3b-0.1-ft",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
dtype = None, # Select None for auto detection
|
||||
load_in_4bit = False, # Select True for 4bit which reduces memory usage
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -42,8 +42,8 @@ base_model_class = model.__class__.__name__
|
|||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=[
|
||||
r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
|
|
@ -52,14 +52,14 @@ model = FastLanguageModel.get_peft_model(
|
|||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=64,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
lora_alpha = 64,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
print("✅ Model and LoRA adapters loaded successfully!")
|
||||
|
||||
|
|
@ -112,10 +112,10 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/orpheus-3b-0.1-ft",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
dtype=None, # Select None for auto detection
|
||||
load_in_4bit=False, # Select True for 4bit which reduces memory usage
|
||||
model_name = "unsloth/orpheus-3b-0.1-ft",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
dtype = None, # Select None for auto detection
|
||||
load_in_4bit = False, # Select True for 4bit which reduces memory usage
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -148,18 +148,18 @@ prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]
|
|||
all_input_ids = []
|
||||
|
||||
for prompt in prompts_:
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = tokenizer(prompt, return_tensors = "pt").input_ids
|
||||
all_input_ids.append(input_ids)
|
||||
|
||||
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
||||
start_token = torch.tensor([[128259]], dtype = torch.int64) # Start of human
|
||||
end_tokens = torch.tensor(
|
||||
[[128009, 128260]], dtype=torch.int64
|
||||
[[128009, 128260]], dtype = torch.int64
|
||||
) # End of text, End of human
|
||||
|
||||
all_modified_input_ids = []
|
||||
for input_ids in all_input_ids:
|
||||
modified_input_ids = torch.cat(
|
||||
[start_token, input_ids, end_tokens], dim=1
|
||||
[start_token, input_ids, end_tokens], dim = 1
|
||||
) # SOH SOT Text EOT EOH
|
||||
all_modified_input_ids.append(modified_input_ids)
|
||||
|
||||
|
|
@ -171,39 +171,39 @@ max_length = max(
|
|||
for modified_input_ids in all_modified_input_ids:
|
||||
padding = max_length - modified_input_ids.shape[1]
|
||||
padded_tensor = torch.cat(
|
||||
[torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1
|
||||
[torch.full((1, padding), 128263, dtype = torch.int64), modified_input_ids], dim = 1
|
||||
)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
torch.zeros((1, padding), dtype=torch.int64),
|
||||
torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64),
|
||||
torch.zeros((1, padding), dtype = torch.int64),
|
||||
torch.ones((1, modified_input_ids.shape[1]), dtype = torch.int64),
|
||||
],
|
||||
dim=1,
|
||||
dim = 1,
|
||||
)
|
||||
all_padded_tensors.append(padded_tensor)
|
||||
all_attention_masks.append(attention_mask)
|
||||
|
||||
all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
|
||||
all_attention_masks = torch.cat(all_attention_masks, dim=0)
|
||||
all_padded_tensors = torch.cat(all_padded_tensors, dim = 0)
|
||||
all_attention_masks = torch.cat(all_attention_masks, dim = 0)
|
||||
|
||||
input_ids = all_padded_tensors.to("cuda")
|
||||
attention_mask = all_attention_masks.to("cuda")
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=1200,
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1,
|
||||
eos_token_id=128258,
|
||||
use_cache=True,
|
||||
input_ids = input_ids,
|
||||
attention_mask = attention_mask,
|
||||
max_new_tokens = 1200,
|
||||
do_sample = True,
|
||||
temperature = 0.6,
|
||||
top_p = 0.95,
|
||||
repetition_penalty = 1.1,
|
||||
num_return_sequences = 1,
|
||||
eos_token_id = 128258,
|
||||
use_cache = True,
|
||||
)
|
||||
token_to_find = 128257
|
||||
token_to_remove = 128258
|
||||
|
||||
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
|
||||
token_indices = (generated_ids == token_to_find).nonzero(as_tuple = True)
|
||||
|
||||
if len(token_indices[1]) > 0:
|
||||
last_occurrence_idx = token_indices[1][-1].item()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from tests.utils.ocr_eval import OCRModelEvaluator
|
|||
## Dataset Preparation
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
|
||||
# To select the first 2000 examples
|
||||
train_dataset = dataset.select(range(2000))
|
||||
|
||||
|
|
@ -81,39 +81,39 @@ model_comparison_results = {}
|
|||
# Load Base Model
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name="unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
load_in_4bit=True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning=False, # [NEW!] We have full finetuning now!
|
||||
model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
)
|
||||
|
||||
# benchmark base model performance
|
||||
model_name = "Unsloth Base model"
|
||||
FastVisionModel.for_inference(model)
|
||||
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
||||
model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results"
|
||||
model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
## Lora Finetuning
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=True, # Turn off for just text!
|
||||
finetune_language_layers=True, # Should leave on!
|
||||
finetune_attention_modules=True, # Attention good for GRPO
|
||||
finetune_mlp_modules=True, # SHould leave on always!
|
||||
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
finetune_vision_layers = True, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # SHould leave on always!
|
||||
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
# "gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha=32,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
lora_alpha = 32,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
|
||||
from unsloth import is_bf16_supported
|
||||
|
|
@ -124,40 +124,40 @@ model.config.use_cache = False
|
|||
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset=train_dataset,
|
||||
args=SFTConfig(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset = train_dataset,
|
||||
args = SFTConfig(
|
||||
# per_device_train_batch_size = 4,
|
||||
# gradient_accumulation_steps = 8,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
gradient_checkpointing = True,
|
||||
gradient_checkpointing_kwargs = {
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03,
|
||||
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
max_steps=60,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bf16_supported(),
|
||||
bf16=is_bf16_supported(),
|
||||
logging_steps=5,
|
||||
save_strategy="epoch",
|
||||
optim="adamw_torch_fused",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
|
||||
report_to="none", # For Weights and Biases
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bf16_supported(),
|
||||
bf16 = is_bf16_supported(),
|
||||
logging_steps = 5,
|
||||
save_strategy = "epoch",
|
||||
optim = "adamw_torch_fused",
|
||||
weight_decay = 0.01,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
|
||||
report_to = "none", # For Weights and Biases
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns=False,
|
||||
dataset_text_field="",
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
dataset_num_proc=4,
|
||||
max_seq_length=2048,
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
dataset_num_proc = 4,
|
||||
max_seq_length = 2048,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ tokenizer.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter")
|
|||
model_name = "Unsloth lora adapter model"
|
||||
FastVisionModel.for_inference(model)
|
||||
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
||||
model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results"
|
||||
model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ print((base.__class__.__name__))
|
|||
|
||||
# merge default 16 bits
|
||||
model.save_pretrained_merged(
|
||||
save_directory="qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer
|
||||
save_directory = "qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ model.save_pretrained_merged(
|
|||
### 16 bits merged model
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=False
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -215,13 +215,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_16bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_16bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
# load 16bits-merged model in 4 bits
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=True, load_in_8bit=False
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -232,13 +232,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_4bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_4bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
# load model in 8 bits
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=True
|
||||
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -247,7 +247,7 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_8bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_8bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from tests.utils.ocr_eval import OCRModelEvaluator
|
|||
## Dataset Preparation
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
|
||||
# To select the first 2000 examples
|
||||
train_dataset = dataset.select(range(2000))
|
||||
|
||||
|
|
@ -81,39 +81,39 @@ model_comparison_results = {}
|
|||
# Load Base Model
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name="unsloth/Qwen2-VL-7B-Instruct",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
load_in_4bit=True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning=False, # [NEW!] We have full finetuning now!
|
||||
model_name = "unsloth/Qwen2-VL-7B-Instruct",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
)
|
||||
|
||||
# benchmark base model performance
|
||||
model_name = "Unsloth Base model"
|
||||
FastVisionModel.for_inference(model)
|
||||
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
||||
model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results"
|
||||
model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
## Lora Finetuning
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=True, # Turn off for just text!
|
||||
finetune_language_layers=True, # Should leave on!
|
||||
finetune_attention_modules=True, # Attention good for GRPO
|
||||
finetune_mlp_modules=True, # SHould leave on always!
|
||||
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
finetune_vision_layers = True, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # SHould leave on always!
|
||||
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
# "gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha=32,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
lora_alpha = 32,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
|
||||
from unsloth import is_bf16_supported
|
||||
|
|
@ -124,40 +124,40 @@ model.config.use_cache = False
|
|||
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset=train_dataset,
|
||||
args=SFTConfig(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset = train_dataset,
|
||||
args = SFTConfig(
|
||||
# per_device_train_batch_size = 4,
|
||||
# gradient_accumulation_steps = 8,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
gradient_checkpointing = True,
|
||||
gradient_checkpointing_kwargs = {
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03,
|
||||
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
max_steps=60,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bf16_supported(),
|
||||
bf16=is_bf16_supported(),
|
||||
logging_steps=5,
|
||||
save_strategy="epoch",
|
||||
optim="adamw_torch_fused",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="unsloth-qwen2-7vl-french-ocr-checkpoints",
|
||||
report_to="none", # For Weights and Biases
|
||||
max_steps = 60,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bf16_supported(),
|
||||
bf16 = is_bf16_supported(),
|
||||
logging_steps = 5,
|
||||
save_strategy = "epoch",
|
||||
optim = "adamw_torch_fused",
|
||||
weight_decay = 0.01,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "unsloth-qwen2-7vl-french-ocr-checkpoints",
|
||||
report_to = "none", # For Weights and Biases
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns=False,
|
||||
dataset_text_field="",
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
dataset_num_proc=4,
|
||||
max_seq_length=2048,
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
dataset_num_proc = 4,
|
||||
max_seq_length = 2048,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ tokenizer.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter")
|
|||
model_name = "Unsloth lora adapter model"
|
||||
FastVisionModel.for_inference(model)
|
||||
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
||||
model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results"
|
||||
model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ print((base.__class__.__name__))
|
|||
|
||||
# merge default 16 bits
|
||||
model.save_pretrained_merged(
|
||||
save_directory="qwen2-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer
|
||||
save_directory = "qwen2-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ model.save_pretrained_merged(
|
|||
### 16 bits merged model
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=False
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -215,13 +215,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_16bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_16bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
# load 16bits-merged model in 4 bits
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=True, load_in_8bit=False
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -232,13 +232,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_4bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_4bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
# load model in 8 bits
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=True
|
||||
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
|
||||
)
|
||||
|
||||
# benchmark 4bit loaded, 16bits merged model performance
|
||||
|
|
@ -247,7 +247,7 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
|
|||
model,
|
||||
tokenizer,
|
||||
eval_dataset,
|
||||
output_dir="unsloth_16bits_merged_model_load_8bits_results",
|
||||
output_dir = "unsloth_16bits_merged_model_load_8bits_results",
|
||||
)
|
||||
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ TestParams = [
|
|||
|
||||
|
||||
# Test that model registration methods register respective models
|
||||
@pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name)
|
||||
@pytest.mark.parametrize("model_test_param", TestParams, ids = lambda param: param.name)
|
||||
def test_model_registration(model_test_param: ModelTestParam):
|
||||
MODEL_REGISTRY.clear()
|
||||
registration_method = model_test_param.register_models
|
||||
|
|
@ -86,7 +86,7 @@ def test_all_model_registration():
|
|||
def test_quant_type():
|
||||
# Test that the quant_type is correctly set for model paths
|
||||
# NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH
|
||||
dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH])
|
||||
dynamic_quant_models = search_models(quant_types = [QuantType.UNSLOTH])
|
||||
assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)
|
||||
quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]
|
||||
assert all(quant_tag in m.model_path for m in dynamic_quant_models)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def timer(name):
|
|||
|
||||
|
||||
@contextmanager
|
||||
def header_footer_context(title: str, char="-"):
|
||||
def header_footer_context(title: str, char = "-"):
|
||||
print()
|
||||
print(f"{char}" * 50 + f" {title} " + f"{char}" * 50)
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
import warnings
|
||||
|
||||
|
||||
def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
|
||||
def clear_memory(variables_to_clear = None, verbose = False, clear_all_caches = True):
|
||||
"""
|
||||
Comprehensive memory clearing for persistent memory leaks.
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
|
|||
logger.setLevel(level)
|
||||
|
||||
|
||||
def clear_all_lru_caches(verbose=True):
|
||||
def clear_all_lru_caches(verbose = True):
|
||||
"""Clear all LRU caches in loaded modules."""
|
||||
cleared_caches = []
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ def monitor_cache_sizes():
|
|||
except:
|
||||
pass
|
||||
|
||||
return sorted(cache_info, key=lambda x: x["size"], reverse=True)
|
||||
return sorted(cache_info, key = lambda x: x["size"], reverse = True)
|
||||
|
||||
|
||||
def safe_remove_directory(path):
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = N
|
|||
dataset = create_instruction_dataset(messages)
|
||||
|
||||
def _apply_chat_template(example):
|
||||
chat = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
chat = tokenizer.apply_chat_template(example["messages"], tokenize = False)
|
||||
return {"text": chat}
|
||||
|
||||
dataset = dataset.map(_apply_chat_template, remove_columns="messages")
|
||||
dataset = dataset.map(_apply_chat_template, remove_columns = "messages")
|
||||
if num_examples is not None:
|
||||
if len(dataset) < num_examples:
|
||||
num_repeats = num_examples // len(dataset) + 1
|
||||
|
|
@ -139,11 +139,11 @@ def get_peft_weights(model):
|
|||
|
||||
def describe_peft_weights(model):
|
||||
for name, param in get_peft_weights(model).items():
|
||||
yield name, describe_param(param, as_str=True)
|
||||
yield name, describe_param(param, as_str = True)
|
||||
|
||||
|
||||
def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool:
|
||||
for i, response in enumerate(responses, start=1):
|
||||
for i, response in enumerate(responses, start = 1):
|
||||
if answer in response:
|
||||
print(f"\u2713 response {i} contains answer")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ class OCRModelEvaluator:
|
|||
Evaluate a model on an OCR dataset.
|
||||
"""
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(output_dir, exist_ok = True)
|
||||
|
||||
# Initialize results storage
|
||||
results = []
|
||||
|
||||
# Process each sample in the dataset
|
||||
for i, sample in enumerate(
|
||||
tqdm(dataset, desc="Evaluating OCR performance", disable=not verbose)
|
||||
tqdm(dataset, desc = "Evaluating OCR performance", disable = not verbose)
|
||||
):
|
||||
try:
|
||||
# Extract components from sample
|
||||
|
|
@ -187,7 +187,7 @@ class OCRModelEvaluator:
|
|||
|
||||
# Preparation for inference using Qwen's specific processing
|
||||
text = processor.apply_chat_template(
|
||||
input_messages, tokenize=False, add_generation_prompt=True
|
||||
input_messages, tokenize = False, add_generation_prompt = True
|
||||
)
|
||||
|
||||
# Process vision info (images/videos) from messages
|
||||
|
|
@ -195,11 +195,11 @@ class OCRModelEvaluator:
|
|||
|
||||
# Create model inputs
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
text = [text],
|
||||
images = image_inputs,
|
||||
videos = video_inputs,
|
||||
padding = True,
|
||||
return_tensors = "pt",
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
|
|
@ -207,10 +207,10 @@ class OCRModelEvaluator:
|
|||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
min_p=min_p,
|
||||
use_cache=True,
|
||||
max_new_tokens = max_new_tokens,
|
||||
temperature = temperature,
|
||||
min_p = min_p,
|
||||
use_cache = True,
|
||||
)
|
||||
|
||||
# Extract only the generated part (not the input)
|
||||
|
|
@ -222,8 +222,8 @@ class OCRModelEvaluator:
|
|||
# Decode the generated text
|
||||
generated_response = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens = True,
|
||||
clean_up_tokenization_spaces = False,
|
||||
)[0]
|
||||
|
||||
return generated_response
|
||||
|
|
@ -240,7 +240,7 @@ class OCRModelEvaluator:
|
|||
):
|
||||
"""Save individual sample result to file."""
|
||||
output_file = os.path.join(output_dir, f"sample_{sample_idx}.txt")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
with open(output_file, "w", encoding = "utf-8") as f:
|
||||
f.write(f"Sample {sample_idx}\n")
|
||||
f.write(f"Question: {question}\n\n")
|
||||
f.write(f"Model output:\n{generated_response.strip()}\n\n")
|
||||
|
|
@ -268,7 +268,7 @@ class OCRModelEvaluator:
|
|||
f.write(f"Average CER: {avg_cer:.4f}\n")
|
||||
|
||||
# Save detailed results
|
||||
df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index=False)
|
||||
df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index = False)
|
||||
|
||||
if verbose:
|
||||
print("\nResults Summary:")
|
||||
|
|
@ -310,12 +310,12 @@ class OCRModelEvaluator:
|
|||
|
||||
# Display the comparison table
|
||||
print("\nComparison Table (sorted by WER):")
|
||||
print(comparison_df.to_string(index=False))
|
||||
print(comparison_df.to_string(index = False))
|
||||
|
||||
# Save the comparison table
|
||||
if save_csv:
|
||||
comparison_file = "model_comparison_results.csv"
|
||||
comparison_df.to_csv(comparison_file, index=False)
|
||||
comparison_df.to_csv(comparison_file, index = False)
|
||||
print(f"\nComparison table saved to {comparison_file}")
|
||||
|
||||
# Generate a bar chart visualization
|
||||
|
|
@ -326,23 +326,23 @@ class OCRModelEvaluator:
|
|||
|
||||
def _create_comparison_plot(self, comparison_df: pd.DataFrame):
|
||||
"""Create and save comparison plot."""
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.figure(figsize = (12, 6))
|
||||
|
||||
# Plot WER
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.bar(comparison_df["Model"], comparison_df["WER"], color="skyblue")
|
||||
plt.bar(comparison_df["Model"], comparison_df["WER"], color = "skyblue")
|
||||
plt.title("Word Error Rate Comparison")
|
||||
plt.ylabel("WER (lower is better)")
|
||||
plt.ylim(bottom=0)
|
||||
plt.xticks(rotation=45, ha="right")
|
||||
plt.ylim(bottom = 0)
|
||||
plt.xticks(rotation = 45, ha = "right")
|
||||
|
||||
# Plot CER
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.bar(comparison_df["Model"], comparison_df["CER"], color="lightgreen")
|
||||
plt.bar(comparison_df["Model"], comparison_df["CER"], color = "lightgreen")
|
||||
plt.title("Character Error Rate Comparison")
|
||||
plt.ylabel("CER (lower is better)")
|
||||
plt.ylim(bottom=0)
|
||||
plt.xticks(rotation=45, ha="right")
|
||||
plt.ylim(bottom = 0)
|
||||
plt.xticks(rotation = 45, ha = "right")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("ocr_model_comparison.png")
|
||||
|
|
@ -360,7 +360,7 @@ class OCRModelEvaluator:
|
|||
|
||||
|
||||
def evaluate_ocr_model(
|
||||
model, processor, dataset, output_dir="ocr_evaluation_results", **kwargs
|
||||
model, processor, dataset, output_dir = "ocr_evaluation_results", **kwargs
|
||||
):
|
||||
"""
|
||||
Convenience function that maintains backward compatibility with the original function.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def detect_package_manager():
|
|||
return None
|
||||
|
||||
|
||||
def check_package_installed(package_name, package_manager=None):
|
||||
def check_package_installed(package_name, package_manager = None):
|
||||
"""Check if a package is installed using the system package manager"""
|
||||
|
||||
if package_manager is None:
|
||||
|
|
@ -35,26 +35,26 @@ def check_package_installed(package_name, package_manager=None):
|
|||
if package_manager == "apt":
|
||||
# Check with dpkg
|
||||
result = subprocess.run(
|
||||
["dpkg", "-l", package_name], capture_output=True, text=True
|
||||
["dpkg", "-l", package_name], capture_output = True, text = True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
elif package_manager in ["yum", "dnf"]:
|
||||
# Check with rpm
|
||||
result = subprocess.run(
|
||||
["rpm", "-q", package_name], capture_output=True, text=True
|
||||
["rpm", "-q", package_name], capture_output = True, text = True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
elif package_manager == "pacman":
|
||||
result = subprocess.run(
|
||||
["pacman", "-Q", package_name], capture_output=True, text=True
|
||||
["pacman", "-Q", package_name], capture_output = True, text = True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
elif package_manager == "zypper":
|
||||
result = subprocess.run(
|
||||
["zypper", "se", "-i", package_name], capture_output=True, text=True
|
||||
["zypper", "se", "-i", package_name], capture_output = True, text = True
|
||||
)
|
||||
return package_name in result.stdout
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ def check_package_installed(package_name, package_manager=None):
|
|||
return None
|
||||
|
||||
|
||||
def require_package(package_name, executable_name=None):
|
||||
def require_package(package_name, executable_name = None):
|
||||
"""Require a package to be installed, exit if not found"""
|
||||
|
||||
# First check if executable is in PATH (most reliable)
|
||||
|
|
@ -109,7 +109,7 @@ def require_package(package_name, executable_name=None):
|
|||
# require_package("ffmpeg", "ffmpeg")
|
||||
|
||||
|
||||
def require_python_package(package_name, import_name=None, pip_name=None):
|
||||
def require_python_package(package_name, import_name = None, pip_name = None):
|
||||
"""Require a Python package to be installed, exit if not found"""
|
||||
if import_name is None:
|
||||
import_name = package_name
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ def ppl_model(model, tokenizer, dataset):
|
|||
max_length = 2048
|
||||
stride = 512
|
||||
for s in tqdm(range(len(dataset["text"]))):
|
||||
encodings = tokenizer(dataset["text"][s], return_tensors="pt")
|
||||
encodings = tokenizer(dataset["text"][s], return_tensors = "pt")
|
||||
seq_len = encodings.input_ids.size(1)
|
||||
prev_end_loc = 0
|
||||
for begin_loc in range(0, seq_len, stride):
|
||||
|
|
@ -28,7 +28,7 @@ def ppl_model(model, tokenizer, dataset):
|
|||
attention_mask = (input_ids != pad_token_id).long()
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
input_ids, labels=target_ids, attention_mask=attention_mask
|
||||
input_ids, labels = target_ids, attention_mask = attention_mask
|
||||
)
|
||||
neg_log_likelihood = outputs.loss
|
||||
nlls.append(neg_log_likelihood)
|
||||
|
|
@ -78,4 +78,4 @@ def print_model_comparison():
|
|||
|
||||
# Display the comparison table
|
||||
print("\nComparison Table:")
|
||||
print(comparison_df.to_string(index=False))
|
||||
print(comparison_df.to_string(index = False))
|
||||
|
|
|
|||
|
|
@ -32,15 +32,15 @@ def _get_model(qat_scheme: str, full_finetuning: bool):
|
|||
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
|
||||
"""
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Qwen3-1.7B",
|
||||
load_in_4bit=False,
|
||||
full_finetuning=full_finetuning,
|
||||
qat_scheme=qat_scheme if full_finetuning else None,
|
||||
model_name = "unsloth/Qwen3-1.7B",
|
||||
load_in_4bit = False,
|
||||
full_finetuning = full_finetuning,
|
||||
qat_scheme = qat_scheme if full_finetuning else None,
|
||||
)
|
||||
if not full_finetuning:
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
qat_scheme=qat_scheme,
|
||||
qat_scheme = qat_scheme,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
|
@ -140,7 +140,7 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
|
|||
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
|
||||
inputs = tokenizer("How are you?", return_tensors="pt")
|
||||
inputs = tokenizer("How are you?", return_tensors = "pt")
|
||||
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
|
||||
|
||||
|
||||
|
|
@ -148,9 +148,9 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
|
|||
# how to disable model caching before re-enabling this test
|
||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
|
||||
def _test_full_model_fake_quantize(qat_scheme: bool):
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
|
||||
def test_lora_model_fake_quantize(qat_scheme: bool):
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning=False)
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning = False)
|
||||
|
|
|
|||
254
unsloth-cli.py
254
unsloth-cli.py
|
|
@ -47,17 +47,17 @@ def run(args):
|
|||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=args.model_name,
|
||||
max_seq_length=args.max_seq_length,
|
||||
dtype=args.dtype,
|
||||
load_in_4bit=args.load_in_4bit,
|
||||
model_name = args.model_name,
|
||||
max_seq_length = args.max_seq_length,
|
||||
dtype = args.dtype,
|
||||
load_in_4bit = args.load_in_4bit,
|
||||
)
|
||||
|
||||
# Configure PEFT model
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=args.r,
|
||||
target_modules=[
|
||||
r = args.r,
|
||||
target_modules = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
|
|
@ -66,13 +66,13 @@ def run(args):
|
|||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
bias=args.bias,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
random_state=args.random_state,
|
||||
use_rslora=args.use_rslora,
|
||||
loftq_config=args.loftq_config,
|
||||
lora_alpha = args.lora_alpha,
|
||||
lora_dropout = args.lora_dropout,
|
||||
bias = args.bias,
|
||||
use_gradient_checkpointing = args.use_gradient_checkpointing,
|
||||
random_state = args.random_state,
|
||||
use_rslora = args.use_rslora,
|
||||
loftq_config = args.loftq_config,
|
||||
)
|
||||
|
||||
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
|
@ -102,40 +102,40 @@ def run(args):
|
|||
if use_modelscope:
|
||||
from modelscope import MsDataset
|
||||
|
||||
dataset = MsDataset.load(args.dataset, split="train")
|
||||
dataset = MsDataset.load(args.dataset, split = "train")
|
||||
else:
|
||||
# Load and format dataset
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
dataset = load_dataset(args.dataset, split = "train")
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
print("Data is formatted and ready!")
|
||||
|
||||
# Configure training arguments
|
||||
training_args = SFTConfig(
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
warmup_steps=args.warmup_steps,
|
||||
max_steps=args.max_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=args.logging_steps,
|
||||
optim=args.optim,
|
||||
weight_decay=args.weight_decay,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
seed=args.seed,
|
||||
output_dir=args.output_dir,
|
||||
report_to=args.report_to,
|
||||
max_length=args.max_seq_length,
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
per_device_train_batch_size = args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps = args.gradient_accumulation_steps,
|
||||
warmup_steps = args.warmup_steps,
|
||||
max_steps = args.max_steps,
|
||||
learning_rate = args.learning_rate,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = args.logging_steps,
|
||||
optim = args.optim,
|
||||
weight_decay = args.weight_decay,
|
||||
lr_scheduler_type = args.lr_scheduler_type,
|
||||
seed = args.seed,
|
||||
output_dir = args.output_dir,
|
||||
report_to = args.report_to,
|
||||
max_length = args.max_seq_length,
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
model = model,
|
||||
processing_class = tokenizer,
|
||||
train_dataset = dataset,
|
||||
args = training_args,
|
||||
)
|
||||
|
||||
# Train model
|
||||
|
|
@ -153,24 +153,24 @@ def run(args):
|
|||
model.save_pretrained_gguf(
|
||||
args.save_path,
|
||||
tokenizer,
|
||||
quantization_method=quantization_method,
|
||||
quantization_method = quantization_method,
|
||||
)
|
||||
if args.push_model:
|
||||
model.push_to_hub_gguf(
|
||||
hub_path=args.hub_path,
|
||||
hub_token=args.hub_token,
|
||||
quantization_method=quantization_method,
|
||||
hub_path = args.hub_path,
|
||||
hub_token = args.hub_token,
|
||||
quantization_method = quantization_method,
|
||||
)
|
||||
else:
|
||||
print(f"Saving model with quantization method: {args.quantization}")
|
||||
model.save_pretrained_gguf(
|
||||
args.save_path, tokenizer, quantization_method=args.quantization
|
||||
args.save_path, tokenizer, quantization_method = args.quantization
|
||||
)
|
||||
if args.push_model:
|
||||
model.push_to_hub_gguf(
|
||||
hub_path=args.hub_path,
|
||||
hub_token=args.hub_token,
|
||||
quantization_method=quantization_method,
|
||||
hub_path = args.hub_path,
|
||||
hub_token = args.hub_token,
|
||||
quantization_method = quantization_method,
|
||||
)
|
||||
else:
|
||||
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
|
||||
|
|
@ -183,38 +183,38 @@ def run(args):
|
|||
if __name__ == "__main__":
|
||||
# Define argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="🦥 Fine-tune your llm faster using unsloth!"
|
||||
description = "🦥 Fine-tune your llm faster using unsloth!"
|
||||
)
|
||||
|
||||
model_group = parser.add_argument_group("🤖 Model Options")
|
||||
model_group.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="unsloth/llama-3-8b",
|
||||
help="Model name to load",
|
||||
type = str,
|
||||
default = "unsloth/llama-3-8b",
|
||||
help = "Model name to load",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max_seq_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
|
||||
type = int,
|
||||
default = 2048,
|
||||
help = "Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Data type for model (None for auto detection)",
|
||||
type = str,
|
||||
default = None,
|
||||
help = "Data type for model (None for auto detection)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--load_in_4bit",
|
||||
action="store_true",
|
||||
help="Use 4bit quantization to reduce memory usage",
|
||||
action = "store_true",
|
||||
help = "Use 4bit quantization to reduce memory usage",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="yahma/alpaca-cleaned",
|
||||
help="Huggingface dataset to use for training",
|
||||
type = str,
|
||||
default = "yahma/alpaca-cleaned",
|
||||
help = "Huggingface dataset to use for training",
|
||||
)
|
||||
|
||||
lora_group = parser.add_argument_group(
|
||||
|
|
@ -222,101 +222,101 @@ if __name__ == "__main__":
|
|||
)
|
||||
lora_group.add_argument(
|
||||
"--r",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)",
|
||||
type = int,
|
||||
default = 16,
|
||||
help = "Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)",
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=16,
|
||||
help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
|
||||
type = int,
|
||||
default = 16,
|
||||
help = "LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--lora_dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="LoRA dropout rate, default is 0.0 which is optimized.",
|
||||
type = float,
|
||||
default = 0.0,
|
||||
help = "LoRA dropout rate, default is 0.0 which is optimized.",
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--bias", type=str, default="none", help="Bias setting for LoRA"
|
||||
"--bias", type = str, default = "none", help = "Bias setting for LoRA"
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
type=str,
|
||||
default="unsloth",
|
||||
help="Use gradient checkpointing",
|
||||
type = str,
|
||||
default = "unsloth",
|
||||
help = "Use gradient checkpointing",
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--random_state",
|
||||
type=int,
|
||||
default=3407,
|
||||
help="Random state for reproducibility, default is 3407.",
|
||||
type = int,
|
||||
default = 3407,
|
||||
help = "Random state for reproducibility, default is 3407.",
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--use_rslora", action="store_true", help="Use rank stabilized LoRA"
|
||||
"--use_rslora", action = "store_true", help = "Use rank stabilized LoRA"
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--loftq_config", type=str, default=None, help="Configuration for LoftQ"
|
||||
"--loftq_config", type = str, default = None, help = "Configuration for LoftQ"
|
||||
)
|
||||
|
||||
training_group = parser.add_argument_group("🎓 Training Options")
|
||||
training_group.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Batch size per device during training, default is 2.",
|
||||
type = int,
|
||||
default = 2,
|
||||
help = "Batch size per device during training, default is 2.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of gradient accumulation steps, default is 4.",
|
||||
type = int,
|
||||
default = 4,
|
||||
help = "Number of gradient accumulation steps, default is 4.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--warmup_steps",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of warmup steps, default is 5.",
|
||||
type = int,
|
||||
default = 5,
|
||||
help = "Number of warmup steps, default is 5.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--max_steps", type=int, default=400, help="Maximum number of training steps."
|
||||
"--max_steps", type = int, default = 400, help = "Maximum number of training steps."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=2e-4,
|
||||
help="Learning rate, default is 2e-4.",
|
||||
type = float,
|
||||
default = 2e-4,
|
||||
help = "Learning rate, default is 2e-4.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--optim", type=str, default="adamw_8bit", help="Optimizer type."
|
||||
"--optim", type = str, default = "adamw_8bit", help = "Optimizer type."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--weight_decay",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Weight decay, default is 0.01.",
|
||||
type = float,
|
||||
default = 0.01,
|
||||
help = "Weight decay, default is 0.01.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=str,
|
||||
default="linear",
|
||||
help="Learning rate scheduler type, default is 'linear'.",
|
||||
type = str,
|
||||
default = "linear",
|
||||
help = "Learning rate scheduler type, default is 'linear'.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=3407,
|
||||
help="Seed for reproducibility, default is 3407.",
|
||||
type = int,
|
||||
default = 3407,
|
||||
help = "Seed for reproducibility, default is 3407.",
|
||||
)
|
||||
|
||||
# Report/Logging arguments
|
||||
report_group = parser.add_argument_group("📊 Report Options")
|
||||
report_group.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=[
|
||||
type = str,
|
||||
default = "tensorboard",
|
||||
choices = [
|
||||
"azure_ml",
|
||||
"clearml",
|
||||
"codecarbon",
|
||||
|
|
@ -331,62 +331,62 @@ if __name__ == "__main__":
|
|||
"all",
|
||||
"none",
|
||||
],
|
||||
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
|
||||
help = "The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
|
||||
)
|
||||
report_group.add_argument(
|
||||
"--logging_steps", type=int, default=1, help="Logging steps, default is 1"
|
||||
"--logging_steps", type = int, default = 1, help = "Logging steps, default is 1"
|
||||
)
|
||||
|
||||
# Saving and pushing arguments
|
||||
save_group = parser.add_argument_group("💾 Save Model Options")
|
||||
save_group.add_argument(
|
||||
"--output_dir", type=str, default="outputs", help="Output directory"
|
||||
"--output_dir", type = str, default = "outputs", help = "Output directory"
|
||||
)
|
||||
save_group.add_argument(
|
||||
"--save_model", action="store_true", help="Save the model after training"
|
||||
"--save_model", action = "store_true", help = "Save the model after training"
|
||||
)
|
||||
save_group.add_argument(
|
||||
"--save_method",
|
||||
type=str,
|
||||
default="merged_16bit",
|
||||
choices=["merged_16bit", "merged_4bit", "lora"],
|
||||
help="Save method for the model, default is 'merged_16bit'",
|
||||
type = str,
|
||||
default = "merged_16bit",
|
||||
choices = ["merged_16bit", "merged_4bit", "lora"],
|
||||
help = "Save method for the model, default is 'merged_16bit'",
|
||||
)
|
||||
save_group.add_argument(
|
||||
"--save_gguf",
|
||||
action="store_true",
|
||||
help="Convert the model to GGUF after training",
|
||||
action = "store_true",
|
||||
help = "Convert the model to GGUF after training",
|
||||
)
|
||||
save_group.add_argument(
|
||||
"--save_path", type=str, default="model", help="Path to save the model"
|
||||
"--save_path", type = str, default = "model", help = "Path to save the model"
|
||||
)
|
||||
save_group.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
default="q8_0",
|
||||
nargs="+",
|
||||
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
|
||||
type = str,
|
||||
default = "q8_0",
|
||||
nargs = "+",
|
||||
help = "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
|
||||
)
|
||||
|
||||
push_group = parser.add_argument_group("🚀 Push Model Options")
|
||||
push_group.add_argument(
|
||||
"--push_model",
|
||||
action="store_true",
|
||||
help="Push the model to Hugging Face hub after training",
|
||||
action = "store_true",
|
||||
help = "Push the model to Hugging Face hub after training",
|
||||
)
|
||||
push_group.add_argument(
|
||||
"--push_gguf",
|
||||
action="store_true",
|
||||
help="Push the model as GGUF to Hugging Face hub after training",
|
||||
action = "store_true",
|
||||
help = "Push the model as GGUF to Hugging Face hub after training",
|
||||
)
|
||||
push_group.add_argument(
|
||||
"--hub_path",
|
||||
type=str,
|
||||
default="hf/model",
|
||||
help="Path on Hugging Face hub to push the model",
|
||||
type = str,
|
||||
default = "hf/model",
|
||||
help = "Path on Hugging Face hub to push the model",
|
||||
)
|
||||
push_group.add_argument(
|
||||
"--hub_token", type=str, help="Token for pushing the model to Hugging Face hub"
|
||||
"--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -110,7 +110,6 @@ from unsloth_zoo.device_type import (
|
|||
from .import_fixes import (
|
||||
fix_xformers_performance_issue,
|
||||
fix_vllm_aimv2_issue,
|
||||
fix_vllm_guided_decoding_params,
|
||||
ignore_logger_messages,
|
||||
patch_ipykernel_hf_xet,
|
||||
patch_trackio,
|
||||
|
|
@ -119,14 +118,13 @@ from .import_fixes import (
|
|||
|
||||
fix_xformers_performance_issue()
|
||||
fix_vllm_aimv2_issue()
|
||||
fix_vllm_guided_decoding_params()
|
||||
ignore_logger_messages()
|
||||
patch_ipykernel_hf_xet()
|
||||
patch_trackio()
|
||||
patch_datasets()
|
||||
|
||||
del fix_xformers_performance_issue
|
||||
del patch_vllm_imports
|
||||
del fix_vllm_aimv2_issue
|
||||
del ignore_logger_messages
|
||||
del patch_ipykernel_hf_xet
|
||||
del patch_trackio
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from .synthetic_configs import (
|
|||
)
|
||||
|
||||
|
||||
def terminate_tree(proc: subprocess.Popen, timeout=15):
|
||||
def terminate_tree(proc: subprocess.Popen, timeout = 15):
|
||||
if proc is None or proc.poll() is not None:
|
||||
return
|
||||
|
||||
|
|
@ -48,10 +48,10 @@ def terminate_tree(proc: subprocess.Popen, timeout=15):
|
|||
import psutil
|
||||
|
||||
parent = psutil.Process(proc.pid)
|
||||
for child in parent.children(recursive=True):
|
||||
for child in parent.children(recursive = True):
|
||||
child.terminate()
|
||||
parent.terminate()
|
||||
parent.wait(timeout=timeout / 2)
|
||||
parent.wait(timeout = timeout / 2)
|
||||
return
|
||||
except:
|
||||
pass
|
||||
|
|
@ -60,17 +60,17 @@ def terminate_tree(proc: subprocess.Popen, timeout=15):
|
|||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/T", "/F", "/PID", str(proc.pid)],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
capture_output = True,
|
||||
timeout = 5,
|
||||
)
|
||||
proc.wait(timeout=1)
|
||||
proc.wait(timeout = 1)
|
||||
return
|
||||
except:
|
||||
pass
|
||||
|
||||
proc.kill()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
proc.wait(timeout = 5)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
@ -81,16 +81,16 @@ class PipeCapture:
|
|||
def __init__(
|
||||
self,
|
||||
pipe,
|
||||
keep_lines=2000,
|
||||
echo=False,
|
||||
name="",
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
ready_regex=None,
|
||||
keep_lines = 2000,
|
||||
echo = False,
|
||||
name = "",
|
||||
text = True,
|
||||
encoding = "utf-8",
|
||||
errors = "replace",
|
||||
ready_regex = None,
|
||||
):
|
||||
self.pipe = pipe
|
||||
self.buf = deque(maxlen=keep_lines)
|
||||
self.buf = deque(maxlen = keep_lines)
|
||||
self.lock = threading.Lock()
|
||||
self.echo = echo
|
||||
self.name = name
|
||||
|
|
@ -107,7 +107,7 @@ class PipeCapture:
|
|||
ready_regex = re.compile(ready_regex)
|
||||
self.ready_regex = ready_regex
|
||||
|
||||
self.t = threading.Thread(target=self._reader, daemon=True)
|
||||
self.t = threading.Thread(target = self._reader, daemon = True)
|
||||
self.t.start()
|
||||
|
||||
def _reader(self):
|
||||
|
|
@ -136,16 +136,16 @@ class PipeCapture:
|
|||
pass
|
||||
self.closed_event.set()
|
||||
|
||||
def wait_for_ready(self, timeout=None):
|
||||
def wait_for_ready(self, timeout = None):
|
||||
return self.ready_event.wait(timeout)
|
||||
|
||||
def has_closed(self):
|
||||
return self.closed_event.is_set()
|
||||
|
||||
def wait_until_closed(self, timeout=None):
|
||||
def wait_until_closed(self, timeout = None):
|
||||
return self.closed_event.wait(timeout)
|
||||
|
||||
def tail(self, n=200):
|
||||
def tail(self, n = 200):
|
||||
with self.lock:
|
||||
return "\n".join(list(self.buf)[-n:])
|
||||
|
||||
|
|
@ -153,13 +153,13 @@ class PipeCapture:
|
|||
class SyntheticDataKit:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
|
||||
max_seq_length=2048,
|
||||
gpu_memory_utilization=0.98,
|
||||
float8_kv_cache=False,
|
||||
conservativeness=1.0,
|
||||
token=None,
|
||||
timeout=1200, # maybe this is not enough for large models if we need to download
|
||||
model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
|
||||
max_seq_length = 2048,
|
||||
gpu_memory_utilization = 0.98,
|
||||
float8_kv_cache = False,
|
||||
conservativeness = 1.0,
|
||||
token = None,
|
||||
timeout = 1200, # maybe this is not enough for large models if we need to download
|
||||
**kwargs,
|
||||
):
|
||||
assert type(model_name) is str
|
||||
|
|
@ -176,25 +176,25 @@ class SyntheticDataKit:
|
|||
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
token = token,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
token = token,
|
||||
)
|
||||
patch_vllm(debug=False)
|
||||
patch_vllm(debug = False)
|
||||
engine_args = load_vllm(
|
||||
model_name=model_name,
|
||||
config=self.config,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
max_seq_length=max_seq_length,
|
||||
disable_log_stats=True,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
conservativeness=conservativeness,
|
||||
return_args=True,
|
||||
enable_lora=False,
|
||||
use_bitsandbytes=False,
|
||||
compilation_config=3,
|
||||
model_name = model_name,
|
||||
config = self.config,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
max_seq_length = max_seq_length,
|
||||
disable_log_stats = True,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
conservativeness = conservativeness,
|
||||
return_args = True,
|
||||
enable_lora = False,
|
||||
use_bitsandbytes = False,
|
||||
compilation_config = 3,
|
||||
**kwargs,
|
||||
)
|
||||
if "dtype" in engine_args:
|
||||
|
|
@ -252,31 +252,31 @@ class SyntheticDataKit:
|
|||
logger.info(subprocess_commands)
|
||||
vllm_process = subprocess.Popen(
|
||||
subprocess_commands,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
start_new_session=True,
|
||||
stdout = subprocess.PIPE,
|
||||
stderr = subprocess.PIPE,
|
||||
start_new_session = True,
|
||||
)
|
||||
ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
|
||||
self.vllm_process = vllm_process
|
||||
self.stdout_capture = PipeCapture(
|
||||
vllm_process.stdout,
|
||||
keep_lines=1000,
|
||||
echo=True,
|
||||
name="vLLM STDOUT",
|
||||
ready_regex=ready_re,
|
||||
text=False,
|
||||
keep_lines = 1000,
|
||||
echo = True,
|
||||
name = "vLLM STDOUT",
|
||||
ready_regex = ready_re,
|
||||
text = False,
|
||||
)
|
||||
self.stderr_capture = PipeCapture(
|
||||
vllm_process.stderr,
|
||||
keep_lines=2000,
|
||||
echo=False,
|
||||
name="vLLM STDERR",
|
||||
ready_regex=None,
|
||||
text=False,
|
||||
keep_lines = 2000,
|
||||
echo = False,
|
||||
name = "vLLM STDERR",
|
||||
ready_regex = None,
|
||||
text = False,
|
||||
)
|
||||
# we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines
|
||||
|
||||
ready = self.stdout_capture.wait_for_ready(timeout=timeout)
|
||||
ready = self.stdout_capture.wait_for_ready(timeout = timeout)
|
||||
if not ready:
|
||||
if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:
|
||||
print("Stdout stream ended before readiness message detected.")
|
||||
|
|
@ -305,21 +305,21 @@ class SyntheticDataKit:
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
|
||||
max_seq_length=2048,
|
||||
gpu_memory_utilization=0.9,
|
||||
float8_kv_cache=False,
|
||||
conservativeness=1.0,
|
||||
token=None,
|
||||
model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
|
||||
max_seq_length = 2048,
|
||||
gpu_memory_utilization = 0.9,
|
||||
float8_kv_cache = False,
|
||||
conservativeness = 1.0,
|
||||
token = None,
|
||||
**kwargs,
|
||||
):
|
||||
return SyntheticDataKit(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
conservativeness=conservativeness,
|
||||
token=token,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
conservativeness = conservativeness,
|
||||
token = token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -340,7 +340,7 @@ class SyntheticDataKit:
|
|||
print("Attempting to terminate the VLLM server gracefully...")
|
||||
try:
|
||||
vllm_process.terminate()
|
||||
vllm_process.wait(timeout=10)
|
||||
vllm_process.wait(timeout = 10)
|
||||
print("Server terminated gracefully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
|
|
@ -364,7 +364,7 @@ class SyntheticDataKit:
|
|||
gc.collect()
|
||||
|
||||
# Delete vLLM module as well
|
||||
delete_vllm(llm=None)
|
||||
delete_vllm(llm = None)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
|
@ -375,7 +375,7 @@ class SyntheticDataKit:
|
|||
def __del__(self):
|
||||
self.cleanup()
|
||||
|
||||
def chunk_data(self, filename=None):
|
||||
def chunk_data(self, filename = None):
|
||||
# Chunks data by max tokens and generation length
|
||||
assert filename is not None
|
||||
assert os.path.exists(filename)
|
||||
|
|
@ -387,7 +387,7 @@ class SyntheticDataKit:
|
|||
if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"):
|
||||
raise RuntimeError("Please use prepare_qa_generation first!")
|
||||
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
with open(filename, "r", encoding = "utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
max_tokens = (
|
||||
|
|
@ -395,7 +395,7 @@ class SyntheticDataKit:
|
|||
) # -128 to reduce errors
|
||||
if max_tokens <= 5:
|
||||
raise RuntimeError("Generation length is way too long!")
|
||||
input_ids = self.tokenizer(text, add_special_tokens=False).input_ids
|
||||
input_ids = self.tokenizer(text, add_special_tokens = False).input_ids
|
||||
|
||||
# Get left and right boundaries
|
||||
length = len(input_ids)
|
||||
|
|
@ -416,21 +416,21 @@ class SyntheticDataKit:
|
|||
chunked_text = self.tokenizer.decode(input_ids[left:right])
|
||||
new_filename = f"{filename}_{i}{extension}"
|
||||
all_filenames.append(new_filename)
|
||||
with open(new_filename, "w", encoding="utf-8") as f:
|
||||
with open(new_filename, "w", encoding = "utf-8") as f:
|
||||
f.write(chunked_text)
|
||||
return all_filenames
|
||||
|
||||
def prepare_qa_generation(
|
||||
self,
|
||||
output_folder="data",
|
||||
max_generation_tokens=512,
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
overlap=64,
|
||||
default_num_pairs=25,
|
||||
cleanup_threshold=1.0,
|
||||
cleanup_batch_size=4,
|
||||
cleanup_temperature=0.3,
|
||||
output_folder = "data",
|
||||
max_generation_tokens = 512,
|
||||
temperature = 0.7,
|
||||
top_p = 0.95,
|
||||
overlap = 64,
|
||||
default_num_pairs = 25,
|
||||
cleanup_threshold = 1.0,
|
||||
cleanup_batch_size = 4,
|
||||
cleanup_temperature = 0.3,
|
||||
):
|
||||
assert hasattr(self, "model_name")
|
||||
assert hasattr(self, "max_seq_length")
|
||||
|
|
@ -439,7 +439,7 @@ class SyntheticDataKit:
|
|||
locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final"
|
||||
locations = locations.split(",")
|
||||
for path in locations:
|
||||
os.makedirs(os.path.join(output_folder, path), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_folder, path), exist_ok = True)
|
||||
|
||||
self.max_generation_tokens = max_generation_tokens
|
||||
|
||||
|
|
@ -459,7 +459,7 @@ class SyntheticDataKit:
|
|||
.replace("{cleanup_temperature}", str(cleanup_temperature))
|
||||
)
|
||||
|
||||
with open("synthetic_data_kit_config.yaml", "w", encoding="utf-8") as f:
|
||||
with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f:
|
||||
f.write(config)
|
||||
|
||||
self.overlap = overlap
|
||||
|
|
|
|||
|
|
@ -114,16 +114,6 @@ def fix_xformers_performance_issue():
|
|||
print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
|
||||
|
||||
|
||||
def fix_vllm_aimv2_issue():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
return
|
||||
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
|
||||
vllm_version = importlib_version("vllm")
|
||||
if Version(vllm_version) < Version("0.10.1"):
|
||||
vllm_version = importlib.util.find_spec("vllm").origin
|
||||
vllm_version = os.path.split(vllm_version)[0]
|
||||
ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py"
|
||||
try:
|
||||
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
|
||||
def fix_vllm_aimv2_issue():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
|
|
@ -165,22 +155,6 @@ def fix_vllm_aimv2_issue():
|
|||
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
|
||||
|
||||
|
||||
def fix_vllm_guided_decoding_params():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
return
|
||||
# GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
|
||||
# https://github.com/vllm-project/vllm/pull/22772/files
|
||||
# trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
|
||||
import vllm
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
except ImportError:
|
||||
vllm.sampling_params.GuidedDecodingParams = (
|
||||
vllm.sampling_params.StructuredOutputsParams
|
||||
)
|
||||
|
||||
|
||||
def ignore_logger_messages():
|
||||
# Ignore Environment variable `HF_TOKEN` is set
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def _cross_entropy_forward(
|
|||
mask = col_offsets < VOCAB_SIZE
|
||||
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ def _chunked_cross_entropy_forward(
|
|||
mask = col_offsets < VOCAB_SIZE
|
||||
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
|
|
@ -246,7 +246,7 @@ def _cross_entropy_backward(
|
|||
else:
|
||||
dloss = 0.0
|
||||
|
||||
x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
|
||||
# Do logit scaling for Cohere
|
||||
if DO_LOGIT_SCALING:
|
||||
|
|
@ -277,7 +277,7 @@ def _cross_entropy_backward(
|
|||
y = y * (1.0 - partial * partial)
|
||||
|
||||
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
||||
tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)
|
||||
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
||||
|
||||
|
||||
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
|
||||
|
|
@ -304,7 +304,7 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
|
||||
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
||||
n_chunks: int = div + (mod != 0)
|
||||
losses = torch.empty(n_rows, dtype=torch.float32, device=device)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
|
||||
|
||||
DO_SOFTCAPPING: bool = bool(logit_softcapping != 0)
|
||||
DO_LOGIT_SCALING: bool = bool(logit_scaling != 0)
|
||||
|
|
@ -316,7 +316,7 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
||||
if is_cdna():
|
||||
num_warps = num_warps // 2
|
||||
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=device)
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
|
||||
|
||||
with torch_gpu_device(device):
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
|
|
@ -325,13 +325,13 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE=vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
DO_SOFTCAPPING=DO_SOFTCAPPING,
|
||||
SOFTCAP=logit_softcapping,
|
||||
DO_LOGIT_SCALING=DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE=logit_scaling,
|
||||
num_warps=num_warps,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
||||
SOFTCAP = logit_softcapping,
|
||||
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE = logit_scaling,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
else:
|
||||
# For large vocabs > 65336 like Gemma 256K
|
||||
|
|
@ -340,8 +340,8 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
n_rows,
|
||||
n_chunks,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
dtype = torch.float32,
|
||||
device = device,
|
||||
)
|
||||
|
||||
with torch_gpu_device(device):
|
||||
|
|
@ -356,18 +356,18 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE=vocab_size,
|
||||
N_CHUNKS=n_chunks,
|
||||
BLOCK_SIZE=MAX_FUSED_SIZE,
|
||||
DO_SOFTCAPPING=DO_SOFTCAPPING,
|
||||
SOFTCAP=logit_softcapping,
|
||||
DO_LOGIT_SCALING=DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE=logit_scaling,
|
||||
num_warps=32 if not is_cdna() else 16,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
N_CHUNKS = n_chunks,
|
||||
BLOCK_SIZE = MAX_FUSED_SIZE,
|
||||
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
||||
SOFTCAP = logit_softcapping,
|
||||
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE = logit_scaling,
|
||||
num_warps = 32 if not is_cdna() else 16,
|
||||
)
|
||||
# logsumexp(chunked_logsumexp) - x
|
||||
# Do the -x separately
|
||||
logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum
|
||||
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
|
||||
losses += logsumexp
|
||||
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
|
||||
|
||||
|
|
@ -404,13 +404,13 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE=vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
DO_SOFTCAPPING=ctx.DO_SOFTCAPPING,
|
||||
SOFTCAP=ctx.logit_softcapping,
|
||||
DO_LOGIT_SCALING=ctx.DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE=ctx.logit_scaling,
|
||||
num_warps=8,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
||||
SOFTCAP = ctx.logit_softcapping,
|
||||
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
|
||||
LOGIT_SCALE = ctx.logit_scaling,
|
||||
num_warps = 8,
|
||||
)
|
||||
return (
|
||||
logits,
|
||||
|
|
@ -423,9 +423,9 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
def fast_cross_entropy_loss(
|
||||
logits,
|
||||
labels,
|
||||
logit_softcapping=0,
|
||||
logit_scaling=0,
|
||||
n_items=None,
|
||||
logit_softcapping = 0,
|
||||
logit_scaling = 0,
|
||||
n_items = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
|
|
@ -455,5 +455,5 @@ if (Version(torch.__version__) < Version("2.4.0")) and not hasattr(
|
|||
|
||||
|
||||
# Patch CE Losses in transformers
|
||||
def patch_loss_functions(torch_compile=True):
|
||||
_patch_loss_functions(fast_cross_entropy_loss, torch_compile=torch_compile)
|
||||
def patch_loss_functions(torch_compile = True):
|
||||
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ try:
|
|||
)
|
||||
|
||||
_flex_attention = torch.compile(
|
||||
_flex_attention, dynamic=True, options=torch_compile_options
|
||||
_flex_attention, dynamic = True, options = torch_compile_options
|
||||
)
|
||||
HAS_FLEX_ATTENTION = False
|
||||
except:
|
||||
|
|
@ -42,7 +42,7 @@ except:
|
|||
|
||||
if not HAS_FLEX_ATTENTION:
|
||||
# Logit softcapping
|
||||
@torch.compile(fullgraph=True, dynamic=True, options=torch_compile_options)
|
||||
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
||||
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
n_heads = self.config.num_attention_heads
|
||||
head_dim = self.head_dim
|
||||
|
|
@ -62,13 +62,13 @@ if not HAS_FLEX_ATTENTION:
|
|||
s = self.config.query_pre_attn_scalar
|
||||
t = self.config.attn_logit_softcapping
|
||||
|
||||
Q = Q * torch.tensor(s**-0.5, dtype=Q.dtype) # Follow Keras exactly
|
||||
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
||||
A = torch.matmul(Q, K.transpose(2, 3))
|
||||
A = t * torch.tanh(A / t) # Logit softcapping
|
||||
A += causal_mask[:q_len, :q_len]
|
||||
# Much slower in torch compile!
|
||||
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
||||
A = torch.nn.functional.softmax(A, dim=-1, dtype=torch.float32).to(Q.dtype)
|
||||
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
||||
A = torch.matmul(A, V)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
A = A.reshape(bsz, q_len, n_heads * head_dim)
|
||||
|
|
@ -92,7 +92,7 @@ else:
|
|||
return q_idx >= kv_idx
|
||||
|
||||
@functools.lru_cache
|
||||
def sliding_window_masker(size=4096):
|
||||
def sliding_window_masker(size = 4096):
|
||||
def sliding_window(b, h, q_idx, kv_idx):
|
||||
causal_mask = q_idx >= kv_idx
|
||||
window_mask = q_idx - kv_idx <= size
|
||||
|
|
@ -101,23 +101,23 @@ else:
|
|||
return sliding_window
|
||||
|
||||
@functools.lru_cache
|
||||
def create_block_mask(mask, n=128):
|
||||
def create_block_mask(mask, n = 128):
|
||||
return _create_block_mask(
|
||||
mask,
|
||||
1,
|
||||
1,
|
||||
n,
|
||||
n,
|
||||
BLOCK_SIZE=128,
|
||||
_compile=True,
|
||||
BLOCK_SIZE = 128,
|
||||
_compile = True,
|
||||
)
|
||||
|
||||
def create_flex_attention_causal_mask(max_seq_length=8192):
|
||||
def create_flex_attention_causal_mask(max_seq_length = 8192):
|
||||
causal_mask = create_block_mask(causal_masker, max_seq_length)
|
||||
return causal_mask
|
||||
|
||||
def create_flex_attention_sliding_window_mask(
|
||||
max_seq_length=8192, sliding_window=4096
|
||||
max_seq_length = 8192, sliding_window = 4096
|
||||
):
|
||||
sliding_masker = sliding_window_masker(sliding_window)
|
||||
causal_mask = create_block_mask(sliding_masker, max_seq_length)
|
||||
|
|
@ -129,9 +129,9 @@ else:
|
|||
score_mod = generate_tanh_softcap(t)
|
||||
return functools.partial(
|
||||
_flex_attention,
|
||||
score_mod=score_mod,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
score_mod = score_mod,
|
||||
scale = scale,
|
||||
enable_gqa = True,
|
||||
)
|
||||
|
||||
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
|
|
@ -140,7 +140,7 @@ else:
|
|||
s = self.config.query_pre_attn_scalar
|
||||
t = self.config.attn_logit_softcapping
|
||||
fx = flex_attention(s, t)
|
||||
A = fx(query=Q, key=K, value=V, block_mask=causal_mask)
|
||||
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
A = A.reshape(bsz, q_len, n_heads * head_dim)
|
||||
return A
|
||||
|
|
@ -170,17 +170,17 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len)
|
|||
s = self.config.query_pre_attn_scalar
|
||||
t = self.config.attn_logit_softcapping
|
||||
|
||||
Q = Q * torch.tensor(s**-0.5, dtype=Q.dtype) # Follow Keras exactly
|
||||
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
||||
A = torch_matmul(Q, K.transpose(2, 3))
|
||||
|
||||
# Logit softcapping
|
||||
A /= t
|
||||
torch_tanh(A, out=A)
|
||||
torch_tanh(A, out = A)
|
||||
A *= t
|
||||
A += causal_mask[:q_len, :q_len]
|
||||
# Much slower in torch compile!
|
||||
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
||||
A = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32).to(Q.dtype)
|
||||
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
||||
A = torch_matmul(A, V)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
A = A.reshape(bsz, q_len, n_heads * head_dim)
|
||||
|
|
|
|||
|
|
@ -63,21 +63,21 @@ except:
|
|||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis = 0)
|
||||
pid_n = tl.program_id(axis = 1)
|
||||
n = tl.cdiv(N, BLOCK_SIZE)
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs = offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
||||
x = tl.load(x_ptr + offs, mask = mask).to(tl.float32)
|
||||
s = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
y = x * s
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
tl.store(y_ptr + offs, y, mask = mask)
|
||||
|
||||
|
||||
def weight_dequant_block(
|
||||
x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16
|
||||
x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype = torch.bfloat16
|
||||
) -> torch.Tensor:
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
|
|
@ -85,16 +85,16 @@ def weight_dequant_block(
|
|||
s = s.contiguous()
|
||||
assert x.dim() == 2 and s.dim() == 2
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=dtype)
|
||||
y = torch.empty_like(x, dtype = dtype)
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M, meta["BLOCK_SIZE"]),
|
||||
triton.cdiv(N, meta["BLOCK_SIZE"]),
|
||||
)
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)
|
||||
return y
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
|
||||
if s.shape[1] == 1:
|
||||
# this is row quantized weight, just simple multiplication suffices
|
||||
if x.shape[0] == s.shape[0]:
|
||||
|
|
@ -108,13 +108,13 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
|
|||
return y
|
||||
else:
|
||||
# this is block quantized weight
|
||||
return weight_dequant_block(x, s, dtype=dtype)
|
||||
return weight_dequant_block(x, s, dtype = dtype)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
pid = tl.program_id(axis = 0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.0
|
||||
|
|
@ -134,13 +134,13 @@ def act_quant(
|
|||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
assert x.shape[-1] % block_size == 0
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
y = torch.empty_like(x, dtype = torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)
|
||||
return y, s
|
||||
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ def _w8a8_block_fp8_matmul(
|
|||
store the result in output tensor `C`.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
pid = tl.program_id(axis = 0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
|
@ -202,10 +202,10 @@ def _w8a8_block_fp8_matmul(
|
|||
offs_bsn = offs_bn // group_n
|
||||
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
a = tl.load(a_ptrs, mask = offs_k[None, :] < K - k * BLOCK_SIZE_K, other = 0.0)
|
||||
b = tl.load(b_ptrs, mask = offs_k[:, None] < K - k * BLOCK_SIZE_K, other = 0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
|
|
@ -227,7 +227,7 @@ def _w8a8_block_fp8_matmul(
|
|||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
tl.store(c_ptrs, c, mask = c_mask)
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_triton(
|
||||
|
|
@ -267,7 +267,7 @@ def w8a8_block_fp8_matmul_triton(
|
|||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
C = A.new_empty(C_shape, dtype = output_dtype)
|
||||
|
||||
BLOCK_SIZE_M = 128
|
||||
if M < BLOCK_SIZE_M:
|
||||
|
|
@ -303,10 +303,10 @@ def w8a8_block_fp8_matmul_triton(
|
|||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M=8,
|
||||
BLOCK_SIZE_M = BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N = BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K = BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M = 8,
|
||||
)
|
||||
return C
|
||||
|
||||
|
|
@ -324,7 +324,7 @@ def torchao_block_matmul(
|
|||
act_scale.contiguous(),
|
||||
weight_q.contiguous(),
|
||||
weight_scale.contiguous(),
|
||||
block_size=block_size[1],
|
||||
block_size = block_size[1],
|
||||
)
|
||||
return out.to(output_dtype)
|
||||
|
||||
|
|
@ -372,7 +372,7 @@ class FP8BlockQuantLinear(torch.autograd.Function):
|
|||
scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=X.dtype,
|
||||
output_dtype = X.dtype,
|
||||
)
|
||||
ctx.weight = weight
|
||||
ctx.weight_scale = weight_scale
|
||||
|
|
@ -394,7 +394,7 @@ def fp8_torch_block_quant_forward(X, weight, weight_scale):
|
|||
|
||||
class FbgemmFp8Linear_matmul(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, weight_scale, bias=None):
|
||||
def forward(ctx, x, weight, weight_scale, bias = None):
|
||||
if weight.shape[0] == weight_scale.shape[0] and (
|
||||
weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0
|
||||
):
|
||||
|
|
@ -409,7 +409,7 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
|
|||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]).contiguous(),
|
||||
scale_ub=getattr(weight, "input_scale_ub", None),
|
||||
scale_ub = getattr(weight, "input_scale_ub", None),
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
|
|
@ -423,7 +423,7 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
|
|||
weight_scale = weight_scale.contiguous()
|
||||
|
||||
output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True
|
||||
x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum = True
|
||||
)
|
||||
output = output + bias if bias is not None else output
|
||||
# Hacky for now, we have the output to the device of x
|
||||
|
|
@ -459,13 +459,13 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
|
|||
|
||||
|
||||
@torch_compile
|
||||
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None):
|
||||
def fbgemm_fp8_linear(X, weight, weight_scale, bias = None):
|
||||
return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)
|
||||
|
||||
|
||||
class FP8_fbgemm_block_linear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, X, weight, weight_scale, bias=None):
|
||||
def forward(ctx, X, weight, weight_scale, bias = None):
|
||||
orig_shape = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
|
||||
|
|
@ -516,7 +516,7 @@ class FP8_fbgemm_block_linear(torch.autograd.Function):
|
|||
|
||||
|
||||
@torch_compile
|
||||
def fp8_fbgemm_block_linear(X, weight, weight_scale, bias=None):
|
||||
def fp8_fbgemm_block_linear(X, weight, weight_scale, bias = None):
|
||||
return FP8_fbgemm_block_linear.apply(X, weight, weight_scale, bias)
|
||||
|
||||
|
||||
|
|
@ -525,11 +525,11 @@ def test_has_fbgemm():
|
|||
# For example RTX 5090 and RTX 4090 does not work
|
||||
# [TODO] Investigate with TorchAO why FBGEMM fails on consumer GPUs
|
||||
M, N, K = 128, 128, 128
|
||||
xq = torch.ones(M, K, dtype=torch.float8_e4m3fn, device="cuda")
|
||||
xq = torch.ones(M, K, dtype = torch.float8_e4m3fn, device = "cuda")
|
||||
wq = xq
|
||||
M, K = xq.shape
|
||||
N, _ = wq.shape
|
||||
block_scale = torch.ones(M // 128, K // 128, dtype=torch.float32, device="cuda")
|
||||
block_scale = torch.ones(M // 128, K // 128, dtype = torch.float32, device = "cuda")
|
||||
has_fbgemm = False
|
||||
try:
|
||||
out = torch.ops.fbgemm.f8f8bf16_blockwise(xq, wq, block_scale, block_scale)
|
||||
|
|
@ -575,7 +575,7 @@ except:
|
|||
|
||||
|
||||
@torch_compile
|
||||
def fp8_linear(X, weight, weight_scale, bias=None):
|
||||
def fp8_linear(X, weight, weight_scale, bias = None):
|
||||
if weight_scale.ndim == 2 and weight_scale.shape[1] > 1:
|
||||
# This is block quantized FP8 matmul
|
||||
out = fp8_block_quant_linear(X, weight, weight_scale)
|
||||
|
|
@ -585,7 +585,7 @@ def fp8_linear(X, weight, weight_scale, bias=None):
|
|||
return out
|
||||
|
||||
|
||||
def module_forward_patch(forward_function, scale_attr="weight_scale"):
|
||||
def module_forward_patch(forward_function, scale_attr = "weight_scale"):
|
||||
def patched_forward(self, X):
|
||||
return forward_function(X, self.weight, getattr(self, scale_attr))
|
||||
|
||||
|
|
|
|||
|
|
@ -49,22 +49,22 @@ def _exact_forward_kernel(
|
|||
|
||||
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
||||
# h = f * up
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
||||
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
||||
h_row = f_row * g_row
|
||||
|
||||
# Store h
|
||||
tl.store(h + offsets, h_row, mask=mask)
|
||||
tl.store(h + offsets, h_row, mask = mask)
|
||||
|
||||
|
||||
def geglu_exact_forward_kernel(gate, up):
|
||||
batch, seq_len, hd = gate.shape
|
||||
n_elements = gate.numel()
|
||||
device = gate.device
|
||||
out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=device)
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
with torch_gpu_device(device):
|
||||
_exact_forward_kernel[grid](
|
||||
|
|
@ -72,8 +72,8 @@ def geglu_exact_forward_kernel(gate, up):
|
|||
up,
|
||||
out,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
|
@ -107,9 +107,9 @@ def _exact_backward_kernel(
|
|||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
# Break e_row away for re-use
|
||||
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
||||
|
|
@ -132,9 +132,9 @@ def _exact_backward_kernel(
|
|||
de_row = de_row.to(DW_row.dtype)
|
||||
|
||||
# Store derivatives in buffers
|
||||
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask=mask) # de
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask = mask) # de
|
||||
|
||||
|
||||
def geglu_exact_backward_kernel(DW, e, g):
|
||||
|
|
@ -147,8 +147,8 @@ def geglu_exact_backward_kernel(DW, e, g):
|
|||
e,
|
||||
g,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return DW, e, g
|
||||
|
||||
|
|
@ -177,8 +177,8 @@ def _approx_forward_kernel(
|
|||
# h = f * up
|
||||
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
||||
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
f_row = (
|
||||
0.5 * e_row * (triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0)
|
||||
|
|
@ -187,14 +187,14 @@ def _approx_forward_kernel(
|
|||
h_row = f_row * g_row
|
||||
|
||||
# Store h
|
||||
tl.store(h + offsets, h_row, mask=mask)
|
||||
tl.store(h + offsets, h_row, mask = mask)
|
||||
|
||||
|
||||
def geglu_approx_forward_kernel(gate, up):
|
||||
batch, seq_len, hd = gate.shape
|
||||
n_elements = gate.numel()
|
||||
device = gate.device
|
||||
out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=device)
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
with torch_gpu_device(device):
|
||||
_approx_forward_kernel[grid](
|
||||
|
|
@ -202,8 +202,8 @@ def geglu_approx_forward_kernel(gate, up):
|
|||
up,
|
||||
out,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
|
@ -241,9 +241,9 @@ def _approx_backward_kernel(
|
|||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
# See https://www.desmos.com/calculator/nqprfoni6x
|
||||
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
||||
|
|
@ -269,9 +269,9 @@ def _approx_backward_kernel(
|
|||
de_row = de_row.to(DW_row.dtype)
|
||||
|
||||
# Store derivatives in buffers
|
||||
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask=mask) # de
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask = mask) # de
|
||||
|
||||
|
||||
def geglu_approx_backward_kernel(DW, e, g):
|
||||
|
|
@ -284,7 +284,7 @@ def geglu_approx_backward_kernel(DW, e, g):
|
|||
e,
|
||||
g,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return DW, e, g
|
||||
|
|
|
|||
|
|
@ -47,19 +47,19 @@ def layernorm_forward(
|
|||
|
||||
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
||||
# are in float32!
|
||||
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
b_row = tl.load(b + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
mean_X = tl.sum(X_row, axis=0) / n_cols
|
||||
mean_X = tl.sum(X_row, axis = 0) / n_cols
|
||||
# (X[0] - mean) == -mean so we need to mask it out
|
||||
XX = tl.where(mask, X_row - mean_X, 0)
|
||||
row_var = tl.sum(XX * XX, axis=0) / n_cols
|
||||
row_var = tl.sum(XX * XX, axis = 0) / n_cols
|
||||
inv_var = tl.math.rsqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
tl.store(mu, mean_X)
|
||||
output = (XX * inv_var) * W_row + b_row
|
||||
tl.store(Y + col_offsets, output, mask=mask)
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
|
@ -88,10 +88,10 @@ def layernorm_backward(
|
|||
|
||||
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
||||
# are in float32!
|
||||
dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
b_row = tl.load(b + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
inv_var = tl.load(r).to(tl.float32)
|
||||
mean = tl.load(mu).to(tl.float32)
|
||||
|
|
@ -99,11 +99,11 @@ def layernorm_backward(
|
|||
dY_W = dY_row * W_row
|
||||
dX_row = (
|
||||
dY_W
|
||||
- tl.sum(dY_W, axis=0) / n_cols
|
||||
- normed * tl.sum(dY_W * normed, axis=0) / n_cols
|
||||
- tl.sum(dY_W, axis = 0) / n_cols
|
||||
- normed * tl.sum(dY_W * normed, axis = 0) / n_cols
|
||||
)
|
||||
dX_row = dX_row * inv_var
|
||||
tl.store(dY + col_offsets, dX_row, mask=mask)
|
||||
tl.store(dY + col_offsets, dX_row, mask = mask)
|
||||
|
||||
|
||||
class Fast_Layernorm(torch.autograd.Function):
|
||||
|
|
@ -115,9 +115,9 @@ class Fast_Layernorm(torch.autograd.Function):
|
|||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
device = X.device
|
||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device)
|
||||
r = torch.empty(n_rows, dtype=torch.float32, device=device)
|
||||
mu = torch.empty(n_rows, dtype=torch.float32, device=device)
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
||||
mu = torch.empty(n_rows, dtype = torch.float32, device = device)
|
||||
|
||||
with torch_gpu_device(device):
|
||||
layernorm_forward[(n_rows,)](
|
||||
|
|
@ -131,8 +131,8 @@ class Fast_Layernorm(torch.autograd.Function):
|
|||
mu,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
|
|
@ -160,8 +160,8 @@ class Fast_Layernorm(torch.autograd.Function):
|
|||
mu,
|
||||
n_cols,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
dX = dY.view(*shape)
|
||||
return dX, None, None, None, None
|
||||
|
|
@ -181,26 +181,26 @@ def fast_layernorm(layernorm, X):
|
|||
|
||||
|
||||
def test_layernorm(
|
||||
dim=1024,
|
||||
eps=1e-5,
|
||||
dtype=torch.float16,
|
||||
bsz=21,
|
||||
random_state=3407,
|
||||
seqlen=3341,
|
||||
dim = 1024,
|
||||
eps = 1e-5,
|
||||
dtype = torch.float16,
|
||||
bsz = 21,
|
||||
random_state = 3407,
|
||||
seqlen = 3341,
|
||||
):
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
layernorm = LayerNorm((dim,), eps=eps, device="cuda", dtype=dtype)
|
||||
layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
|
||||
torch.cuda.manual_seed(random_state)
|
||||
torch.manual_seed(random_state)
|
||||
torch.nn.init.uniform_(layernorm.weight)
|
||||
torch.nn.init.uniform_(layernorm.bias)
|
||||
X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda")
|
||||
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
||||
XX = X.clone()
|
||||
X.requires_grad_(True)
|
||||
XX.requires_grad_(True)
|
||||
Y = layernorm(X)
|
||||
YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True)
|
||||
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
||||
Y.backward(YY)
|
||||
correct_grad = X.grad.clone()
|
||||
# from unsloth.kernels import fast_layernorm
|
||||
|
|
@ -212,14 +212,14 @@ def test_layernorm(
|
|||
def testing_suite_layernorm():
|
||||
for dim in [512, 1024, 2048]:
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.autocast(device_type = "cuda", dtype = dtype):
|
||||
for seqlen in [3341, 2048, 349]:
|
||||
for random_state in [3407, 42]:
|
||||
test_layernorm(
|
||||
dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
bsz=21,
|
||||
random_state=random_state,
|
||||
seqlen=seqlen,
|
||||
dim = dim,
|
||||
eps = 1e-5,
|
||||
dtype = dtype,
|
||||
bsz = 21,
|
||||
random_state = random_state,
|
||||
seqlen = seqlen,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def post_process_results(
|
|||
dtype: torch.dtype,
|
||||
autotune: bool,
|
||||
):
|
||||
df = KernelResult.to_dataframe(results, sort_by="speedup")
|
||||
df = KernelResult.to_dataframe(results, sort_by = "speedup")
|
||||
df = create_merged_results(df, mode, seqlen, dtype, autotune)
|
||||
return df
|
||||
|
||||
|
|
@ -63,16 +63,16 @@ def save_results(
|
|||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
print(f"Saving results to {save_path}")
|
||||
df.to_csv(save_path, index=False)
|
||||
df.to_csv(save_path, index = False)
|
||||
|
||||
|
||||
def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool):
|
||||
block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1])
|
||||
block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1])
|
||||
block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1])
|
||||
num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step=2)
|
||||
num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step = 2)
|
||||
num_stages_range = multiples_of_range(
|
||||
args.num_stages[0], args.num_stages[1], step=1
|
||||
args.num_stages[0], args.num_stages[1], step = 1
|
||||
)
|
||||
|
||||
mode = args.mode
|
||||
|
|
@ -96,39 +96,39 @@ def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y:
|
|||
):
|
||||
if mode == "forward":
|
||||
kernel_config = KernelConfigForward(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_w=tma_load_a,
|
||||
use_tma_load_x=tma_load_b,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_w = tma_load_a,
|
||||
use_tma_load_x = tma_load_b,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
)
|
||||
elif mode == "dW":
|
||||
kernel_config = KernelConfigBackward_dW(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_dy=tma_load_a,
|
||||
use_tma_load_x=tma_load_b,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_dy = tma_load_a,
|
||||
use_tma_load_x = tma_load_b,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
)
|
||||
elif mode == "dX":
|
||||
kernel_config = KernelConfigBackward_dX(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_dy=tma_load_a,
|
||||
use_tma_load_w=tma_load_b,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_dy = tma_load_a,
|
||||
use_tma_load_w = tma_load_b,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
|
@ -161,7 +161,7 @@ def power_of_two_range(start, end):
|
|||
return [2**i for i in range(int(start), int(end) + 1)]
|
||||
|
||||
|
||||
def multiples_of_range(start, end, step=1):
|
||||
def multiples_of_range(start, end, step = 1):
|
||||
return list(range(start, end + step, step))
|
||||
|
||||
|
||||
|
|
@ -221,8 +221,8 @@ def postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_
|
|||
print(f"{mode} {key}: {value.all_kwargs()}")
|
||||
save_autotune_results(
|
||||
autotuner.cache,
|
||||
mode=mode,
|
||||
ref_time=ref_time,
|
||||
fused_time=fused_time,
|
||||
results_dir=results_dir,
|
||||
mode = mode,
|
||||
ref_time = ref_time,
|
||||
fused_time = fused_time,
|
||||
results_dir = results_dir,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def get_per_device_per_stream_alloc_fn(device):
|
|||
or _per_stream_tensors[stream].numel() < size
|
||||
):
|
||||
_per_stream_tensors[stream] = torch.empty(
|
||||
size, device=device, dtype=torch.int8
|
||||
size, device = device, dtype = torch.int8
|
||||
)
|
||||
_per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
|
||||
return _per_stream_tensors[stream]
|
||||
|
|
@ -160,7 +160,7 @@ def grouped_gemm_forward(
|
|||
if use_tma or autotune:
|
||||
|
||||
def alloc_fn(size: int, alignment: int, stream: int):
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
return torch.empty(size, device = "cuda", dtype = torch.int8)
|
||||
|
||||
triton.set_allocator(alloc_fn)
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ def grouped_gemm_forward(
|
|||
f"DEBUG::GROUPED_GEMM {topk_weights.tolist()} {gather_indices.tolist()}"
|
||||
)
|
||||
|
||||
y = torch.empty((total_tokens, N), device=X.device, dtype=X.dtype)
|
||||
y = torch.empty((total_tokens, N), device = X.device, dtype = X.dtype)
|
||||
if total_tokens == 0 or N == 0:
|
||||
return y
|
||||
|
||||
|
|
@ -357,7 +357,7 @@ def grouped_gemm_dX(
|
|||
|
||||
def alloc_fn(size: int, alignment: int, stream: int):
|
||||
# print(f"DEBUG::GROUPED_GEMM alloc_fn {size=} {alignment=} {stream=}")
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
return torch.empty(size, device = "cuda", dtype = torch.int8)
|
||||
|
||||
triton.set_allocator(alloc_fn)
|
||||
|
||||
|
|
@ -383,7 +383,7 @@ def grouped_gemm_dX(
|
|||
# Note that the output shape is [NUM_TOKENS * TOPK, K] even when `permute_x` is True since we need to accumulate gradients across all experts chosen by the token.
|
||||
# This will be done in a post-processing step reduction step.
|
||||
output_shape = (total_tokens, K)
|
||||
dX = torch.zeros(output_shape, device=dY.device, dtype=dY.dtype)
|
||||
dX = torch.zeros(output_shape, device = dY.device, dtype = dY.dtype)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties(
|
||||
"cuda"
|
||||
|
|
@ -512,7 +512,7 @@ def grouped_gemm_dW(
|
|||
if use_tma or autotune:
|
||||
|
||||
def alloc_fn(size: int, alignment: int, stream: int):
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
return torch.empty(size, device = "cuda", dtype = torch.int8)
|
||||
|
||||
triton.set_allocator(alloc_fn)
|
||||
|
||||
|
|
@ -538,7 +538,7 @@ def grouped_gemm_dW(
|
|||
|
||||
assert M_grad == total_tokens, f"dY M ({M_grad}) != total_tokens ({total_tokens})"
|
||||
|
||||
dW = torch.zeros((num_experts, N, K), device=X.device, dtype=X.dtype)
|
||||
dW = torch.zeros((num_experts, N, K), device = X.device, dtype = X.dtype)
|
||||
|
||||
if not autotune:
|
||||
BLOCK_SIZE_M = min(total_tokens, BLOCK_SIZE_M)
|
||||
|
|
@ -663,17 +663,17 @@ class GroupedGemm(torch.autograd.Function):
|
|||
fwd_config["use_tma_store"] = kernel_config_fwd.use_tma_store
|
||||
|
||||
return grouped_gemm_forward(
|
||||
X=X,
|
||||
W=W,
|
||||
topk=topk,
|
||||
m_sizes=m_sizes,
|
||||
gather_indices=gather_indices,
|
||||
topk_weights=topk_weights,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
fuse_mul_post=fuse_mul_post,
|
||||
X = X,
|
||||
W = W,
|
||||
topk = topk,
|
||||
m_sizes = m_sizes,
|
||||
gather_indices = gather_indices,
|
||||
topk_weights = topk_weights,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
fuse_mul_post = fuse_mul_post,
|
||||
# Autotune -- this will override the manual kernel config if true
|
||||
autotune=autotune,
|
||||
autotune = autotune,
|
||||
# Manual kernel config
|
||||
**fwd_config,
|
||||
)
|
||||
|
|
@ -719,15 +719,15 @@ class GroupedGemm(torch.autograd.Function):
|
|||
bwd_dW_config["num_stages"] = kernel_config_bwd_dW.num_stages
|
||||
|
||||
dW = grouped_gemm_dW(
|
||||
X=X,
|
||||
dY=dY,
|
||||
m_sizes=m_sizes,
|
||||
gather_indices=gather_indices,
|
||||
topk=topk,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
X = X,
|
||||
dY = dY,
|
||||
m_sizes = m_sizes,
|
||||
gather_indices = gather_indices,
|
||||
topk = topk,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
# Autotune -- this will override the manual kernel config if true
|
||||
autotune=autotune,
|
||||
autotune = autotune,
|
||||
# Manual kernel config
|
||||
**bwd_dW_config,
|
||||
)
|
||||
|
|
@ -747,21 +747,21 @@ class GroupedGemm(torch.autograd.Function):
|
|||
bwd_dX_config["num_stages"] = kernel_config_bwd_dX.num_stages
|
||||
|
||||
dX = grouped_gemm_dX(
|
||||
dY=dY,
|
||||
W=W,
|
||||
m_sizes=m_sizes,
|
||||
gather_indices=gather_indices,
|
||||
topk=topk,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
dY = dY,
|
||||
W = W,
|
||||
m_sizes = m_sizes,
|
||||
gather_indices = gather_indices,
|
||||
topk = topk,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
# Autotune -- this will override the manual kernel config if true
|
||||
autotune=autotune,
|
||||
autotune = autotune,
|
||||
# Manual kernel config
|
||||
**bwd_dX_config,
|
||||
)
|
||||
|
||||
if topk > 1 and permute_x:
|
||||
dX = dX.view(X.shape[0], topk, -1).sum(dim=1)
|
||||
dX = dX.view(X.shape[0], topk, -1).sum(dim = 1)
|
||||
else:
|
||||
dX = None
|
||||
|
||||
|
|
@ -866,8 +866,8 @@ def grouped_gemm(
|
|||
gather_indices: torch.Tensor = None,
|
||||
permute_x: bool = False,
|
||||
permute_y: bool = False,
|
||||
topk_weights=None,
|
||||
fuse_mul_post=False,
|
||||
topk_weights = None,
|
||||
fuse_mul_post = False,
|
||||
kernel_config_fwd: KernelConfigForward = None,
|
||||
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
|
||||
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
|
||||
|
|
@ -908,31 +908,31 @@ def grouped_gemm(
|
|||
check_valid_config_fwd(
|
||||
permute_x,
|
||||
permute_y,
|
||||
use_tma_load_x=kernel_config_fwd.use_tma_load_x,
|
||||
use_tma_load_w=kernel_config_fwd.use_tma_load_w,
|
||||
use_tma_store=kernel_config_fwd.use_tma_store,
|
||||
fuse_mul_post=fuse_mul_post,
|
||||
is_first_gemm=is_first_gemm,
|
||||
use_tma_load_x = kernel_config_fwd.use_tma_load_x,
|
||||
use_tma_load_w = kernel_config_fwd.use_tma_load_w,
|
||||
use_tma_store = kernel_config_fwd.use_tma_store,
|
||||
fuse_mul_post = fuse_mul_post,
|
||||
is_first_gemm = is_first_gemm,
|
||||
)
|
||||
if kernel_config_bwd_dW is not None and not dX_only:
|
||||
check_valid_config_bwd_dW(
|
||||
permute_x,
|
||||
permute_y,
|
||||
use_tma_load_dY=kernel_config_bwd_dW.use_tma_load_dy,
|
||||
use_tma_load_x=kernel_config_bwd_dW.use_tma_load_x,
|
||||
use_tma_store=kernel_config_bwd_dW.use_tma_store,
|
||||
fuse_mul_post=fuse_mul_post,
|
||||
is_first_gemm=is_first_gemm,
|
||||
use_tma_load_dY = kernel_config_bwd_dW.use_tma_load_dy,
|
||||
use_tma_load_x = kernel_config_bwd_dW.use_tma_load_x,
|
||||
use_tma_store = kernel_config_bwd_dW.use_tma_store,
|
||||
fuse_mul_post = fuse_mul_post,
|
||||
is_first_gemm = is_first_gemm,
|
||||
)
|
||||
if kernel_config_bwd_dX is not None and not dW_only:
|
||||
check_valid_config_bwd_dX(
|
||||
permute_x,
|
||||
permute_y,
|
||||
use_tma_load_dY=kernel_config_bwd_dX.use_tma_load_dy,
|
||||
use_tma_load_w=kernel_config_bwd_dX.use_tma_load_w,
|
||||
use_tma_store=kernel_config_bwd_dX.use_tma_store,
|
||||
fuse_mul_post=fuse_mul_post,
|
||||
is_first_gemm=is_first_gemm,
|
||||
use_tma_load_dY = kernel_config_bwd_dX.use_tma_load_dy,
|
||||
use_tma_load_w = kernel_config_bwd_dX.use_tma_load_w,
|
||||
use_tma_store = kernel_config_bwd_dX.use_tma_store,
|
||||
fuse_mul_post = fuse_mul_post,
|
||||
is_first_gemm = is_first_gemm,
|
||||
)
|
||||
|
||||
if permute_x or permute_y:
|
||||
|
|
|
|||
|
|
@ -68,25 +68,25 @@ def _grouped_gemm_forward_kernel(
|
|||
if USE_TMA_LOAD_X:
|
||||
x_desc = tl._experimental_make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[TOTAL_TOKENS, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
shape = [TOTAL_TOKENS, K],
|
||||
strides = [K, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
if USE_TMA_LOAD_W:
|
||||
expert_stride = N * K
|
||||
w_desc = tl._experimental_make_tensor_descriptor(
|
||||
w_ptr,
|
||||
shape=[NUM_EXPERTS, N, K],
|
||||
strides=[expert_stride, K, 1],
|
||||
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
shape = [NUM_EXPERTS, N, K],
|
||||
strides = [expert_stride, K, 1],
|
||||
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
m_end = 0
|
||||
processed_tiles = 0
|
||||
m_block_range = tl.arange(0, BLOCK_SIZE_M)
|
||||
|
||||
for expert_idx in tl.range(NUM_EXPERTS, flatten=FLATTEN):
|
||||
for expert_idx in tl.range(NUM_EXPERTS, flatten = FLATTEN):
|
||||
m_start = m_end
|
||||
m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
|
||||
m_end = m_start + m_size
|
||||
|
|
@ -102,9 +102,9 @@ def _grouped_gemm_forward_kernel(
|
|||
if USE_TMA_STORE:
|
||||
y_desc = tl._experimental_make_tensor_descriptor(
|
||||
y_ptr, # + m_start * N,
|
||||
shape=[m_end, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
shape = [m_end, N],
|
||||
strides = [N, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
# Process tiles for this expert
|
||||
|
|
@ -127,7 +127,7 @@ def _grouped_gemm_forward_kernel(
|
|||
)
|
||||
expert_token_idx = tl.load(
|
||||
gather_indices_ptr + indices_to_gather,
|
||||
mask=indices_to_gather < TOTAL_TOKENS,
|
||||
mask = indices_to_gather < TOTAL_TOKENS,
|
||||
)
|
||||
expert_token_offsets = expert_token_idx[:, None]
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ def _grouped_gemm_forward_kernel(
|
|||
if SHOULD_FUSE_MUL:
|
||||
topk_load_idx = expert_token_offsets
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = acc_dtype)
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
|
|
@ -194,19 +194,19 @@ def _grouped_gemm_forward_kernel(
|
|||
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
if not USE_TMA_LOAD_X:
|
||||
x = tl.load(x_ptrs, mask=row_mask)
|
||||
x = tl.load(x_ptrs, mask = row_mask)
|
||||
else:
|
||||
x = x_desc.load([m_start + off_am, k_offset])
|
||||
|
||||
if FUSE_MUL_PRE:
|
||||
# Check for correct broadcasting
|
||||
topk_weights = tl.load(
|
||||
topk_weights_ptr + topk_load_idx, mask=row_mask
|
||||
topk_weights_ptr + topk_load_idx, mask = row_mask
|
||||
)
|
||||
x *= topk_weights.to(x.dtype)
|
||||
|
||||
if not USE_TMA_LOAD_W:
|
||||
w = tl.load(w_ptrs, mask=offs_bn[:, None] < N)
|
||||
w = tl.load(w_ptrs, mask = offs_bn[:, None] < N)
|
||||
else:
|
||||
w = w_desc.load(
|
||||
[expert_idx, tile_n_idx * BLOCK_SIZE_N, k_offset]
|
||||
|
|
@ -228,7 +228,7 @@ def _grouped_gemm_forward_kernel(
|
|||
if FUSE_MUL_POST:
|
||||
# Check for correct broadcasting
|
||||
topk_weights = tl.load(
|
||||
topk_weights_ptr + topk_load_idx, mask=row_mask
|
||||
topk_weights_ptr + topk_load_idx, mask = row_mask
|
||||
)
|
||||
y *= topk_weights.to(output_dtype)
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def _grouped_gemm_forward_kernel(
|
|||
tl.store(
|
||||
y_ptr + store_idx + offs_bn[None, :],
|
||||
y,
|
||||
mask=store_mask,
|
||||
mask = store_mask,
|
||||
)
|
||||
tidx += NUM_SMS
|
||||
|
||||
|
|
@ -251,9 +251,9 @@ def _grouped_gemm_forward_kernel(
|
|||
|
||||
|
||||
_autotuned_grouped_gemm_forward_kernel = triton.autotune(
|
||||
configs=get_forward_configs(),
|
||||
prune_configs_by={"early_config_prune": prune_kernel_configs_fwd},
|
||||
key=[
|
||||
configs = get_forward_configs(),
|
||||
prune_configs_by = {"early_config_prune": prune_kernel_configs_fwd},
|
||||
key = [
|
||||
"NUM_EXPERTS",
|
||||
"NUM_TOKENS",
|
||||
"N",
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ class KernelResult:
|
|||
def to_dict(self):
|
||||
return OrderedDict(
|
||||
**asdict(self.kernel_config),
|
||||
torch_time=self.torch_time,
|
||||
triton_time=self.triton_time,
|
||||
speedup=self.speedup,
|
||||
torch_time = self.torch_time,
|
||||
triton_time = self.triton_time,
|
||||
speedup = self.speedup,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -119,7 +119,7 @@ class KernelResult:
|
|||
results: list["KernelResult"], sort_by: str = "speedup", ascending: bool = False
|
||||
):
|
||||
df = pd.DataFrame([result.to_dict() for result in results])
|
||||
df = df.sort_values(by=sort_by, ascending=ascending)
|
||||
df = df.sort_values(by = sort_by, ascending = ascending)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -130,7 +130,7 @@ class KernelResult:
|
|||
filename: str = "results.csv",
|
||||
):
|
||||
df = KernelResult.to_dataframe(results, sort_by, ascending)
|
||||
df.to_csv(filename, index=False)
|
||||
df.to_csv(filename, index = False)
|
||||
|
||||
@staticmethod
|
||||
def print_table(
|
||||
|
|
@ -140,17 +140,17 @@ class KernelResult:
|
|||
num_results: int = 10,
|
||||
):
|
||||
df = KernelResult.to_dataframe(results, sort_by, ascending)
|
||||
print(df.head(num_results).to_string(index=False))
|
||||
print(df.head(num_results).to_string(index = False))
|
||||
|
||||
|
||||
def get_kernel_configs(
|
||||
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
|
||||
num_warps=DEFAULT_NUM_WARPS,
|
||||
num_stages=DEFAULT_NUM_STAGES,
|
||||
use_tma_loads=BOOLS,
|
||||
fuse_permute=BOOLS,
|
||||
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
|
||||
num_warps = DEFAULT_NUM_WARPS,
|
||||
num_stages = DEFAULT_NUM_STAGES,
|
||||
use_tma_loads = BOOLS,
|
||||
fuse_permute = BOOLS,
|
||||
):
|
||||
kernel_configs_fwd = []
|
||||
kernel_configs_backward_dW = []
|
||||
|
|
@ -160,44 +160,44 @@ def get_kernel_configs(
|
|||
):
|
||||
kernel_configs_fwd.append(
|
||||
KernelConfigForward(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
use_tma_load_x=use_tma_load,
|
||||
use_tma_load_w=use_tma_load,
|
||||
use_tma_store=False,
|
||||
permute_x=permute,
|
||||
permute_y=permute,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
use_tma_load_x = use_tma_load,
|
||||
use_tma_load_w = use_tma_load,
|
||||
use_tma_store = False,
|
||||
permute_x = permute,
|
||||
permute_y = permute,
|
||||
)
|
||||
)
|
||||
kernel_configs_backward_dW.append(
|
||||
KernelConfigBackward_dW(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
use_tma_load_dy=use_tma_load,
|
||||
use_tma_load_x=use_tma_load,
|
||||
use_tma_store=False,
|
||||
permute_x=permute,
|
||||
permute_y=permute,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
use_tma_load_dy = use_tma_load,
|
||||
use_tma_load_x = use_tma_load,
|
||||
use_tma_store = False,
|
||||
permute_x = permute,
|
||||
permute_y = permute,
|
||||
)
|
||||
)
|
||||
kernel_configs_backward_dX.append(
|
||||
KernelConfigBackward_dX(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
use_tma_load_dy=use_tma_load,
|
||||
use_tma_load_w=use_tma_load,
|
||||
use_tma_store=False,
|
||||
permute_x=permute,
|
||||
permute_y=permute,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
use_tma_load_dy = use_tma_load,
|
||||
use_tma_load_w = use_tma_load,
|
||||
use_tma_store = False,
|
||||
permute_x = permute,
|
||||
permute_y = permute,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -78,8 +78,8 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
self.gate = torch.nn.Parameter(gate)
|
||||
|
||||
# experts
|
||||
self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=True)
|
||||
self.down_proj = torch.nn.Parameter(down_proj, requires_grad=True)
|
||||
self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad = True)
|
||||
self.down_proj = torch.nn.Parameter(down_proj, requires_grad = True)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -90,17 +90,17 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
gate = moe_block.gate.weight.data
|
||||
gate_proj = torch.stack(
|
||||
[moe_block.experts[i].gate_proj.weight.data for i in range(num_experts)],
|
||||
dim=0,
|
||||
dim = 0,
|
||||
)
|
||||
up_proj = torch.stack(
|
||||
[moe_block.experts[i].up_proj.weight.data for i in range(num_experts)],
|
||||
dim=0,
|
||||
dim = 0,
|
||||
)
|
||||
down_proj = torch.stack(
|
||||
[moe_block.experts[i].down_proj.weight.data for i in range(num_experts)],
|
||||
dim=0,
|
||||
dim = 0,
|
||||
)
|
||||
gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
|
||||
gate_up_proj = torch.cat([gate_proj, up_proj], dim = 1)
|
||||
return gate, gate_up_proj, down_proj
|
||||
|
||||
@classmethod
|
||||
|
|
@ -117,7 +117,7 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
moe_block.experts[i].gate_proj.weight.data,
|
||||
moe_block.experts[i].up_proj.weight.data,
|
||||
],
|
||||
dim=0,
|
||||
dim = 0,
|
||||
)
|
||||
)
|
||||
assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data)
|
||||
|
|
@ -132,12 +132,12 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = torch.nn.functional.linear(hidden_states, self.gate)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1
|
||||
routing_weights, self.top_k, dim = -1
|
||||
)
|
||||
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
|
|
@ -177,13 +177,13 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
|
||||
# Start expert computation
|
||||
first_gemm = torch_grouped_gemm(
|
||||
X=hidden_states, W=self.gate_up_proj, m_sizes=token_counts_by_expert
|
||||
X = hidden_states, W = self.gate_up_proj, m_sizes = token_counts_by_expert
|
||||
)
|
||||
assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
|
||||
intermediate = self.act_and_mul(first_gemm)
|
||||
assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
|
||||
second_gemm = torch_grouped_gemm(
|
||||
X=intermediate, W=self.down_proj, m_sizes=token_counts_by_expert
|
||||
X = intermediate, W = self.down_proj, m_sizes = token_counts_by_expert
|
||||
)
|
||||
assert second_gemm.shape == (total_tokens, hidden_dim)
|
||||
|
||||
|
|
@ -197,19 +197,19 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
|
|||
hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
|
||||
* routing_weights[..., None]
|
||||
)
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
assert hidden_states.shape == (num_tokens, hidden_dim)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
|
||||
return GroupedGEMMResult(
|
||||
token_counts_by_expert=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk_weights=routing_weights,
|
||||
first_gemm=first_gemm,
|
||||
intermediate=intermediate,
|
||||
second_gemm=second_gemm,
|
||||
hidden_states_unpermute=hidden_states_unpermute,
|
||||
hidden_states=hidden_states,
|
||||
token_counts_by_expert = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk_weights = routing_weights,
|
||||
first_gemm = first_gemm,
|
||||
intermediate = intermediate,
|
||||
second_gemm = second_gemm,
|
||||
hidden_states_unpermute = hidden_states_unpermute,
|
||||
hidden_states = hidden_states,
|
||||
), router_logits
|
||||
|
||||
|
||||
|
|
@ -267,14 +267,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
gate,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
dW_only=dW_only,
|
||||
dX_only=dX_only,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
dW_only = dW_only,
|
||||
dX_only = dX_only,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -299,37 +299,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
hidden_states = permute(hidden_states, gather_indices, self.top_k)
|
||||
# Start expert computation
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.gate_up_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=self.permute_x,
|
||||
permute_y=False, # output of first grouped gemm should never be permuted
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=True,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.gate_up_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = self.permute_x,
|
||||
permute_y = False, # output of first grouped gemm should never be permuted
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = True,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
hidden_states = self.act_and_mul(hidden_states)
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.down_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=False,
|
||||
permute_y=self.permute_y,
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=False,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.down_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = False,
|
||||
permute_y = self.permute_y,
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = False,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
|
||||
# Post-processing
|
||||
|
|
@ -342,7 +342,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
hidden_states.view(num_tokens, self.top_k, hidden_dim)
|
||||
* routing_weights[..., None]
|
||||
)
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
|
||||
return hidden_states, router_logits
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from grouped_gemm.kernels.tuning import (
|
|||
)
|
||||
|
||||
|
||||
def print_delimiter(char="-", length=80):
|
||||
def print_delimiter(char = "-", length = 80):
|
||||
print(char * length)
|
||||
|
||||
|
||||
|
|
@ -29,28 +29,28 @@ def delimiter_context():
|
|||
print_delimiter()
|
||||
|
||||
|
||||
def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
|
||||
def make_inputs(M, N, K, E, topk, dtype, requires_grad = False):
|
||||
X1 = (
|
||||
torch.randn((M, K), device="cuda", dtype=dtype, requires_grad=requires_grad)
|
||||
torch.randn((M, K), device = "cuda", dtype = dtype, requires_grad = requires_grad)
|
||||
/ 10
|
||||
)
|
||||
X2 = (
|
||||
torch.randn(
|
||||
(M * topk, N), device="cuda", dtype=dtype, requires_grad=requires_grad
|
||||
(M * topk, N), device = "cuda", dtype = dtype, requires_grad = requires_grad
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
W1 = (
|
||||
torch.randn(
|
||||
(E, 2 * N, K), device="cuda", dtype=dtype, requires_grad=requires_grad
|
||||
(E, 2 * N, K), device = "cuda", dtype = dtype, requires_grad = requires_grad
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
W2 = (
|
||||
torch.randn((E, K, N), device="cuda", dtype=dtype, requires_grad=requires_grad)
|
||||
torch.randn((E, K, N), device = "cuda", dtype = dtype, requires_grad = requires_grad)
|
||||
/ 10
|
||||
)
|
||||
score = torch.randn((M, E), device="cuda", dtype=dtype, requires_grad=requires_grad)
|
||||
score = torch.randn((M, E), device = "cuda", dtype = dtype, requires_grad = requires_grad)
|
||||
if requires_grad:
|
||||
X1.retain_grad()
|
||||
X2.retain_grad()
|
||||
|
|
@ -60,7 +60,7 @@ def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
|
|||
return X1, X2, W1, W2, score
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@dataclass(kw_only = True)
|
||||
class DataConfig:
|
||||
seq_len: int
|
||||
dtype: torch.dtype
|
||||
|
|
@ -68,7 +68,7 @@ class DataConfig:
|
|||
bs: int = 1
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@dataclass(kw_only = True)
|
||||
class ModelConfig:
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
|
|
@ -77,13 +77,13 @@ class ModelConfig:
|
|||
use_sigmoid: bool
|
||||
renormalize: bool
|
||||
pre_mul: bool = False
|
||||
post_mul: bool = field(init=False)
|
||||
post_mul: bool = field(init = False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.post_mul = not self.pre_mul
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@dataclass(kw_only = True)
|
||||
class GroupedGEMMTestConfig:
|
||||
name: str = "test"
|
||||
data_config: DataConfig
|
||||
|
|
@ -105,7 +105,7 @@ def assert_equal(ref, tri):
|
|||
assert ref == tri, f"ref not equal to tri {ref} != {tri}"
|
||||
|
||||
|
||||
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
|
||||
def assert_close(ref, tri, maxtol = None, rmstol = None, description = "--", verbose = True):
|
||||
if tri.dtype.itemsize == 1:
|
||||
ref_as_type = ref.to(tri.dtype)
|
||||
if ref.dtype == tri.dtype:
|
||||
|
|
@ -182,11 +182,11 @@ def assert_indx_equal(ref, tri):
|
|||
|
||||
|
||||
def get_kernel_test_configs(
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
BLOCK_SIZE_M = 32,
|
||||
BLOCK_SIZE_N = 32,
|
||||
BLOCK_SIZE_K = 32,
|
||||
num_warps = 4,
|
||||
num_stages = 2,
|
||||
) -> list[KernelConfig]:
|
||||
configs_fwd = []
|
||||
configs_bwd_dX = []
|
||||
|
|
@ -199,44 +199,44 @@ def get_kernel_test_configs(
|
|||
for use_tma_store in [True, False]:
|
||||
configs_fwd.append(
|
||||
KernelConfigForward(
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_w=use_tma_load_w,
|
||||
use_tma_load_x=use_tma_load_x,
|
||||
use_tma_store=use_tma_store,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
BLOCK_SIZE_M = BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N = BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K = BLOCK_SIZE_K,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_w = use_tma_load_w,
|
||||
use_tma_load_x = use_tma_load_x,
|
||||
use_tma_store = use_tma_store,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
)
|
||||
)
|
||||
configs_bwd_dX.append(
|
||||
KernelConfigBackward_dX(
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_dy=use_tma_load_x,
|
||||
use_tma_load_w=use_tma_load_w,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
use_tma_store=use_tma_store,
|
||||
BLOCK_SIZE_M = BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N = BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K = BLOCK_SIZE_K,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_dy = use_tma_load_x,
|
||||
use_tma_load_w = use_tma_load_w,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
use_tma_store = use_tma_store,
|
||||
)
|
||||
)
|
||||
configs_bwd_dW.append(
|
||||
KernelConfigBackward_dW(
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
use_tma_load_dy=use_tma_load_w,
|
||||
use_tma_load_x=use_tma_load_x,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
use_tma_store=use_tma_store,
|
||||
BLOCK_SIZE_M = BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N = BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K = BLOCK_SIZE_K,
|
||||
num_warps = num_warps,
|
||||
num_stages = num_stages,
|
||||
use_tma_load_dy = use_tma_load_w,
|
||||
use_tma_load_x = use_tma_load_x,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
use_tma_store = use_tma_store,
|
||||
)
|
||||
)
|
||||
configs_fwd = prune_kernel_configs_fwd(configs_fwd)
|
||||
|
|
@ -289,39 +289,39 @@ TEST_MODEL_SIZES = [
|
|||
|
||||
SMALL_MODEL_CONFIGS = [
|
||||
ModelConfig(
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
hidden_size=model_size[0],
|
||||
intermediate_size=model_size[1],
|
||||
use_sigmoid=False,
|
||||
renormalize=False,
|
||||
topk = topk,
|
||||
num_experts = num_experts,
|
||||
hidden_size = model_size[0],
|
||||
intermediate_size = model_size[1],
|
||||
use_sigmoid = False,
|
||||
renormalize = False,
|
||||
)
|
||||
for topk, num_experts, model_size in itertools.product(
|
||||
TOPK, NUM_EXPERTS, TEST_MODEL_SIZES
|
||||
)
|
||||
]
|
||||
LLAMA_MODEL_CONFIG = ModelConfig(
|
||||
topk=1,
|
||||
num_experts=16,
|
||||
hidden_size=5120,
|
||||
intermediate_size=8192,
|
||||
use_sigmoid=True,
|
||||
renormalize=False,
|
||||
topk = 1,
|
||||
num_experts = 16,
|
||||
hidden_size = 5120,
|
||||
intermediate_size = 8192,
|
||||
use_sigmoid = True,
|
||||
renormalize = False,
|
||||
)
|
||||
QWEN_MODEL_CONFIG = ModelConfig(
|
||||
topk=8,
|
||||
num_experts=128,
|
||||
hidden_size=2048,
|
||||
intermediate_size=768,
|
||||
use_sigmoid=False,
|
||||
renormalize=False,
|
||||
topk = 8,
|
||||
num_experts = 128,
|
||||
hidden_size = 2048,
|
||||
intermediate_size = 768,
|
||||
use_sigmoid = False,
|
||||
renormalize = False,
|
||||
)
|
||||
|
||||
SEQLENS = [128, 1024]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
DATA_CONFIGS = [
|
||||
DataConfig(seq_len=seq_len, dtype=dtype)
|
||||
DataConfig(seq_len = seq_len, dtype = dtype)
|
||||
for seq_len, dtype in itertools.product(SEQLENS, DTYPE)
|
||||
]
|
||||
KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
|
||||
|
|
@ -331,6 +331,6 @@ KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
|
|||
if __name__ == "__main__":
|
||||
print(
|
||||
KERNEL_CONFIGS_BWD_dX[0].to_string(
|
||||
include_tuning_params=False, include_tma=False
|
||||
include_tuning_params = False, include_tma = False
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,13 +33,13 @@ def rebind_experts_to_shared_buffer(
|
|||
dtype = moe_block.experts[0].down_proj.weight.dtype
|
||||
|
||||
buffer_up = torch.empty(
|
||||
num_experts, interm_size, hidden_size, device=device, dtype=dtype
|
||||
num_experts, interm_size, hidden_size, device = device, dtype = dtype
|
||||
)
|
||||
buffer_gate = torch.empty(
|
||||
num_experts, interm_size, hidden_size, device=device, dtype=dtype
|
||||
num_experts, interm_size, hidden_size, device = device, dtype = dtype
|
||||
)
|
||||
buffer_down = torch.empty(
|
||||
num_experts, hidden_size, interm_size, device=device, dtype=dtype
|
||||
num_experts, hidden_size, interm_size, device = device, dtype = dtype
|
||||
)
|
||||
|
||||
# Step 2: Copy existing expert weights into buffers
|
||||
|
|
@ -114,7 +114,7 @@ def check_down_proj_grad(
|
|||
test_grad = grouped_gemm_block.down_proj.grad[i]
|
||||
assert test_grad is not None
|
||||
diff = (ref_grad - test_grad).abs().max()
|
||||
if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
|
||||
if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
|
||||
print(f"expert {i} down_proj_grad_diff: {diff.detach().cpu().item():.6f}")
|
||||
|
||||
|
||||
|
|
@ -152,12 +152,12 @@ def check_gate_up_proj_grad(
|
|||
# Check gradients
|
||||
diff = (ref_gate_proj_grad - test_gate_proj_grad).abs().max()
|
||||
if not torch.allclose(
|
||||
ref_gate_proj_grad, test_gate_proj_grad, atol=atol, rtol=rtol
|
||||
ref_gate_proj_grad, test_gate_proj_grad, atol = atol, rtol = rtol
|
||||
):
|
||||
print(f"expert {i} gate_proj_grad_diff: {diff.detach().cpu().item():.6f}")
|
||||
diff = (ref_up_proj_grad - test_up_proj_grad).abs().max()
|
||||
if not torch.allclose(
|
||||
ref_up_proj_grad, test_up_proj_grad, atol=atol, rtol=rtol
|
||||
ref_up_proj_grad, test_up_proj_grad, atol = atol, rtol = rtol
|
||||
):
|
||||
print(f"expert {i} up_proj_grad_diff: {diff.detach().cpu().item():.6f}")
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ def check_gate_grad(
|
|||
test_grad = grouped_gemm_block.gate.grad
|
||||
assert test_grad is not None
|
||||
diff = (ref_grad - test_grad).abs().max()
|
||||
if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
|
||||
if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
|
||||
print(f"gate_grad_diff: {diff.detach().cpu().item():.6f}")
|
||||
|
||||
|
||||
|
|
@ -200,7 +200,7 @@ def check_tensor_allclose(
|
|||
if verbose:
|
||||
print(f"{name} diff: {diff.detach().cpu().item():.6f}")
|
||||
assert torch.allclose(
|
||||
X_ref, X_test, atol=atol, rtol=rtol
|
||||
X_ref, X_test, atol = atol, rtol = rtol
|
||||
), f"{name} diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
|
||||
|
|
@ -227,7 +227,7 @@ def check_expert_grads(
|
|||
test_grad = test_grads[i]
|
||||
diff = (ref_grad - test_grad).abs().max()
|
||||
assert torch.allclose(
|
||||
ref_grad, test_grad, atol=atol, rtol=rtol
|
||||
ref_grad, test_grad, atol = atol, rtol = rtol
|
||||
), f"{field}[{i}] diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
# Test all experts
|
||||
|
|
@ -235,7 +235,7 @@ def check_expert_grads(
|
|||
if verbose:
|
||||
print(f"{field} diff: {diff.detach().cpu().item():.6f}")
|
||||
assert torch.allclose(
|
||||
ref_grads, test_grads, atol=atol, rtol=rtol
|
||||
ref_grads, test_grads, atol = atol, rtol = rtol
|
||||
), f"{field} diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
|
||||
|
|
@ -269,7 +269,7 @@ def check_fwd(
|
|||
if verbose:
|
||||
print(f"output diff: {diff.detach().cpu().item():.6f}")
|
||||
assert torch.allclose(
|
||||
ref_output, test_output, atol=atol, rtol=rtol
|
||||
ref_output, test_output, atol = atol, rtol = rtol
|
||||
), f"output diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
# Check router logits
|
||||
|
|
@ -279,7 +279,7 @@ def check_fwd(
|
|||
if verbose:
|
||||
print(f"router_logits diff: {diff.detach().cpu().item():.6f}")
|
||||
assert torch.allclose(
|
||||
ref_router_logits, test_router_logits, atol=atol, rtol=rtol
|
||||
ref_router_logits, test_router_logits, atol = atol, rtol = rtol
|
||||
), f"router_logits diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
|
||||
|
|
@ -305,7 +305,7 @@ def check_grouped_gemm_results(
|
|||
print(f"{field.name} diff: {diff.detach().cpu().item():.6f}")
|
||||
|
||||
assert torch.allclose(
|
||||
ref_value, test_value, atol=atol, rtol=rtol
|
||||
ref_value, test_value, atol = atol, rtol = rtol
|
||||
), f"{field.name} diff: {diff.detach().cpu().item():.6f}"
|
||||
|
||||
|
||||
|
|
@ -314,13 +314,13 @@ def run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False
|
|||
output, router_logits = model(X)
|
||||
if is_grouped_gemm:
|
||||
result = ForwardResult(
|
||||
output=output.hidden_states,
|
||||
router_logits=router_logits,
|
||||
X=X,
|
||||
grouped_gemm_result=output,
|
||||
output = output.hidden_states,
|
||||
router_logits = router_logits,
|
||||
X = X,
|
||||
grouped_gemm_result = output,
|
||||
)
|
||||
else:
|
||||
result = ForwardResult(output=output, router_logits=router_logits, X=X)
|
||||
result = ForwardResult(output = output, router_logits = router_logits, X = X)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -344,16 +344,16 @@ def run_backward(
|
|||
)
|
||||
elif isinstance(model, Qwen3MoeGroupedGEMMBlock):
|
||||
gate_grad = model.gate.grad
|
||||
gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim=1)
|
||||
gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim = 1)
|
||||
down_proj_grad = model.down_proj.grad
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
return BackwardResult(
|
||||
X_grad=X.grad,
|
||||
gate_grad=gate_grad,
|
||||
gate_proj_grad=gate_proj_grad,
|
||||
up_proj_grad=up_proj_grad,
|
||||
down_proj_grad=down_proj_grad,
|
||||
X_grad = X.grad,
|
||||
gate_grad = gate_grad,
|
||||
gate_proj_grad = gate_proj_grad,
|
||||
up_proj_grad = up_proj_grad,
|
||||
down_proj_grad = down_proj_grad,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -414,12 +414,12 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
gate,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
||||
|
|
@ -446,35 +446,35 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
|
||||
# Start expert computation
|
||||
first_gemm = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.gate_up_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=self.permute_x,
|
||||
permute_y=False, # output of first grouped gemm should never be permuted
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=True,
|
||||
X = hidden_states,
|
||||
W = self.gate_up_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = self.permute_x,
|
||||
permute_y = False, # output of first grouped gemm should never be permuted
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = True,
|
||||
)
|
||||
assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
|
||||
intermediate = self.act_and_mul(first_gemm)
|
||||
assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
|
||||
second_gemm = grouped_gemm(
|
||||
X=intermediate,
|
||||
W=self.down_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=False,
|
||||
permute_y=self.permute_y,
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=False,
|
||||
X = intermediate,
|
||||
W = self.down_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = False,
|
||||
permute_y = self.permute_y,
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = False,
|
||||
)
|
||||
assert second_gemm.shape == (total_tokens, hidden_dim)
|
||||
|
||||
|
|
@ -491,17 +491,17 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
|
||||
* routing_weights[..., None]
|
||||
)
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
assert hidden_states.shape == (num_tokens, hidden_dim)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
|
||||
return GroupedGEMMResult(
|
||||
token_counts_by_expert=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk_weights=routing_weights,
|
||||
first_gemm=first_gemm,
|
||||
intermediate=intermediate,
|
||||
second_gemm=second_gemm,
|
||||
hidden_states_unpermute=hidden_states_unpermute,
|
||||
hidden_states=hidden_states,
|
||||
token_counts_by_expert = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk_weights = routing_weights,
|
||||
first_gemm = first_gemm,
|
||||
intermediate = intermediate,
|
||||
second_gemm = second_gemm,
|
||||
hidden_states_unpermute = hidden_states_unpermute,
|
||||
hidden_states = hidden_states,
|
||||
), router_logits
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -68,18 +68,18 @@ TOLERANCES = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope = "module")
|
||||
def model_id():
|
||||
return "Qwen/Qwen3-30B-A3B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope = "module")
|
||||
def config(model_id: str):
|
||||
return AutoConfig.from_pretrained(model_id)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
|
||||
def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
|
||||
print(char * num_chars)
|
||||
print(prelude)
|
||||
yield
|
||||
|
|
@ -96,16 +96,16 @@ NUM_AUTOTUNE_CONFIGS = 50
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"permute_y", [True], ids=lambda x: "permute_y" if x else "no_permute_y"
|
||||
"permute_y", [True], ids = lambda x: "permute_y" if x else "no_permute_y"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"permute_x", [True], ids=lambda x: "permute_x" if x else "no_permute_x"
|
||||
"permute_x", [True], ids = lambda x: "permute_x" if x else "no_permute_x"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"autotune", [True], ids=lambda x: "autotune" if x else "manual"
|
||||
"autotune", [True], ids = lambda x: "autotune" if x else "manual"
|
||||
)
|
||||
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids=str)
|
||||
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids = str)
|
||||
def test_qwen3_moe(
|
||||
config: Qwen3MoeConfig,
|
||||
seqlen: int,
|
||||
|
|
@ -157,36 +157,36 @@ def test_qwen3_moe(
|
|||
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
|
||||
fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
|
||||
moe_block,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
).to(device, dtype)
|
||||
fused_gemm_block.check_weights(moe_block)
|
||||
|
||||
X = torch.randn(
|
||||
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
|
||||
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
|
||||
)
|
||||
|
||||
# Forward
|
||||
ref_result = run_forward(moe_block, X, is_grouped_gemm=False)
|
||||
grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm=True)
|
||||
fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm=True)
|
||||
ref_result = run_forward(moe_block, X, is_grouped_gemm = False)
|
||||
grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)
|
||||
fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)
|
||||
|
||||
with annotated_context(
|
||||
"Testing forward pass",
|
||||
epilogue="Passed forward tests!",
|
||||
char="=",
|
||||
num_chars=100,
|
||||
epilogue = "Passed forward tests!",
|
||||
char = "=",
|
||||
num_chars = 100,
|
||||
):
|
||||
# Sanity checks
|
||||
|
||||
with annotated_context(
|
||||
"Checking HF vs torch grouped gemm MoE forward outputs..."
|
||||
):
|
||||
check_fwd(ref_result, grouped_result, atol, rtol, verbose=False)
|
||||
check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)
|
||||
|
||||
with annotated_context(
|
||||
"Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs..."
|
||||
|
|
@ -195,42 +195,42 @@ def test_qwen3_moe(
|
|||
check_grouped_gemm_results(
|
||||
grouped_result.grouped_gemm_result,
|
||||
fused_result.grouped_gemm_result,
|
||||
permute_y=permute_y,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
verbose=False,
|
||||
permute_y = permute_y,
|
||||
atol = atol,
|
||||
rtol = rtol,
|
||||
verbose = False,
|
||||
)
|
||||
# Actual test
|
||||
with annotated_context(
|
||||
"Checking HF vs fused grouped gemm MoE forward outputs..."
|
||||
):
|
||||
check_fwd(ref_result, fused_result, atol, rtol, verbose=True)
|
||||
check_fwd(ref_result, fused_result, atol, rtol, verbose = True)
|
||||
|
||||
# Backward
|
||||
grad_output = torch.randn_like(ref_result.output)
|
||||
ref_backward_result = run_backward(
|
||||
moe_block, grad_output, output=ref_result.output, X=ref_result.X
|
||||
moe_block, grad_output, output = ref_result.output, X = ref_result.X
|
||||
)
|
||||
grouped_backward_result = run_backward(
|
||||
grouped_gemm_block,
|
||||
grad_output,
|
||||
output=grouped_result.output,
|
||||
X=grouped_result.X,
|
||||
output = grouped_result.output,
|
||||
X = grouped_result.X,
|
||||
)
|
||||
fused_backward_result = run_backward(
|
||||
fused_gemm_block, grad_output, output=fused_result.output, X=fused_result.X
|
||||
fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X
|
||||
)
|
||||
|
||||
with annotated_context(
|
||||
"Testing backward pass",
|
||||
epilogue="Passed backward tests!",
|
||||
char="=",
|
||||
num_chars=100,
|
||||
epilogue = "Passed backward tests!",
|
||||
char = "=",
|
||||
num_chars = 100,
|
||||
):
|
||||
# Sanity checks
|
||||
with annotated_context("Checking HF vs torch grouped gemm MoE grads..."):
|
||||
check_grads(
|
||||
ref_backward_result, grouped_backward_result, atol, rtol, verbose=False
|
||||
ref_backward_result, grouped_backward_result, atol, rtol, verbose = False
|
||||
)
|
||||
with annotated_context(
|
||||
"Checking torch grouped gemm MoE vs fused grouped gemm MoE grads..."
|
||||
|
|
@ -240,25 +240,25 @@ def test_qwen3_moe(
|
|||
fused_backward_result,
|
||||
atol,
|
||||
rtol,
|
||||
verbose=False,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
# Actual test
|
||||
with annotated_context("Checking HF vs fused grouped gemm MoE grads..."):
|
||||
check_grads(
|
||||
ref_backward_result, fused_backward_result, atol, rtol, verbose=True
|
||||
ref_backward_result, fused_backward_result, atol, rtol, verbose = True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seqlen", type=int, default=1024)
|
||||
parser.add_argument("--seqlen", type = int, default = 1024)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
|
||||
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
|
||||
)
|
||||
parser.add_argument("--permute_x", action="store_true")
|
||||
parser.add_argument("--permute_y", action="store_true")
|
||||
parser.add_argument("--autotune", action="store_true")
|
||||
parser.add_argument("--permute_x", action = "store_true")
|
||||
parser.add_argument("--permute_y", action = "store_true")
|
||||
parser.add_argument("--autotune", action = "store_true")
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
args_dict = vars(args)
|
||||
|
|
|
|||
|
|
@ -54,10 +54,10 @@ except:
|
|||
CohereFlashAttention2 = CohereAttention
|
||||
|
||||
|
||||
def fast_layernorm_inference(self, X, out_weight=None):
|
||||
XX = X.to(torch.float32, copy=True)
|
||||
XX -= X.mean(-1, keepdim=True)
|
||||
variance = XX.square().mean(-1, keepdim=True)
|
||||
def fast_layernorm_inference(self, X, out_weight = None):
|
||||
XX = X.to(torch.float32, copy = True)
|
||||
XX -= X.mean(-1, keepdim = True)
|
||||
variance = XX.square().mean(-1, keepdim = True)
|
||||
variance += self.variance_epsilon
|
||||
XX *= variance.rsqrt_()
|
||||
out_weight[:] = self.weight
|
||||
|
|
@ -120,8 +120,8 @@ def CohereAttention_fast_forward(
|
|||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -143,14 +143,14 @@ def CohereAttention_fast_forward(
|
|||
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
||||
else:
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION and attention_mask is None:
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
A = flash_attn_func(Q, K, V, causal=True)
|
||||
A = flash_attn_func(Q, K, V, causal = True)
|
||||
else:
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
|
|
@ -169,7 +169,7 @@ def CohereAttention_fast_forward(
|
|||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
A = scaled_dot_product_attention(
|
||||
Q, K, V, attn_mask=attention_mask, is_causal=False
|
||||
Q, K, V, attn_mask = attention_mask, is_causal = False
|
||||
)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
|
|
@ -198,7 +198,7 @@ def CohereDecoderLayer_fast_forward(
|
|||
self, "_flag_for_generation"
|
||||
): # past_key_value is not None:
|
||||
out_weight = torch.empty(
|
||||
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
|
||||
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
|
|
@ -207,14 +207,14 @@ def CohereDecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states, out_weight
|
||||
)
|
||||
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
|
|
@ -226,14 +226,14 @@ def CohereDecoderLayer_fast_forward(
|
|||
residual = hidden_states
|
||||
hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
|
||||
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
|
|
@ -260,8 +260,8 @@ def CohereAttention_fast_forward_inference(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill=False,
|
||||
attention_mask=None,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
):
|
||||
Xn = hidden_states
|
||||
bsz, _, hd = hidden_states.size()
|
||||
|
|
@ -284,45 +284,45 @@ def CohereAttention_fast_forward_inference(
|
|||
if do_prefill:
|
||||
self.paged_attention = torch.empty(
|
||||
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
|
||||
dtype=dtype,
|
||||
device="cuda:0",
|
||||
dtype = dtype,
|
||||
device = "cuda:0",
|
||||
)
|
||||
self.paged_attention_K = self.paged_attention[:, 0]
|
||||
self.paged_attention_V = self.paged_attention[:, 1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty(
|
||||
(2, bsz, 1, attention_size), dtype=dtype, device="cuda:0"
|
||||
(2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0"
|
||||
)
|
||||
self.temp_KV = torch.empty(
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device="cuda:0"
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = "cuda:0"
|
||||
)
|
||||
self.RH_Q = torch.empty(
|
||||
(bsz, n_heads, 1, head_dim), dtype=dtype, device="cuda:0"
|
||||
(bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0"
|
||||
)
|
||||
|
||||
# Mistral Nemo 12b has weird dimensions
|
||||
if attention_size != hidden_size:
|
||||
self.temp_O = torch.empty(
|
||||
(1, bsz, hidden_size), dtype=dtype, device="cuda:0"
|
||||
(1, bsz, hidden_size), dtype = dtype, device = "cuda:0"
|
||||
)
|
||||
else:
|
||||
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
|
||||
|
||||
self.attention = torch.empty(
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),
|
||||
dtype=dtype,
|
||||
device="cuda:0",
|
||||
dtype = dtype,
|
||||
device = "cuda:0",
|
||||
)
|
||||
self.scalar = 1.0 / math_sqrt(self.head_dim)
|
||||
self.half_head_dim = head_dim // 2
|
||||
# Cohere has QK layernorms
|
||||
if self.use_qk_norm:
|
||||
self.q_norm_out_weight = torch.empty(
|
||||
self.q_norm.weight.shape, dtype=torch.float32, device="cuda:0"
|
||||
self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
|
||||
)
|
||||
self.k_norm_out_weight = torch.empty(
|
||||
self.k_norm.weight.shape, dtype=torch.float32, device="cuda:0"
|
||||
self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
|
||||
)
|
||||
else:
|
||||
self.q_norm_out_weight = None
|
||||
|
|
@ -343,9 +343,9 @@ def CohereAttention_fast_forward_inference(
|
|||
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
|
||||
)
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
|
@ -363,7 +363,7 @@ def CohereAttention_fast_forward_inference(
|
|||
RH_Q = self.RH_Q
|
||||
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
|
||||
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
|
||||
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
|
||||
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
|
||||
Qn *= cos
|
||||
Qn.addcmul_(RH_Q, sin)
|
||||
|
||||
|
|
@ -372,7 +372,7 @@ def CohereAttention_fast_forward_inference(
|
|||
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
|
||||
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
|
||||
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
|
||||
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
|
||||
Kn *= cos
|
||||
Kn.addcmul_(RH_K, sin)
|
||||
|
||||
|
|
@ -414,20 +414,20 @@ def CohereAttention_fast_forward_inference(
|
|||
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
||||
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
||||
A = torch_matmul(
|
||||
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
|
||||
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
|
||||
)
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
A[:] = torch_nn_functional_softmax(
|
||||
A, dim=-1, dtype=torch.float32
|
||||
A, dim = -1, dtype = torch.float32
|
||||
) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vnn, out=Qn)
|
||||
A = torch_matmul(A, Vnn, out = Qn)
|
||||
else:
|
||||
A = scaled_dot_product_attention(
|
||||
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False
|
||||
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
|
||||
)
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
||||
return A, (Kn, Vn)
|
||||
|
||||
|
||||
|
|
@ -438,13 +438,13 @@ def CohereModel_fast_forward_inference(
|
|||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
attention_mask=None,
|
||||
attention_mask = None,
|
||||
):
|
||||
out_weights = tuple(
|
||||
torch.empty_like(
|
||||
self.model.layers[0].input_layernorm.weight,
|
||||
dtype=torch.float32,
|
||||
device=torch.device(x),
|
||||
dtype = torch.float32,
|
||||
device = torch.device(x),
|
||||
)
|
||||
for x in range(DEVICE_COUNT)
|
||||
)
|
||||
|
|
@ -459,7 +459,7 @@ def CohereModel_fast_forward_inference(
|
|||
(bsz, q_len),
|
||||
hidden_states,
|
||||
seq_len,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
sliding_window = getattr(self.config, "sliding_window", None),
|
||||
)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
|
@ -477,11 +477,11 @@ def CohereModel_fast_forward_inference(
|
|||
hidden_states_attention, present_key_value = (
|
||||
CohereAttention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_values[idx],
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
hidden_states = hidden_states,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = attention_mask,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -496,10 +496,10 @@ def CohereModel_fast_forward_inference(
|
|||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=[],
|
||||
attentions=[],
|
||||
last_hidden_state = hidden_states,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -507,10 +507,10 @@ class FastCohereModel(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="cohere",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=CohereAttention,
|
||||
model_name = "cohere",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = CohereAttention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ def FalconH1Attention_fast_forward(
|
|||
else:
|
||||
# Extend RoPE dynamically to fit in VRA
|
||||
rotary_emb = self.rotary_emb
|
||||
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
|
||||
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
|
||||
device_index = Q.device.index
|
||||
|
||||
if position_ids is None:
|
||||
|
|
@ -132,8 +132,8 @@ def FalconH1Attention_fast_forward(
|
|||
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -157,7 +157,7 @@ def FalconH1Attention_fast_forward(
|
|||
# Xformers does support the forward pass though
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION and attention_mask is None:
|
||||
|
|
@ -166,7 +166,7 @@ def FalconH1Attention_fast_forward(
|
|||
V = V.transpose(1, 2)
|
||||
sw = kv_seq_len
|
||||
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
# Grouped query attention
|
||||
# if n_groups != 1:
|
||||
|
|
@ -181,7 +181,7 @@ def FalconH1Attention_fast_forward(
|
|||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
A = scaled_dot_product_attention(
|
||||
Q, K, V, attn_mask=attention_mask, is_causal=False
|
||||
Q, K, V, attn_mask = attention_mask, is_causal = False
|
||||
)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
|
|
@ -200,8 +200,8 @@ def FalconH1Attention_fast_forward_inference(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill=False,
|
||||
attention_mask=None,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
):
|
||||
"""
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
|
||||
|
|
@ -253,29 +253,29 @@ def FalconH1Attention_fast_forward_inference(
|
|||
if do_prefill:
|
||||
self.paged_attention = torch.empty(
|
||||
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
self.paged_attention_K = self.paged_attention[:, 0]
|
||||
self.paged_attention_V = self.paged_attention[:, 1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty(
|
||||
(2, bsz, 1, attention_size), dtype=dtype, device=device
|
||||
(2, bsz, 1, attention_size), dtype = dtype, device = device
|
||||
)
|
||||
self.temp_KV = torch.empty(
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
|
||||
)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
|
||||
|
||||
# Mistral Nemo 12b has weird dimensions
|
||||
if attention_size != hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
|
||||
else:
|
||||
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
|
||||
|
||||
self.attention = torch.empty(
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
|
||||
)
|
||||
self.scalar = 1.0 / math_sqrt(self.head_dim)
|
||||
self.half_head_dim = head_dim // 2
|
||||
|
|
@ -295,10 +295,10 @@ def FalconH1Attention_fast_forward_inference(
|
|||
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
|
||||
)
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Kn = Kn * self.config.key_multiplier
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(
|
||||
bsz, 1, n_heads, head_dim
|
||||
) # .transpose(1, 2) # we will transpose after normalisation
|
||||
|
|
@ -375,25 +375,25 @@ def FalconH1Attention_fast_forward_inference(
|
|||
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
||||
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
||||
A = torch_matmul(
|
||||
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
|
||||
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
|
||||
)
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
A[:] = torch_nn_functional_softmax(
|
||||
A, dim=-1, dtype=torch.float32
|
||||
A, dim = -1, dtype = torch.float32
|
||||
) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vnn, out=Qn)
|
||||
A = torch_matmul(A, Vnn, out = Qn)
|
||||
else:
|
||||
if SDPA_HAS_GQA:
|
||||
A = scaled_dot_product_attention(
|
||||
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False, enable_gqa=True
|
||||
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True
|
||||
)
|
||||
else:
|
||||
A = scaled_dot_product_attention(
|
||||
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False
|
||||
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
|
||||
)
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
||||
return A, (Kn, Vn)
|
||||
|
||||
|
||||
|
|
@ -401,7 +401,7 @@ def FalconH1Attention_fast_forward_inference(
|
|||
def FalconH1DecoderLayer_fast_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
causal_mask=None,
|
||||
causal_mask = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
mamba_attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
|
|
@ -433,23 +433,23 @@ def FalconH1DecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states
|
||||
)
|
||||
attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
||||
|
||||
mamba_hidden_states = self.mamba(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=mamba_attention_mask,
|
||||
hidden_states = hidden_states,
|
||||
cache_params = past_key_value,
|
||||
cache_position = cache_position,
|
||||
attention_mask = mamba_attention_mask,
|
||||
)
|
||||
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
||||
|
||||
|
|
@ -469,23 +469,23 @@ def FalconH1DecoderLayer_fast_forward(
|
|||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
|
||||
mamba_hidden_states = self.mamba(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=mamba_attention_mask,
|
||||
hidden_states = hidden_states,
|
||||
cache_params = past_key_value,
|
||||
cache_position = cache_position,
|
||||
attention_mask = mamba_attention_mask,
|
||||
)
|
||||
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
||||
|
||||
attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
||||
|
||||
|
|
@ -509,8 +509,8 @@ def FalconH1DecoderLayer_fast_forward(
|
|||
|
||||
|
||||
def _FalconH1_fast_forward_inference(
|
||||
attention_fast_forward_inference=FalconH1Attention_fast_forward_inference,
|
||||
mlp_fast_forward_inference=fast_swiglu_inference,
|
||||
attention_fast_forward_inference = FalconH1Attention_fast_forward_inference,
|
||||
mlp_fast_forward_inference = fast_swiglu_inference,
|
||||
):
|
||||
# This makes the attention and MLP customisable.
|
||||
# Now for models like qwen3 or cohere which use custom attention operations, we can use this function
|
||||
|
|
@ -519,9 +519,9 @@ def _FalconH1_fast_forward_inference(
|
|||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
cache_position=None,
|
||||
attention_mask=None,
|
||||
mamba_attention_mask=None,
|
||||
cache_position = None,
|
||||
attention_mask = None,
|
||||
mamba_attention_mask = None,
|
||||
):
|
||||
input_ids = input_ids[:, : self.max_seq_length]
|
||||
bsz, q_len = input_ids.shape
|
||||
|
|
@ -536,11 +536,11 @@ def _FalconH1_fast_forward_inference(
|
|||
bsz, q_len, hd = X.shape
|
||||
assert q_len == 1
|
||||
# Get saved buffers to reduce memory movement
|
||||
residual = torch.empty((bsz, q_len, hd), dtype=torch.float32, device="cuda:0")
|
||||
_XX = torch.empty((2, bsz, q_len, hd), dtype=torch.float32, device="cuda:0")
|
||||
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
|
||||
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
|
||||
XX, XX2 = _XX[0], _XX[1]
|
||||
variance = torch.empty((bsz, q_len, 1), dtype=torch.float32, device="cuda:0")
|
||||
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype=X.dtype, device="cuda:0")
|
||||
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0")
|
||||
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
|
||||
temp_gate, temp_up = temp_mlp[0], temp_mlp[1]
|
||||
seq_len = past_key_values[0][0].shape[-2]
|
||||
if bsz != 1:
|
||||
|
|
@ -549,7 +549,7 @@ def _FalconH1_fast_forward_inference(
|
|||
(bsz, q_len),
|
||||
X,
|
||||
seq_len,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
sliding_window = getattr(self.config, "sliding_window", None),
|
||||
)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
|
@ -561,28 +561,28 @@ def _FalconH1_fast_forward_inference(
|
|||
X = fast_rms_layernorm_inference(
|
||||
decoder_layer.input_layernorm,
|
||||
X,
|
||||
XX=XX,
|
||||
XX2=XX2,
|
||||
variance=variance,
|
||||
XX = XX,
|
||||
XX2 = XX2,
|
||||
variance = variance,
|
||||
)
|
||||
attention_hidden_states, present_key_value = (
|
||||
attention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states=X * decoder_layer.attention_in_multiplier,
|
||||
past_key_value=past_key_values[idx],
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
hidden_states = X * decoder_layer.attention_in_multiplier,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = attention_mask,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
)
|
||||
)
|
||||
attention_hidden_states = (
|
||||
attention_hidden_states * decoder_layer.attn_out_multiplier
|
||||
)
|
||||
mamba_hidden_states = decoder_layer.mamba(
|
||||
hidden_states=X,
|
||||
cache_params=present_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=mamba_attention_mask,
|
||||
hidden_states = X,
|
||||
cache_params = present_key_value,
|
||||
cache_position = cache_position,
|
||||
attention_mask = mamba_attention_mask,
|
||||
)
|
||||
mamba_hidden_states = mamba_hidden_states * decoder_layer.ssm_out_multiplier
|
||||
X = mamba_hidden_states + attention_hidden_states
|
||||
|
|
@ -593,17 +593,17 @@ def _FalconH1_fast_forward_inference(
|
|||
X = fast_rms_layernorm_inference(
|
||||
decoder_layer.pre_ff_layernorm,
|
||||
X,
|
||||
XX=XX,
|
||||
XX2=XX2,
|
||||
variance=variance,
|
||||
XX = XX,
|
||||
XX2 = XX2,
|
||||
variance = variance,
|
||||
)
|
||||
X = mlp_fast_forward_inference(
|
||||
decoder_layer.feed_forward,
|
||||
X,
|
||||
temp_gate=temp_gate,
|
||||
temp_up=temp_up,
|
||||
gate_multiplier=gate_multiplier,
|
||||
down_multiplier=down_multiplier,
|
||||
temp_gate = temp_gate,
|
||||
temp_up = temp_up,
|
||||
gate_multiplier = gate_multiplier,
|
||||
down_multiplier = down_multiplier,
|
||||
)
|
||||
X += residual
|
||||
|
||||
|
|
@ -611,16 +611,16 @@ def _FalconH1_fast_forward_inference(
|
|||
X = fast_rms_layernorm_inference(
|
||||
self.model.final_layernorm,
|
||||
X,
|
||||
XX=XX,
|
||||
XX2=XX2,
|
||||
variance=variance,
|
||||
XX = XX,
|
||||
XX2 = XX2,
|
||||
variance = variance,
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=X,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=[],
|
||||
attentions=[],
|
||||
last_hidden_state = X,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
|
||||
return FalconH1Model_fast_forward_inference_custom
|
||||
|
|
@ -630,12 +630,12 @@ def _FalconH1_fast_forward_inference(
|
|||
def _fast_prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
past_key_values = None,
|
||||
attention_mask = None,
|
||||
inputs_embeds = None,
|
||||
cache_position = None,
|
||||
position_ids = None,
|
||||
use_cache = True,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
|
||||
|
|
@ -707,10 +707,10 @@ class FastFalconH1Model(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="FalconH1",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=FalconH1Attention,
|
||||
model_name = "FalconH1",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = FalconH1Attention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -738,30 +738,30 @@ class FastFalconH1Model(FastLlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained( # TODO: Change after release
|
||||
model_name="Qwen/FalconH1-7B",
|
||||
max_seq_length=4096,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None,
|
||||
fix_tokenizer=True,
|
||||
model_patcher=None,
|
||||
tokenizer_name=None,
|
||||
trust_remote_code=False,
|
||||
model_name = "Qwen/FalconH1-7B",
|
||||
max_seq_length = 4096,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None,
|
||||
fix_tokenizer = True,
|
||||
model_patcher = None,
|
||||
tokenizer_name = None,
|
||||
trust_remote_code = False,
|
||||
**kwargs,
|
||||
):
|
||||
return FastLlamaModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=FastFalconH1Model,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = FastFalconH1Model,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -67,11 +67,11 @@ def fast_geglu_inference(self, X):
|
|||
|
||||
gate = fast_linear_forward(self.gate_proj, X) # , out = temp[0])
|
||||
up = fast_linear_forward(self.up_proj, X) # , out = temp[1])
|
||||
gate = torch_nn_functional_gelu(gate, approximate="tanh")
|
||||
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
|
||||
gate *= up
|
||||
|
||||
# X = self.down_proj(gate)
|
||||
down = fast_linear_forward(self.down_proj, gate, out=up[:, :, :hd])
|
||||
down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
|
||||
return down
|
||||
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ def GemmaDecoderLayer_fast_forward(
|
|||
self, "_flag_for_generation"
|
||||
): # past_key_value is not None:
|
||||
out_weight = torch.empty(
|
||||
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
|
||||
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
|
|
@ -102,14 +102,14 @@ def GemmaDecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states, out_weight
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
)
|
||||
hidden_states += residual
|
||||
|
||||
|
|
@ -123,24 +123,24 @@ def GemmaDecoderLayer_fast_forward(
|
|||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.input_layernorm, hidden_states, gemma=True
|
||||
self.input_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.post_attention_layernorm, hidden_states, gemma=True
|
||||
self.post_attention_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
|
@ -163,13 +163,13 @@ def GemmaModel_fast_forward_inference(
|
|||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
attention_mask=None,
|
||||
attention_mask = None,
|
||||
):
|
||||
out_weights = tuple(
|
||||
torch.empty_like(
|
||||
self.model.layers[0].input_layernorm.weight,
|
||||
dtype=torch.float32,
|
||||
device=torch.device(x),
|
||||
dtype = torch.float32,
|
||||
device = torch.device(x),
|
||||
)
|
||||
for x in range(DEVICE_COUNT)
|
||||
)
|
||||
|
|
@ -179,7 +179,7 @@ def GemmaModel_fast_forward_inference(
|
|||
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
||||
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
||||
hidden_states *= torch.tensor(
|
||||
math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype
|
||||
math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
|
||||
)
|
||||
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
|
|
@ -205,11 +205,11 @@ def GemmaModel_fast_forward_inference(
|
|||
)
|
||||
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_values[idx],
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
hidden_states = hidden_states,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = attention_mask,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
)
|
||||
hidden_states += residual
|
||||
|
||||
|
|
@ -228,10 +228,10 @@ def GemmaModel_fast_forward_inference(
|
|||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=[],
|
||||
attentions=[],
|
||||
last_hidden_state = hidden_states,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -243,11 +243,11 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
|||
# The precision of RoPE buffers is not correct, so we cast to int64.
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
config=None, # [TODO] Hack to pass in config - need to remove later
|
||||
dim = None,
|
||||
max_position_embeddings = 2048,
|
||||
base = 10000,
|
||||
device = None,
|
||||
config = None, # [TODO] Hack to pass in config - need to remove later
|
||||
):
|
||||
super().__init__()
|
||||
if config is not None:
|
||||
|
|
@ -274,17 +274,17 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
|||
# Build here to make `torch.jit.trace` work.
|
||||
for device in range(DEVICE_COUNT):
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.current_rope_size,
|
||||
device=torch.device(device),
|
||||
dtype=torch.get_default_dtype(),
|
||||
seq_len = self.current_rope_size,
|
||||
device = torch.device(device),
|
||||
dtype = torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
# dummy so that patch_utils doesn't fail for now
|
||||
self.cos_cached = torch.empty(
|
||||
1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()
|
||||
1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
|
||||
)
|
||||
self.sin_cached = torch.empty(
|
||||
1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()
|
||||
1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
|
|
@ -294,27 +294,27 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
|||
|
||||
# The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
|
||||
freq_exponents = (2.0 / self.dim) * (
|
||||
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
|
||||
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
|
||||
)
|
||||
timescale = self.base**freq_exponents
|
||||
positions = torch.arange(
|
||||
self.current_rope_size, device="cpu", dtype=torch.int64
|
||||
self.current_rope_size, device = "cpu", dtype = torch.int64
|
||||
).float()
|
||||
radians_new = positions[..., None] / timescale[None, None, :]
|
||||
radians_new = radians_new.squeeze(0)
|
||||
|
||||
emb = torch.cat((radians_new, radians_new), dim=-1)
|
||||
emb = torch.cat((radians_new, radians_new), dim = -1)
|
||||
# We must do RoPE in float32!
|
||||
cos = emb.cos().to(device=device, non_blocking=True) # , dtype = dtype)
|
||||
sin = emb.sin().to(device=device, non_blocking=True) # , dtype = dtype)
|
||||
cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
|
||||
sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
|
||||
self.multi_gpu_cos_cached[device.index] = cos
|
||||
self.multi_gpu_sin_cached[device.index] = sin
|
||||
return cos, sin
|
||||
|
||||
def forward(self, x, position_ids=None, seq_len=None):
|
||||
def forward(self, x, position_ids = None, seq_len = None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len is not None and seq_len > self.current_rope_size:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
|
||||
|
||||
device_index = x.device.index
|
||||
|
||||
|
|
@ -323,7 +323,7 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
|||
self.multi_gpu_sin_cached[device_index][:seq_len],
|
||||
)
|
||||
|
||||
def get_cached(self, seq_len=None, device_index=None):
|
||||
def get_cached(self, seq_len = None, device_index = None):
|
||||
if device_index is None:
|
||||
device_index = torch.cuda.current_device()
|
||||
return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
|
||||
|
|
@ -337,7 +337,7 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
|||
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
|
||||
for device in range(DEVICE_COUNT):
|
||||
self._set_cos_sin_cache(
|
||||
self.current_rope_size, device=torch.device(device), dtype=x.dtype
|
||||
self.current_rope_size, device = torch.device(device), dtype = x.dtype
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -349,20 +349,20 @@ class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
|
|||
# The precision of RoPE buffers is not correct, so we cast to int64.
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
config=None, # [TODO] Hack to pass in config - need to remove later
|
||||
dim = None,
|
||||
max_position_embeddings = 2048,
|
||||
base = 10000,
|
||||
device = None,
|
||||
scaling_factor = 1.0,
|
||||
config = None, # [TODO] Hack to pass in config - need to remove later
|
||||
):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
device=device,
|
||||
config=config,
|
||||
dim = dim,
|
||||
max_position_embeddings = max_position_embeddings,
|
||||
base = base,
|
||||
device = device,
|
||||
config = config,
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
|
|
@ -372,20 +372,20 @@ class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
|
|||
|
||||
# The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
|
||||
freq_exponents = (2.0 / self.dim) * (
|
||||
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
|
||||
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
|
||||
)
|
||||
timescale = self.base**freq_exponents
|
||||
positions = torch.arange(
|
||||
self.current_rope_size, device="cpu", dtype=torch.int64
|
||||
self.current_rope_size, device = "cpu", dtype = torch.int64
|
||||
).float()
|
||||
positions = positions / self.scaling_factor
|
||||
radians_new = positions[..., None] / timescale[None, None, :]
|
||||
radians_new = radians_new.squeeze(0)
|
||||
|
||||
emb = torch.cat((radians_new, radians_new), dim=-1)
|
||||
emb = torch.cat((radians_new, radians_new), dim = -1)
|
||||
# We must do RoPE in float32!
|
||||
cos = emb.cos().to(device=device, non_blocking=True) # , dtype = dtype)
|
||||
sin = emb.sin().to(device=device, non_blocking=True) # , dtype = dtype)
|
||||
cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
|
||||
sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
|
||||
self.multi_gpu_cos_cached[device.index] = cos
|
||||
self.multi_gpu_sin_cached[device.index] = sin
|
||||
return cos, sin
|
||||
|
|
@ -395,10 +395,10 @@ class FastGemmaModel(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="gemma",
|
||||
rope_module=GemmaFixedRotaryEmbedding,
|
||||
scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding,
|
||||
attention_module=GemmaAttention,
|
||||
model_name = "gemma",
|
||||
rope_module = GemmaFixedRotaryEmbedding,
|
||||
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
|
||||
attention_module = GemmaAttention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -430,7 +430,7 @@ class FastGemmaModel(FastLlamaModel):
|
|||
def post_patch(model, tokenizer):
|
||||
# Gemma does not downcast RoPE
|
||||
model, tokenizer = patch_model_and_tokenizer(
|
||||
model, tokenizer, downcast_rope=False
|
||||
model, tokenizer, downcast_rope = False
|
||||
)
|
||||
|
||||
# Add 1 to weight
|
||||
|
|
|
|||
|
|
@ -47,13 +47,13 @@ BAD_MAPPINGS = {
|
|||
|
||||
def __get_model_name(
|
||||
model_name,
|
||||
load_in_4bit=True,
|
||||
INT_TO_FLOAT_MAPPER=None,
|
||||
FLOAT_TO_INT_MAPPER=None,
|
||||
MAP_TO_UNSLOTH_16bit=None,
|
||||
load_in_fp8=False,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER=None,
|
||||
FLOAT_TO_FP8_ROW_MAPPER=None,
|
||||
load_in_4bit = True,
|
||||
INT_TO_FLOAT_MAPPER = None,
|
||||
FLOAT_TO_INT_MAPPER = None,
|
||||
MAP_TO_UNSLOTH_16bit = None,
|
||||
load_in_fp8 = False,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER = None,
|
||||
FLOAT_TO_FP8_ROW_MAPPER = None,
|
||||
):
|
||||
model_name = str(model_name)
|
||||
lower_model_name = model_name.lower()
|
||||
|
|
@ -116,7 +116,7 @@ def _get_new_mapper():
|
|||
import requests
|
||||
|
||||
new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
|
||||
with requests.get(new_mapper, timeout=3) as new_mapper:
|
||||
with requests.get(new_mapper, timeout = 3) as new_mapper:
|
||||
new_mapper = new_mapper.text
|
||||
new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER") :]
|
||||
new_mapper = (
|
||||
|
|
@ -135,17 +135,17 @@ def _get_new_mapper():
|
|||
return {}, {}, {}
|
||||
|
||||
|
||||
def get_model_name(model_name, load_in_4bit=True, load_in_fp8=False):
|
||||
def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
|
||||
assert load_in_fp8 in (True, False, "block")
|
||||
new_model_name = __get_model_name(
|
||||
model_name=model_name,
|
||||
load_in_4bit=load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER=INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER=FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit=MAP_TO_UNSLOTH_16bit,
|
||||
load_in_fp8=load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER=FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER=FLOAT_TO_FP8_ROW_MAPPER,
|
||||
model_name = model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
|
||||
)
|
||||
# In the rare case, we convert bad model names to other names
|
||||
# For eg too large dynamic quants or MoEs
|
||||
|
|
@ -166,14 +166,14 @@ def get_model_name(model_name, load_in_4bit=True, load_in_fp8=False):
|
|||
_get_new_mapper()
|
||||
)
|
||||
upgraded_model_name = __get_model_name(
|
||||
model_name=model_name,
|
||||
load_in_4bit=load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER=NEW_INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER=NEW_FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit=NEW_MAP_TO_UNSLOTH_16bit,
|
||||
load_in_fp8=load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER=FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER=FLOAT_TO_FP8_ROW_MAPPER,
|
||||
model_name = model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
|
||||
)
|
||||
if upgraded_model_name is not None:
|
||||
raise NotImplementedError(
|
||||
|
|
@ -207,8 +207,8 @@ def _get_torchao_fp8_config(fp8_mode: str):
|
|||
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")
|
||||
|
||||
return Float8DynamicActivationFloat8WeightConfig(
|
||||
granularity=granularity,
|
||||
activation_value_lb=1e-12,
|
||||
granularity = granularity,
|
||||
activation_value_lb = 1e-12,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -255,12 +255,12 @@ def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
|
|||
auto_processor = AutoProcessor if is_vlm else AutoTokenizer
|
||||
model = auto_model.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=qconfig,
|
||||
torch_dtype = "auto",
|
||||
device_map = "auto",
|
||||
quantization_config = qconfig,
|
||||
)
|
||||
tokenizer = auto_processor.from_pretrained(model_name)
|
||||
model.save_pretrained(new_model_name, safe_serialization=False)
|
||||
model.save_pretrained(new_model_name, safe_serialization = False)
|
||||
del model
|
||||
for _ in range(2):
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -276,8 +276,8 @@ def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
|
|||
try:
|
||||
base_config = _get_torchao_fp8_config(fp8_mode)
|
||||
model.torchao_config = TorchAOConfig(
|
||||
qat_scheme=None,
|
||||
base_config_and_filter_fns=[(base_config, None)],
|
||||
qat_scheme = None,
|
||||
base_config_and_filter_fns = [(base_config, None)],
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def MistralAttention_fast_forward(
|
|||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# Extend RoPE dynamically to fit in VRAM
|
||||
self.rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
|
||||
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
|
||||
|
||||
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
|
||||
if position_ids is None:
|
||||
|
|
@ -91,8 +91,8 @@ def MistralAttention_fast_forward(
|
|||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -128,7 +128,7 @@ def MistralAttention_fast_forward(
|
|||
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
|
||||
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION and attention_mask is None:
|
||||
|
|
@ -138,7 +138,7 @@ def MistralAttention_fast_forward(
|
|||
sw = getattr(self.config, "sliding_window", None)
|
||||
sw = kv_seq_len if (sw is None or sw == "null") else sw
|
||||
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
# Grouped query attention
|
||||
# if n_groups != 1:
|
||||
|
|
@ -153,7 +153,7 @@ def MistralAttention_fast_forward(
|
|||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
A = scaled_dot_product_attention(
|
||||
Q, K, V, attn_mask=attention_mask, is_causal=False
|
||||
Q, K, V, attn_mask = attention_mask, is_causal = False
|
||||
)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
|
|
@ -199,7 +199,7 @@ def MistralForCausalLM_fast_forward(
|
|||
else:
|
||||
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
[q_len] * bsz
|
||||
).make_local_attention(window_size=sliding_window)
|
||||
).make_local_attention(window_size = sliding_window)
|
||||
|
||||
# If attention_mask exists, it will be handled in the attention forward
|
||||
|
||||
|
|
@ -213,13 +213,13 @@ def MistralForCausalLM_fast_forward(
|
|||
):
|
||||
# Fully causal mask
|
||||
causal_mask_values = torch.triu(
|
||||
torch.full((q_len, q_len), -torch.inf, device=input_ids.device),
|
||||
diagonal=1,
|
||||
torch.full((q_len, q_len), -torch.inf, device = input_ids.device),
|
||||
diagonal = 1,
|
||||
)
|
||||
else:
|
||||
# Sliding window attention
|
||||
q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1)
|
||||
k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1)
|
||||
q_indices = torch.arange(q_len, device = input_ids.device).view(-1, 1)
|
||||
k_indices = torch.arange(q_len, device = input_ids.device).view(1, -1)
|
||||
|
||||
causal_bool_mask = k_indices <= q_indices
|
||||
window_bool_mask = (q_indices - k_indices) < sliding_window
|
||||
|
|
@ -243,7 +243,7 @@ def MistralForCausalLM_fast_forward(
|
|||
attention_mask = attention_mask + causal_mask_values[None, None, :, :]
|
||||
|
||||
attention_mask = attention_mask.to(
|
||||
dtype=_get_dtype(dtype_from_config(self.config))
|
||||
dtype = _get_dtype(dtype_from_config(self.config))
|
||||
)
|
||||
|
||||
output_attentions = (
|
||||
|
|
@ -268,21 +268,21 @@ def MistralForCausalLM_fast_forward(
|
|||
self,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids = position_ids,
|
||||
attention_mask = attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
input_ids = input_ids,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_values = past_key_values,
|
||||
inputs_embeds = inputs_embeds,
|
||||
use_cache = use_cache,
|
||||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
return_dict = return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
|
@ -302,11 +302,11 @@ def MistralForCausalLM_fast_forward(
|
|||
if num_logits_to_keep != 0:
|
||||
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=hidden_states,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
loss = None,
|
||||
logits = hidden_states,
|
||||
past_key_values = outputs.past_key_values,
|
||||
hidden_states = outputs.hidden_states,
|
||||
attentions = outputs.attentions,
|
||||
)
|
||||
|
||||
if bsz == 1 and q_len == 1:
|
||||
|
|
@ -337,28 +337,28 @@ def MistralForCausalLM_fast_forward(
|
|||
# logit_softcapping = logit_softcapping,
|
||||
# )
|
||||
loss = unsloth_fused_ce_loss(
|
||||
trainer=None,
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=lm_head,
|
||||
lm_head_bias=None,
|
||||
labels=labels,
|
||||
mask=None,
|
||||
n_items=n_items,
|
||||
scaling=getattr(self, "accelerator_scaler", None),
|
||||
target_gb=None,
|
||||
torch_compile=True,
|
||||
logit_softcapping=logit_softcapping,
|
||||
trainer = None,
|
||||
hidden_states = hidden_states,
|
||||
lm_head_weight = lm_head,
|
||||
lm_head_bias = None,
|
||||
labels = labels,
|
||||
mask = None,
|
||||
n_items = n_items,
|
||||
scaling = getattr(self, "accelerator_scaler", None),
|
||||
target_gb = None,
|
||||
torch_compile = True,
|
||||
logit_softcapping = logit_softcapping,
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
output = CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=EMPTY_LOGITS,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
loss = loss,
|
||||
logits = EMPTY_LOGITS,
|
||||
past_key_values = outputs.past_key_values,
|
||||
hidden_states = outputs.hidden_states,
|
||||
attentions = outputs.attentions,
|
||||
)
|
||||
return output
|
||||
pass
|
||||
|
|
@ -377,9 +377,9 @@ def MistralForCausalLM_fast_forward(
|
|||
shift_labels[..., :-1] = labels[..., 1:]
|
||||
shift_labels[..., -1] = -100
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits=shift_logits,
|
||||
labels=shift_labels,
|
||||
n_items=kwargs.get("num_items_in_batch", None)
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
n_items = kwargs.get("num_items_in_batch", None)
|
||||
or kwargs.get("n_items", None),
|
||||
)
|
||||
|
||||
|
|
@ -388,11 +388,11 @@ def MistralForCausalLM_fast_forward(
|
|||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
loss = loss,
|
||||
logits = logits,
|
||||
past_key_values = outputs.past_key_values,
|
||||
hidden_states = outputs.hidden_states,
|
||||
attentions = outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -417,10 +417,10 @@ class FastMistralModel(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="mistral",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=MistralAttention,
|
||||
model_name = "mistral",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = MistralAttention,
|
||||
)
|
||||
# Just for Mistral Nemo models!
|
||||
if function is not None and init_name is not None:
|
||||
|
|
@ -451,30 +451,30 @@ class FastMistralModel(FastLlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="unsloth/mistral-7b-bnb-4bit",
|
||||
max_seq_length=None,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None, # Mistral does not support RoPE scaling
|
||||
fix_tokenizer=True,
|
||||
model_patcher=None,
|
||||
tokenizer_name=None,
|
||||
trust_remote_code=False,
|
||||
model_name = "unsloth/mistral-7b-bnb-4bit",
|
||||
max_seq_length = None,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None, # Mistral does not support RoPE scaling
|
||||
fix_tokenizer = True,
|
||||
model_patcher = None,
|
||||
tokenizer_name = None,
|
||||
trust_remote_code = False,
|
||||
**kwargs,
|
||||
):
|
||||
return FastLlamaModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=FastMistralModel,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = FastMistralModel,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,9 +59,7 @@ def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = Non
|
|||
self.gate_proj, X, out = temp_gate
|
||||
) # pretty much the only change from transformers implementation.
|
||||
|
||||
routing_weights = torch_nn_functional_softmax(
|
||||
router_logits, dim = -1, dtype = torch.float32
|
||||
)
|
||||
routing_weights = torch_nn_functional_softmax(router_logits, dim = -1, dtype = torch.float32)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
|
||||
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
|
||||
# we cast back to the input dtype
|
||||
|
|
|
|||
|
|
@ -329,7 +329,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
try:
|
||||
trainer = eval(f"trl.trainer.{trainer_file}")
|
||||
except Exception as error:
|
||||
print(f"Unsloth: Could not import trl.trainer.{trainer_file}: {error}")
|
||||
return
|
||||
|
||||
# Get SFTTrainer and SFTConfig names
|
||||
|
|
@ -348,14 +347,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
and trainer_file.split("_")[0] in x.lower()
|
||||
]
|
||||
if len(name) != 1:
|
||||
print(
|
||||
f"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}"
|
||||
)
|
||||
return
|
||||
if len(config) != 1:
|
||||
print(
|
||||
f"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}"
|
||||
)
|
||||
return
|
||||
|
||||
# Get SFTTrainer, SFTConfig
|
||||
|
|
@ -364,24 +357,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
try:
|
||||
RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
|
||||
except:
|
||||
print(
|
||||
f"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}"
|
||||
)
|
||||
return
|
||||
try:
|
||||
RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
|
||||
except:
|
||||
print(
|
||||
f"Unsloth: Could not load {RLConfig_name} from trl.trainer.{trainer_file}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check name
|
||||
if RLTrainer.__name__.startswith("Unsloth"):
|
||||
print(f"Unsloth: {RLTrainer.__name__} is already patched.")
|
||||
return
|
||||
if RLConfig.__name__.startswith("Unsloth"):
|
||||
print(f"Unsloth: {RLConfig.__name__} is already patched.")
|
||||
return
|
||||
|
||||
# Get old source
|
||||
|
|
@ -1306,11 +1291,7 @@ def patch_trl_rl_trainers():
|
|||
import trl.trainer
|
||||
|
||||
all_trainers = dir(trl.trainer)
|
||||
all_trainers = [
|
||||
x
|
||||
for x in all_trainers
|
||||
if x.islower() and x.endswith("_trainer") and x != "base_trainer"
|
||||
]
|
||||
all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
|
||||
for trainer in all_trainers:
|
||||
_patch_trl_rl_trainers(trainer)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ def sft_trainer_prepare_dataset(function_name, function):
|
|||
matched = re.match(
|
||||
r"[\s]{0,}def _prepare_dataset\(.*?" + params + r".*?\)",
|
||||
function,
|
||||
flags=re.MULTILINE | re.DOTALL,
|
||||
flags = re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if matched:
|
||||
# Use fast version!
|
||||
|
|
@ -147,7 +147,7 @@ def sft_trainer_prepare_dataset(function_name, function):
|
|||
replacer = re.findall(
|
||||
r"def " + function_name + r"\(.*?\).*?\:\n",
|
||||
function,
|
||||
flags=re.MULTILINE | re.DOTALL,
|
||||
flags = re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if len(replacer) != 0:
|
||||
replacer = replacer[0]
|
||||
|
|
@ -175,13 +175,13 @@ def sft_trainer_compute_loss(function_name, function):
|
|||
return function
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
self, model, inputs, return_outputs = False, num_items_in_batch = None
|
||||
):
|
||||
outputs = super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
return_outputs = return_outputs,
|
||||
num_items_in_batch = num_items_in_batch,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
|
@ -286,7 +286,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
|
|||
r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"
|
||||
r"\2if self\.use_vllm:)",
|
||||
function,
|
||||
flags=re.DOTALL | re.MULTILINE,
|
||||
flags = re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if len(found) != 0:
|
||||
replace_part, spacing = found[0]
|
||||
|
|
@ -441,7 +441,7 @@ def grpo_trainer__get_per_token_logps(function_name, function):
|
|||
return function
|
||||
|
||||
def _get_per_token_logps(
|
||||
self, model, input_ids, attention_mask, logits_to_keep, compute_efficient=False
|
||||
self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False
|
||||
):
|
||||
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
|
||||
return None # Unsloth efficient GRPO
|
||||
|
|
@ -456,12 +456,12 @@ def grpo_trainer__get_per_token_logps(function_name, function):
|
|||
self._autocast_dtype = torch.float16
|
||||
|
||||
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
||||
with torch.amp.autocast(device_type=DEVICE_TYPE, dtype=self._autocast_dtype):
|
||||
with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
logits_to_keep=logits_to_keep + 1,
|
||||
input_ids = input_ids,
|
||||
attention_mask = attention_mask,
|
||||
logits_to_keep = logits_to_keep + 1,
|
||||
).logits
|
||||
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
return logits
|
||||
|
|
@ -500,9 +500,9 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
|
|||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=None,
|
||||
compute_entropy=False,
|
||||
compute_efficient=False,
|
||||
batch_size = None,
|
||||
compute_entropy = False,
|
||||
compute_efficient = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -533,33 +533,33 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
|
|||
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(
|
||||
model, keep_fp32_wrapper=False
|
||||
model, keep_fp32_wrapper = False
|
||||
)
|
||||
|
||||
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
|
||||
with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype):
|
||||
with _get_inference_mode_context_manager(model):
|
||||
if pixel_values is None:
|
||||
attention_mask = input_ids != self.processing_class.pad_token_id
|
||||
attention_mask = attention_mask.to(attention_mask.dtype)
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = unwrapped_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
input_ids = input_ids,
|
||||
attention_mask = attention_mask,
|
||||
pixel_values = pixel_values,
|
||||
image_grid_thw = image_grid_thw,
|
||||
pixel_attention_mask = pixel_attention_mask,
|
||||
image_sizes = image_sizes,
|
||||
# logits_to_keep = logits_to_keep + 1,
|
||||
).logits
|
||||
else:
|
||||
logits = unwrapped_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
logits_to_keep=logits_to_keep + 1,
|
||||
input_ids = input_ids,
|
||||
attention_mask = attention_mask,
|
||||
pixel_values = pixel_values,
|
||||
image_grid_thw = image_grid_thw,
|
||||
pixel_attention_mask = pixel_attention_mask,
|
||||
image_sizes = image_sizes,
|
||||
logits_to_keep = logits_to_keep + 1,
|
||||
).logits
|
||||
|
||||
entropies = None
|
||||
|
|
@ -615,7 +615,7 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
return function
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
self, model, inputs, return_outputs = False, num_items_in_batch = None
|
||||
):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
|
|
@ -639,9 +639,9 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
current_gradient_accumulation_steps = self.current_gradient_accumulation_steps
|
||||
num_processes = self.accelerator.num_processes
|
||||
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
|
||||
bsz, qlen = input_ids.shape
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
|
||||
# attention_mask = None
|
||||
logits_to_keep = completion_ids.size(
|
||||
1
|
||||
|
|
@ -654,9 +654,9 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=None,
|
||||
compute_entropy=False,
|
||||
compute_efficient=False: self._get_per_token_logps(
|
||||
batch_size = None,
|
||||
compute_entropy = False,
|
||||
compute_efficient = False: self._get_per_token_logps(
|
||||
model, input_ids, attention_mask, logits_to_keep, compute_efficient
|
||||
)
|
||||
if hasattr(self, "_get_per_token_logps")
|
||||
|
|
@ -672,7 +672,7 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
) # logps
|
||||
|
||||
per_token_logps = get_logps_func(
|
||||
model, input_ids, attention_mask, logits_to_keep, compute_efficient=True
|
||||
model, input_ids, attention_mask, logits_to_keep, compute_efficient = True
|
||||
)
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
|
||||
|
|
@ -726,71 +726,71 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
completion_mask,
|
||||
self.beta,
|
||||
advantages,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
loss_type=self.args.loss_type,
|
||||
importance_sampling_level=self.importance_sampling_level,
|
||||
epsilon_low=self.epsilon_low,
|
||||
epsilon_high=self.epsilon_high,
|
||||
max_completion_length=self.args.max_completion_length,
|
||||
delta=self.args.delta,
|
||||
temperature=self.args.temperature,
|
||||
logit_softcapping=logit_softcapping,
|
||||
logit_scale_multiply=logit_scale_multiply,
|
||||
logit_scale_divide=logit_scale_divide,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
current_gradient_accumulation_steps=current_gradient_accumulation_steps,
|
||||
num_processes=num_processes,
|
||||
sampling_per_token_logps=sampling_per_token_logps,
|
||||
pixel_values = pixel_values,
|
||||
image_grid_thw = image_grid_thw,
|
||||
loss_type = self.args.loss_type,
|
||||
importance_sampling_level = self.importance_sampling_level,
|
||||
epsilon_low = self.epsilon_low,
|
||||
epsilon_high = self.epsilon_high,
|
||||
max_completion_length = self.args.max_completion_length,
|
||||
delta = self.args.delta,
|
||||
temperature = self.args.temperature,
|
||||
logit_softcapping = logit_softcapping,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
num_items_in_batch = num_items_in_batch,
|
||||
current_gradient_accumulation_steps = current_gradient_accumulation_steps,
|
||||
num_processes = num_processes,
|
||||
sampling_per_token_logps = sampling_per_token_logps,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if hasattr(self.args, "loss_type"):
|
||||
loss, completion_length, mean_kl, delta, flat_is_ratio = (
|
||||
grpo_accumulated_loss(
|
||||
trainer=self,
|
||||
input_ids=_input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
logits_to_keep=logits_to_keep,
|
||||
completion_mask=completion_mask,
|
||||
advantages=advantages,
|
||||
old_hidden_states=old_hidden_states,
|
||||
ref_hidden_states=ref_hidden_states,
|
||||
n_chunks=self.args.unsloth_num_chunks,
|
||||
loss_type=self.args.loss_type,
|
||||
importance_sampling_level=self.importance_sampling_level,
|
||||
epsilon_low=self.epsilon_low,
|
||||
epsilon_high=self.epsilon_high,
|
||||
max_completion_length=self.args.max_completion_length,
|
||||
delta=self.args.delta,
|
||||
temperature=self.args.temperature,
|
||||
logit_softcapping=logit_softcapping,
|
||||
logit_scale_multiply=logit_scale_multiply,
|
||||
logit_scale_divide=logit_scale_divide,
|
||||
attention_mask=attention_mask,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
current_gradient_accumulation_steps=current_gradient_accumulation_steps,
|
||||
num_processes=num_processes,
|
||||
sampling_per_token_logps=sampling_per_token_logps,
|
||||
trainer = self,
|
||||
input_ids = _input_ids,
|
||||
pixel_values = pixel_values,
|
||||
image_grid_thw = image_grid_thw,
|
||||
logits_to_keep = logits_to_keep,
|
||||
completion_mask = completion_mask,
|
||||
advantages = advantages,
|
||||
old_hidden_states = old_hidden_states,
|
||||
ref_hidden_states = ref_hidden_states,
|
||||
n_chunks = self.args.unsloth_num_chunks,
|
||||
loss_type = self.args.loss_type,
|
||||
importance_sampling_level = self.importance_sampling_level,
|
||||
epsilon_low = self.epsilon_low,
|
||||
epsilon_high = self.epsilon_high,
|
||||
max_completion_length = self.args.max_completion_length,
|
||||
delta = self.args.delta,
|
||||
temperature = self.args.temperature,
|
||||
logit_softcapping = logit_softcapping,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
attention_mask = attention_mask,
|
||||
num_items_in_batch = num_items_in_batch,
|
||||
current_gradient_accumulation_steps = current_gradient_accumulation_steps,
|
||||
num_processes = num_processes,
|
||||
sampling_per_token_logps = sampling_per_token_logps,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
|
||||
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
||||
trainer=self,
|
||||
input_ids=_input_ids,
|
||||
logits_to_keep=logits_to_keep,
|
||||
completion_mask=completion_mask,
|
||||
advantages=advantages,
|
||||
old_hidden_states=old_hidden_states,
|
||||
ref_hidden_states=ref_hidden_states,
|
||||
n_chunks=self.args.unsloth_num_chunks,
|
||||
temperature=self.args.temperature,
|
||||
logit_softcapping=logit_softcapping,
|
||||
logit_scale_multiply=logit_scale_multiply,
|
||||
logit_scale_divide=logit_scale_divide,
|
||||
attention_mask=attention_mask,
|
||||
trainer = self,
|
||||
input_ids = _input_ids,
|
||||
logits_to_keep = logits_to_keep,
|
||||
completion_mask = completion_mask,
|
||||
advantages = advantages,
|
||||
old_hidden_states = old_hidden_states,
|
||||
ref_hidden_states = ref_hidden_states,
|
||||
n_chunks = self.args.unsloth_num_chunks,
|
||||
temperature = self.args.temperature,
|
||||
logit_softcapping = logit_softcapping,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
attention_mask = attention_mask,
|
||||
)
|
||||
|
||||
if "train" in self._metrics:
|
||||
|
|
@ -805,12 +805,12 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
mean_delta = (
|
||||
torch.mean(delta)
|
||||
if delta.numel() > 0
|
||||
else torch.tensor(0.0, device=self.model.device)
|
||||
else torch.tensor(0.0, device = self.model.device)
|
||||
)
|
||||
max_delta = (
|
||||
torch.max(delta)
|
||||
if delta.numel() > 0
|
||||
else torch.tensor(0.0, device=self.model.device)
|
||||
else torch.tensor(0.0, device = self.model.device)
|
||||
)
|
||||
self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
|
||||
self.accelerator.gather(mean_delta).mean().item()
|
||||
|
|
@ -822,17 +822,17 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
min_importance_sampling_ratio = (
|
||||
torch.min(flat_is_ratio)
|
||||
if flat_is_ratio.numel() > 0
|
||||
else torch.tensor(0.0, device=self.model.device)
|
||||
else torch.tensor(0.0, device = self.model.device)
|
||||
)
|
||||
mean_importance_sampling_ratio = (
|
||||
torch.mean(flat_is_ratio)
|
||||
if flat_is_ratio.numel() > 0
|
||||
else torch.tensor(0.0, device=self.model.device)
|
||||
else torch.tensor(0.0, device = self.model.device)
|
||||
)
|
||||
max_importance_sampling_ratio = (
|
||||
torch.max(flat_is_ratio)
|
||||
if flat_is_ratio.numel() > 0
|
||||
else torch.tensor(0.0, device=self.model.device)
|
||||
else torch.tensor(0.0, device = self.model.device)
|
||||
)
|
||||
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
|
||||
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
|
||||
|
|
|
|||
|
|
@ -115,9 +115,9 @@ HAS_TORCH_DTYPE = "torch_dtype" in PretrainedConfig.__doc__
|
|||
from transformers import GenerationConfig, CompileConfig, HybridCache
|
||||
|
||||
_compile_config = CompileConfig(
|
||||
fullgraph=False,
|
||||
dynamic=None,
|
||||
mode="reduce-overhead",
|
||||
fullgraph = False,
|
||||
dynamic = None,
|
||||
mode = "reduce-overhead",
|
||||
)
|
||||
_compile_config.disable = True # Must set manually
|
||||
|
||||
|
|
@ -219,10 +219,10 @@ def unsloth_base_fast_generate(
|
|||
|
||||
# Mixed precision autocast
|
||||
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
|
||||
autocaster = torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=torch.float16)
|
||||
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = torch.float16)
|
||||
dtype = torch.float16
|
||||
else:
|
||||
autocaster = torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=dtype)
|
||||
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype)
|
||||
# Prepare LoRA
|
||||
# state_dict = convert_lora_modules(self, dtype = dtype)
|
||||
|
||||
|
|
@ -316,34 +316,34 @@ def unsloth_base_fast_generate(
|
|||
class FastBaseModel:
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
load_in_16bit=False,
|
||||
full_finetuning=False,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
trust_remote_code=False,
|
||||
model_types=None,
|
||||
tokenizer_name=None,
|
||||
auto_model=AutoModelForVision2Seq,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
supports_sdpa=True,
|
||||
whisper_language=None,
|
||||
whisper_task=None,
|
||||
auto_config=None,
|
||||
offload_embedding=False,
|
||||
float32_mixed_precision=None, # Forces float32 mixed precision
|
||||
model_name = "unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
load_in_16bit = False,
|
||||
full_finetuning = False,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
trust_remote_code = False,
|
||||
model_types = None,
|
||||
tokenizer_name = None,
|
||||
auto_model = AutoModelForVision2Seq,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
supports_sdpa = True,
|
||||
whisper_language = None,
|
||||
whisper_task = None,
|
||||
auto_config = None,
|
||||
offload_embedding = False,
|
||||
float32_mixed_precision = None, # Forces float32 mixed precision
|
||||
# vLLM parameters
|
||||
fast_inference=False,
|
||||
gpu_memory_utilization=0.5,
|
||||
float8_kv_cache=False,
|
||||
random_state=3407,
|
||||
max_lora_rank=64,
|
||||
disable_log_stats=False,
|
||||
unsloth_vllm_standby=False,
|
||||
fast_inference = False,
|
||||
gpu_memory_utilization = 0.5,
|
||||
float8_kv_cache = False,
|
||||
random_state = 3407,
|
||||
max_lora_rank = 64,
|
||||
disable_log_stats = False,
|
||||
unsloth_vllm_standby = False,
|
||||
**kwargs,
|
||||
):
|
||||
if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1":
|
||||
|
|
@ -539,16 +539,16 @@ class FastBaseModel:
|
|||
)
|
||||
if load_in_4bit:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=bnb_compute_dtype,
|
||||
llm_int8_skip_modules=SKIP_QUANTIZATION_MODULES.copy(),
|
||||
load_in_4bit = True,
|
||||
bnb_4bit_use_double_quant = True,
|
||||
bnb_4bit_quant_type = "nf4",
|
||||
bnb_4bit_compute_dtype = bnb_compute_dtype,
|
||||
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
|
||||
)
|
||||
elif load_in_8bit:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_skip_modules=SKIP_QUANTIZATION_MODULES.copy(),
|
||||
load_in_8bit = True,
|
||||
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
|
||||
)
|
||||
elif load_in_16bit:
|
||||
bnb_config = None
|
||||
|
|
@ -597,8 +597,8 @@ class FastBaseModel:
|
|||
if auto_config is None:
|
||||
auto_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
if hasattr(auto_config, "quantization_config"):
|
||||
from transformers.quantizers.auto import (
|
||||
|
|
@ -648,9 +648,9 @@ class FastBaseModel:
|
|||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
attn_implementation="sdpa" if supports_sdpa else "eager",
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
attn_implementation = "sdpa" if supports_sdpa else "eager",
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
verify_fp8_support_if_applicable(model_config)
|
||||
|
||||
|
|
@ -660,11 +660,11 @@ class FastBaseModel:
|
|||
load_in_fp8 = kwargs.pop("load_in_fp8", None)
|
||||
model = auto_model.from_pretrained(
|
||||
model_name,
|
||||
device_map=device_map,
|
||||
device_map = device_map,
|
||||
# torch_dtype = torch_dtype, # Transformers removed torch_dtype
|
||||
# quantization_config = bnb_config,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
# attn_implementation = attn_implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -715,18 +715,18 @@ class FastBaseModel:
|
|||
|
||||
allowed_args = inspect.getfullargspec(load_vllm).args
|
||||
load_vllm_kwargs = dict(
|
||||
model_name=model_name,
|
||||
config=model_config,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
enable_lora=vllm_enable_lora,
|
||||
max_lora_rank=max_lora_rank,
|
||||
disable_log_stats=disable_log_stats,
|
||||
use_bitsandbytes=load_in_4bit,
|
||||
unsloth_vllm_standby=unsloth_vllm_standby,
|
||||
is_vision_model=is_vlm,
|
||||
model_name = model_name,
|
||||
config = model_config,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
enable_lora = vllm_enable_lora,
|
||||
max_lora_rank = max_lora_rank,
|
||||
disable_log_stats = disable_log_stats,
|
||||
use_bitsandbytes = load_in_4bit,
|
||||
unsloth_vllm_standby = unsloth_vllm_standby,
|
||||
is_vision_model = is_vlm,
|
||||
)
|
||||
for allowed_arg in allowed_args:
|
||||
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
|
||||
|
|
@ -738,15 +738,15 @@ class FastBaseModel:
|
|||
# Convert to HF format
|
||||
_, quant_state_dict = get_vllm_state_dict(
|
||||
llm,
|
||||
config=model_config,
|
||||
is_vision_model=is_vlm,
|
||||
config = model_config,
|
||||
is_vision_model = is_vlm,
|
||||
)
|
||||
model = convert_vllm_to_huggingface(
|
||||
quant_state_dict,
|
||||
model_config,
|
||||
dtype,
|
||||
bnb_config,
|
||||
is_vision_model=is_vlm,
|
||||
is_vision_model = is_vlm,
|
||||
)
|
||||
model.vllm_engine = llm
|
||||
model.fast_generate = model.vllm_engine.generate
|
||||
|
|
@ -788,26 +788,26 @@ class FastBaseModel:
|
|||
):
|
||||
tokenizer = auto_processor.from_pretrained(
|
||||
tokenizer_name,
|
||||
padding_side="left",
|
||||
token=token,
|
||||
language=whisper_language,
|
||||
task=whisper_task,
|
||||
trust_remote_code=trust_remote_code,
|
||||
padding_side = "left",
|
||||
token = token,
|
||||
language = whisper_language,
|
||||
task = whisper_task,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
tokenizer = auto_processor.from_pretrained(
|
||||
tokenizer_name,
|
||||
padding_side="left",
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
padding_side = "left",
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
except:
|
||||
tokenizer = get_auto_processor(
|
||||
tokenizer_name,
|
||||
padding_side="left",
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
padding_side = "left",
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
if hasattr(tokenizer, "tokenizer"):
|
||||
__tokenizer = tokenizer.tokenizer
|
||||
|
|
@ -827,10 +827,10 @@ class FastBaseModel:
|
|||
model, tokenizer = patch_model_and_tokenizer(
|
||||
model,
|
||||
tokenizer,
|
||||
downcast_rope=False,
|
||||
fix_embeddings=False,
|
||||
do_forced_float32=do_forced_float32,
|
||||
correct_dtype=correct_dtype,
|
||||
downcast_rope = False,
|
||||
fix_embeddings = False,
|
||||
do_forced_float32 = do_forced_float32,
|
||||
correct_dtype = correct_dtype,
|
||||
)
|
||||
model, tokenizer = patch_tokenizer(model, tokenizer)
|
||||
model = post_patch_loss_function(model)
|
||||
|
|
@ -838,13 +838,13 @@ class FastBaseModel:
|
|||
# Log Unsloth version for future fastpaths for inference
|
||||
if hasattr(model, "config"):
|
||||
model.config.update({"unsloth_version": __version__})
|
||||
patch_saving_functions(model, vision=True)
|
||||
patch_saving_functions(model, vision = True)
|
||||
if tokenizer is None:
|
||||
del model
|
||||
raise RuntimeError(
|
||||
"Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
|
||||
)
|
||||
patch_saving_functions(tokenizer, vision=True)
|
||||
patch_saving_functions(tokenizer, vision = True)
|
||||
|
||||
# Fix gradient accumulation
|
||||
from transformers.trainer import Trainer
|
||||
|
|
@ -882,11 +882,11 @@ class FastBaseModel:
|
|||
# Post patches
|
||||
model = FastBaseModel.post_patch_model(
|
||||
model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_type=model_type_arch,
|
||||
tokenizer=tokenizer,
|
||||
float32_mixed_precision=float32_mixed_precision,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
trust_remote_code = trust_remote_code,
|
||||
model_type = model_type_arch,
|
||||
tokenizer = tokenizer,
|
||||
float32_mixed_precision = float32_mixed_precision,
|
||||
)
|
||||
# Clear deleted GPU items
|
||||
for _ in range(3):
|
||||
|
|
@ -900,27 +900,27 @@ class FastBaseModel:
|
|||
@staticmethod
|
||||
def get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=None,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
finetune_vision_layers=True,
|
||||
finetune_language_layers=True,
|
||||
finetune_attention_modules=True,
|
||||
finetune_mlp_modules=True,
|
||||
layers_to_transform=None,
|
||||
layers_pattern=None,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
max_seq_length=2048, # not used anymore
|
||||
use_rslora=False,
|
||||
modules_to_save=None,
|
||||
init_lora_weights=True,
|
||||
loftq_config={},
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
temporary_location="_unsloth_temporary_saved_buffers",
|
||||
qat_scheme=None,
|
||||
r = 16,
|
||||
target_modules = None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0.0,
|
||||
bias = "none",
|
||||
finetune_vision_layers = True,
|
||||
finetune_language_layers = True,
|
||||
finetune_attention_modules = True,
|
||||
finetune_mlp_modules = True,
|
||||
layers_to_transform = None,
|
||||
layers_pattern = None,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
max_seq_length = 2048, # not used anymore
|
||||
use_rslora = False,
|
||||
modules_to_save = None,
|
||||
init_lora_weights = True,
|
||||
loftq_config = {},
|
||||
task_type = TaskType.CAUSAL_LM,
|
||||
temporary_location = "_unsloth_temporary_saved_buffers",
|
||||
qat_scheme = None,
|
||||
**kwargs,
|
||||
):
|
||||
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
|
||||
|
|
@ -948,10 +948,10 @@ class FastBaseModel:
|
|||
if target_modules is None or target_modules == "all-linear":
|
||||
target_modules = get_peft_regex(
|
||||
model,
|
||||
finetune_vision_layers=finetune_vision_layers,
|
||||
finetune_language_layers=finetune_language_layers,
|
||||
finetune_attention_modules=finetune_attention_modules,
|
||||
finetune_mlp_modules=finetune_mlp_modules,
|
||||
finetune_vision_layers = finetune_vision_layers,
|
||||
finetune_language_layers = finetune_language_layers,
|
||||
finetune_attention_modules = finetune_attention_modules,
|
||||
finetune_mlp_modules = finetune_mlp_modules,
|
||||
)
|
||||
else:
|
||||
assert type(target_modules) in (
|
||||
|
|
@ -1013,7 +1013,7 @@ class FastBaseModel:
|
|||
)
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
)
|
||||
model = _get_peft_model(model, lora_config)
|
||||
# Apply QAT + LoRA if specified
|
||||
|
|
@ -1027,8 +1027,8 @@ class FastBaseModel:
|
|||
trust_remote_code = getattr(model, "_unsloth_trust_remote_code", False)
|
||||
model = FastBaseModel.post_patch_model(
|
||||
model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
model.max_seq_length = max_seq_length
|
||||
# Save to modules as well
|
||||
|
|
@ -1041,7 +1041,7 @@ class FastBaseModel:
|
|||
torch.cuda.empty_cache()
|
||||
elif DEVICE_TYPE == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
patch_saving_functions(model, vision=True)
|
||||
patch_saving_functions(model, vision = True)
|
||||
patch_peft_fast_inference(model)
|
||||
|
||||
# Add for_inference and for_training
|
||||
|
|
@ -1057,11 +1057,11 @@ class FastBaseModel:
|
|||
@staticmethod
|
||||
def post_patch_model(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
trust_remote_code=False,
|
||||
model_type=None,
|
||||
tokenizer=None,
|
||||
float32_mixed_precision=None,
|
||||
use_gradient_checkpointing = True,
|
||||
trust_remote_code = False,
|
||||
model_type = None,
|
||||
tokenizer = None,
|
||||
float32_mixed_precision = None,
|
||||
):
|
||||
full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1"
|
||||
|
||||
|
|
@ -1079,14 +1079,14 @@ class FastBaseModel:
|
|||
|
||||
model = prepare_model_for_training(
|
||||
model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_reentrant=True,
|
||||
full_finetuning=full_finetuning,
|
||||
train_layernorms=full_finetuning,
|
||||
train_embedding=full_finetuning,
|
||||
train_lm_head=full_finetuning,
|
||||
float32_mixed_precision=float32_mixed_precision,
|
||||
patch_modules_to_save=True,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
use_reentrant = True,
|
||||
full_finetuning = full_finetuning,
|
||||
train_layernorms = full_finetuning,
|
||||
train_embedding = full_finetuning,
|
||||
train_lm_head = full_finetuning,
|
||||
float32_mixed_precision = float32_mixed_precision,
|
||||
patch_modules_to_save = True,
|
||||
)
|
||||
|
||||
from transformers.trainer import Trainer
|
||||
|
|
@ -1096,7 +1096,7 @@ class FastBaseModel:
|
|||
and trust_remote_code == False
|
||||
):
|
||||
raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
|
||||
patch_saving_functions(model, vision=True)
|
||||
patch_saving_functions(model, vision = True)
|
||||
|
||||
# Patch tokenizer to pad to the left
|
||||
m = model
|
||||
|
|
@ -1194,11 +1194,11 @@ class FastBaseModel:
|
|||
os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
|
||||
# Turn off skip guards and set stance to default
|
||||
if torch_compiler_set_stance is not None:
|
||||
torch_compiler_set_stance(stance="default", skip_guard_eval_unsafe=False)
|
||||
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def for_training(model, use_gradient_checkpointing=True):
|
||||
def for_training(model, use_gradient_checkpointing = True):
|
||||
if not hasattr(model, "parameters"):
|
||||
raise TypeError(
|
||||
"Unsloth: I think you're passing a tokenizer, not the model to for_training!"
|
||||
|
|
@ -1250,5 +1250,5 @@ class FastBaseModel:
|
|||
os.environ["UNSLOTH_RETURN_LOGITS"] = "0"
|
||||
# Turn off skip guards and set stance to default
|
||||
if torch_compiler_set_stance is not None:
|
||||
torch_compiler_set_stance(stance="default", skip_guard_eval_unsafe=False)
|
||||
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -25,50 +25,50 @@ class LlamaVisionModelInfo(ModelInfo):
|
|||
|
||||
# Llama 3.1
|
||||
LlamaMeta_3_1 = ModelMeta(
|
||||
org="meta-llama",
|
||||
base_name="Llama",
|
||||
instruct_tags=[None, "Instruct"],
|
||||
model_version="3.1",
|
||||
model_sizes=["8"],
|
||||
model_info_cls=LlamaModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "meta-llama",
|
||||
base_name = "Llama",
|
||||
instruct_tags = [None, "Instruct"],
|
||||
model_version = "3.1",
|
||||
model_sizes = ["8"],
|
||||
model_info_cls = LlamaModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Llama 3.2 Base Models
|
||||
LlamaMeta_3_2_Base = ModelMeta(
|
||||
org="meta-llama",
|
||||
base_name="Llama",
|
||||
instruct_tags=[None],
|
||||
model_version="3.2",
|
||||
model_sizes=["1", "3"],
|
||||
model_info_cls=LlamaModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "meta-llama",
|
||||
base_name = "Llama",
|
||||
instruct_tags = [None],
|
||||
model_version = "3.2",
|
||||
model_sizes = ["1", "3"],
|
||||
model_info_cls = LlamaModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Llama 3.2 Instruction Tuned Models
|
||||
LlamaMeta_3_2_Instruct = ModelMeta(
|
||||
org="meta-llama",
|
||||
base_name="Llama",
|
||||
instruct_tags=["Instruct"],
|
||||
model_version="3.2",
|
||||
model_sizes=["1", "3"],
|
||||
model_info_cls=LlamaModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
org = "meta-llama",
|
||||
base_name = "Llama",
|
||||
instruct_tags = ["Instruct"],
|
||||
model_version = "3.2",
|
||||
model_sizes = ["1", "3"],
|
||||
model_info_cls = LlamaModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
)
|
||||
|
||||
# Llama 3.2 Vision
|
||||
LlamaMeta_3_2_Vision = ModelMeta(
|
||||
org="meta-llama",
|
||||
base_name="Llama",
|
||||
instruct_tags=[None, "Instruct"],
|
||||
model_version="3.2",
|
||||
model_sizes=["11", "90"],
|
||||
model_info_cls=LlamaVisionModelInfo,
|
||||
is_multimodal=True,
|
||||
quant_types={
|
||||
org = "meta-llama",
|
||||
base_name = "Llama",
|
||||
instruct_tags = [None, "Instruct"],
|
||||
model_version = "3.2",
|
||||
model_sizes = ["11", "90"],
|
||||
model_info_cls = LlamaVisionModelInfo,
|
||||
is_multimodal = True,
|
||||
quant_types = {
|
||||
"11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
"90": [QuantType.NONE],
|
||||
},
|
||||
|
|
@ -79,7 +79,7 @@ def register_llama_3_1_models(include_original_model: bool = False):
|
|||
global _IS_LLAMA_3_1_REGISTERED
|
||||
if _IS_LLAMA_3_1_REGISTERED:
|
||||
return
|
||||
_register_models(LlamaMeta_3_1, include_original_model=include_original_model)
|
||||
_register_models(LlamaMeta_3_1, include_original_model = include_original_model)
|
||||
_IS_LLAMA_3_1_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -87,9 +87,9 @@ def register_llama_3_2_models(include_original_model: bool = False):
|
|||
global _IS_LLAMA_3_2_REGISTERED
|
||||
if _IS_LLAMA_3_2_REGISTERED:
|
||||
return
|
||||
_register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model)
|
||||
_register_models(LlamaMeta_3_2_Base, include_original_model = include_original_model)
|
||||
_register_models(
|
||||
LlamaMeta_3_2_Instruct, include_original_model=include_original_model
|
||||
LlamaMeta_3_2_Instruct, include_original_model = include_original_model
|
||||
)
|
||||
_IS_LLAMA_3_2_REGISTERED = True
|
||||
|
||||
|
|
@ -99,15 +99,15 @@ def register_llama_3_2_vision_models(include_original_model: bool = False):
|
|||
if _IS_LLAMA_3_2_VISION_REGISTERED:
|
||||
return
|
||||
_register_models(
|
||||
LlamaMeta_3_2_Vision, include_original_model=include_original_model
|
||||
LlamaMeta_3_2_Vision, include_original_model = include_original_model
|
||||
)
|
||||
_IS_LLAMA_3_2_VISION_REGISTERED = True
|
||||
|
||||
|
||||
def register_llama_models(include_original_model: bool = False):
|
||||
register_llama_3_1_models(include_original_model=include_original_model)
|
||||
register_llama_3_2_models(include_original_model=include_original_model)
|
||||
register_llama_3_2_vision_models(include_original_model=include_original_model)
|
||||
register_llama_3_1_models(include_original_model = include_original_model)
|
||||
register_llama_3_2_models(include_original_model = include_original_model)
|
||||
register_llama_3_2_vision_models(include_original_model = include_original_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
|||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_llama_models(include_original_model=True)
|
||||
register_llama_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
|
|
@ -15,26 +15,26 @@ class PhiModelInfo(ModelInfo):
|
|||
|
||||
# Phi Model Meta
|
||||
PhiMeta4 = ModelMeta(
|
||||
org="microsoft",
|
||||
base_name="phi",
|
||||
instruct_tags=[None],
|
||||
model_version="4",
|
||||
model_sizes=["1"], # Assuming only one size
|
||||
model_info_cls=PhiModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "microsoft",
|
||||
base_name = "phi",
|
||||
instruct_tags = [None],
|
||||
model_version = "4",
|
||||
model_sizes = ["1"], # Assuming only one size
|
||||
model_info_cls = PhiModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Phi Instruct Model Meta
|
||||
PhiInstructMeta4 = ModelMeta(
|
||||
org="microsoft",
|
||||
base_name="phi",
|
||||
instruct_tags=["mini-instruct"],
|
||||
model_version="4",
|
||||
model_sizes=["1"], # Assuming only one size
|
||||
model_info_cls=PhiModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
org = "microsoft",
|
||||
base_name = "phi",
|
||||
instruct_tags = ["mini-instruct"],
|
||||
model_version = "4",
|
||||
model_sizes = ["1"], # Assuming only one size
|
||||
model_info_cls = PhiModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def register_phi_4_models(include_original_model: bool = False):
|
|||
global _IS_PHI_4_REGISTERED
|
||||
if _IS_PHI_4_REGISTERED:
|
||||
return
|
||||
_register_models(PhiMeta4, include_original_model=include_original_model)
|
||||
_register_models(PhiMeta4, include_original_model = include_original_model)
|
||||
_IS_PHI_4_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -50,13 +50,13 @@ def register_phi_4_instruct_models(include_original_model: bool = False):
|
|||
global _IS_PHI_4_INSTRUCT_REGISTERED
|
||||
if _IS_PHI_4_INSTRUCT_REGISTERED:
|
||||
return
|
||||
_register_models(PhiInstructMeta4, include_original_model=include_original_model)
|
||||
_register_models(PhiInstructMeta4, include_original_model = include_original_model)
|
||||
_IS_PHI_4_INSTRUCT_REGISTERED = True
|
||||
|
||||
|
||||
def register_phi_models(include_original_model: bool = False):
|
||||
register_phi_4_models(include_original_model=include_original_model)
|
||||
register_phi_4_instruct_models(include_original_model=include_original_model)
|
||||
register_phi_4_models(include_original_model = include_original_model)
|
||||
register_phi_4_instruct_models(include_original_model = include_original_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -64,7 +64,7 @@ if __name__ == "__main__":
|
|||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_phi_models(include_original_model=True)
|
||||
register_phi_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
578
unsloth/save.py
578
unsloth/save.py
File diff suppressed because it is too large
Load diff
|
|
@ -71,7 +71,7 @@ KAGGLE_TMP = "/tmp"
|
|||
del keynames
|
||||
|
||||
|
||||
def try_fix_tokenizer(tokenizer, prepend=True):
|
||||
def try_fix_tokenizer(tokenizer, prepend = True):
|
||||
if hasattr(tokenizer, "_tokenizer"):
|
||||
converted_tokenizer = tokenizer._tokenizer
|
||||
else:
|
||||
|
|
@ -130,7 +130,7 @@ def get_sorted_dict(dictionary):
|
|||
|
||||
def convert_to_fast_tokenizer(
|
||||
slow_tokenizer,
|
||||
temporary_location="_unsloth_sentencepiece_temp",
|
||||
temporary_location = "_unsloth_sentencepiece_temp",
|
||||
):
|
||||
is_fast = getattr(slow_tokenizer, "is_fast", False)
|
||||
if is_fast:
|
||||
|
|
@ -152,20 +152,20 @@ def convert_to_fast_tokenizer(
|
|||
# Get all arguments (bos_token, etc)
|
||||
docs = FastTokenizer.__doc__
|
||||
docs = docs[docs.find("Args:") :]
|
||||
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags=re.MULTILINE)
|
||||
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
|
||||
args = [x for x in args if not x.endswith("_file")]
|
||||
|
||||
# Also some missing maybe!
|
||||
docs = PreTrainedTokenizerFast.__doc__
|
||||
docs = docs[docs.find("Args:") :]
|
||||
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags=re.MULTILINE)
|
||||
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
|
||||
args2 = [x for x in args2 if not x.endswith("_file")]
|
||||
args = list(set(args + args2))
|
||||
|
||||
kwargs = {}
|
||||
for arg in args:
|
||||
kwargs[arg] = getattr(slow_tokenizer, arg, None)
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend=True)
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
|
||||
fast_tokenizer = FastTokenizer(**kwargs)
|
||||
|
||||
# Check if they're similar!
|
||||
|
|
@ -184,7 +184,7 @@ def convert_to_fast_tokenizer(
|
|||
# Now confirm if they match
|
||||
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
# Maybe remove prepending of __apple?
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend=False)
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
|
||||
fast_tokenizer = FastTokenizer(**kwargs)
|
||||
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
# Failure :(
|
||||
|
|
@ -347,7 +347,7 @@ def fix_sentencepiece_tokenizer(
|
|||
old_tokenizer,
|
||||
new_tokenizer,
|
||||
token_mapping,
|
||||
temporary_location="_unsloth_sentencepiece_temp",
|
||||
temporary_location = "_unsloth_sentencepiece_temp",
|
||||
):
|
||||
# From https://github.com/google/sentencepiece/issues/121
|
||||
# We need to manually edit the sentencepiece tokenizer!
|
||||
|
|
@ -390,7 +390,7 @@ def fix_sentencepiece_tokenizer(
|
|||
|
||||
# Now correct the old tokenizer's .model file
|
||||
for old_token, new_token in token_mapping.items():
|
||||
ids = old_tokenizer([old_token], add_special_tokens=False).input_ids
|
||||
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
|
||||
ids = ids[0]
|
||||
if len(ids) != 1:
|
||||
# Skip this token!
|
||||
|
|
@ -416,8 +416,8 @@ def fix_sentencepiece_tokenizer(
|
|||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
temporary_location,
|
||||
eos_token=new_tokenizer.eos_token,
|
||||
pad_token=new_tokenizer.pad_token,
|
||||
eos_token = new_tokenizer.eos_token,
|
||||
pad_token = new_tokenizer.pad_token,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
|
@ -453,13 +453,13 @@ def fix_sentencepiece_gguf(saved_location):
|
|||
# Load added_tokens_json
|
||||
if not os.path.isfile(f"{saved_location}/added_tokens.json"):
|
||||
return
|
||||
with open(f"{saved_location}/added_tokens.json", "r", encoding="utf-8") as file:
|
||||
with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file:
|
||||
added_tokens_json = json.load(file)
|
||||
if len(added_tokens_json) == 0:
|
||||
return
|
||||
|
||||
added_tokens_json = dict(
|
||||
sorted(added_tokens_json.items(), key=lambda item: item[1])
|
||||
sorted(added_tokens_json.items(), key = lambda item: item[1])
|
||||
)
|
||||
new_size = sentence_piece_size + len(added_tokens_json)
|
||||
|
||||
|
|
@ -496,12 +496,12 @@ def fix_sentencepiece_gguf(saved_location):
|
|||
|
||||
def _load_correct_tokenizer(
|
||||
tokenizer_name,
|
||||
model_max_length=None,
|
||||
padding_side="right",
|
||||
token=None,
|
||||
trust_remote_code=False,
|
||||
cache_dir="huggingface_tokenizers_cache",
|
||||
fix_tokenizer=True,
|
||||
model_max_length = None,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
trust_remote_code = False,
|
||||
cache_dir = "huggingface_tokenizers_cache",
|
||||
fix_tokenizer = True,
|
||||
):
|
||||
if IS_COLAB_ENVIRONMENT:
|
||||
cache_dir = cache_dir
|
||||
|
|
@ -518,15 +518,15 @@ def _load_correct_tokenizer(
|
|||
try:
|
||||
slow_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
model_max_length=model_max_length,
|
||||
padding_side=padding_side,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
|
||||
use_fast=False,
|
||||
legacy=False,
|
||||
from_slow=True,
|
||||
cache_dir=cache_dir,
|
||||
use_fast = False,
|
||||
legacy = False,
|
||||
from_slow = True,
|
||||
cache_dir = cache_dir,
|
||||
)
|
||||
except:
|
||||
slow_tokenizer = None
|
||||
|
|
@ -540,11 +540,11 @@ def _load_correct_tokenizer(
|
|||
|
||||
fast_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
model_max_length=model_max_length,
|
||||
padding_side=padding_side,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
cache_dir=cache_dir,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
cache_dir = cache_dir,
|
||||
)
|
||||
|
||||
if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
|
||||
|
|
@ -580,21 +580,21 @@ def _load_correct_tokenizer(
|
|||
|
||||
def load_correct_tokenizer(
|
||||
tokenizer_name,
|
||||
model_max_length=None,
|
||||
padding_side="right",
|
||||
token=None,
|
||||
trust_remote_code=False,
|
||||
cache_dir="huggingface_tokenizers_cache",
|
||||
fix_tokenizer=True,
|
||||
model_max_length = None,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
trust_remote_code = False,
|
||||
cache_dir = "huggingface_tokenizers_cache",
|
||||
fix_tokenizer = True,
|
||||
):
|
||||
tokenizer = _load_correct_tokenizer(
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_max_length=model_max_length,
|
||||
padding_side=padding_side,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
cache_dir=cache_dir,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
tokenizer_name = tokenizer_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
cache_dir = cache_dir,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
)
|
||||
|
||||
### 1. Fixup tokenizer's chat_template
|
||||
|
|
@ -683,7 +683,7 @@ def fix_chat_template(tokenizer):
|
|||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=False, tokenize=False
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
)
|
||||
is_sharegpt = False
|
||||
except:
|
||||
|
|
@ -692,7 +692,7 @@ def fix_chat_template(tokenizer):
|
|||
{"from": "human", "value": "Who are you?"},
|
||||
]
|
||||
tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=False, tokenize=False
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
)
|
||||
is_sharegpt = True
|
||||
except:
|
||||
|
|
@ -709,10 +709,10 @@ def fix_chat_template(tokenizer):
|
|||
else {"from": "human", "value": "Who are you?"}
|
||||
]
|
||||
no = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=False, tokenize=False
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
)
|
||||
yes = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
messages, add_generation_prompt = True, tokenize = False
|
||||
)
|
||||
|
||||
if no == yes:
|
||||
|
|
@ -750,11 +750,11 @@ def fix_chat_template(tokenizer):
|
|||
def check_tokenizer(
|
||||
model,
|
||||
tokenizer,
|
||||
model_name="unsloth/llama-2-7b-bnb-4bit",
|
||||
model_max_length=4096,
|
||||
padding_side="right",
|
||||
token=None,
|
||||
_reload=True,
|
||||
model_name = "unsloth/llama-2-7b-bnb-4bit",
|
||||
model_max_length = 4096,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
_reload = True,
|
||||
):
|
||||
# Checks tokenizer for out of bounds ids.
|
||||
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
|
||||
|
|
@ -861,23 +861,23 @@ def check_tokenizer(
|
|||
# Try slow tokenizer which can fix things!
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length=model_max_length,
|
||||
padding_side=padding_side,
|
||||
token=token,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
|
||||
use_fast=False,
|
||||
legacy=False,
|
||||
from_slow=True,
|
||||
cache_dir=cache_dir,
|
||||
use_fast = False,
|
||||
legacy = False,
|
||||
from_slow = True,
|
||||
cache_dir = cache_dir,
|
||||
)
|
||||
return check_tokenizer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_name=model_name,
|
||||
model_max_length=model_max_length,
|
||||
padding_side=padding_side,
|
||||
token=token,
|
||||
_reload=False,
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
_reload = False,
|
||||
)
|
||||
break
|
||||
except:
|
||||
|
|
@ -993,7 +993,7 @@ def patch_sft_trainer_tokenizer():
|
|||
replacer = re.findall(
|
||||
f"def {function_name}" + r"\(.*?\).*?\:\n",
|
||||
function,
|
||||
flags=re.MULTILINE | re.DOTALL,
|
||||
flags = re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if len(replacer) == 0:
|
||||
continue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue