Revert "[FIX] Vllm guided decoding params (#3662)"

This reverts commit fb4f0fdf56.
This commit is contained in:
Daniel Han 2025-12-01 05:43:45 -08:00
parent fb4f0fdf56
commit ba2897a318
51 changed files with 2649 additions and 2698 deletions

View file

@ -22,10 +22,10 @@ def safe_remove_directory(path):
print("🔥 Loading the 16-bit merged model from disk...")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./gpt-oss-finetuned-merged",
max_seq_length=1024,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./gpt-oss-finetuned-merged",
max_seq_length = 1024,
load_in_4bit = True,
load_in_8bit = False,
)
print("✅ Merged model loaded successfully.")
@ -36,14 +36,14 @@ messages = [
]
inputs = merged_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
reasoning_effort="low", # **NEW!** Set reasoning effort to low, medium or high
add_generation_prompt = True,
return_tensors = "pt",
return_dict = True,
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
).to(merged_model.device)
_ = merged_model.generate(
**inputs, max_new_tokens=512, streamer=TextStreamer(merged_tokenizer)
**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer)
)
print("\n✅ Inference complete.")

View file

@ -29,7 +29,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -40,17 +40,17 @@ def formatting_prompts_func(examples):
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
max_seq_length = 1024
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
"unsloth/gpt-oss-20b", max_seq_length = max_seq_length, load_in_4bit = True
)
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
formatting_prompts_func, batched=True
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split = "train[:50]").map(
formatting_prompts_func, batched = True
)
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=[
r = 8,
target_modules = [
"q_proj",
"k_proj",
"v_proj",
@ -59,22 +59,22 @@ model = FastLanguageModel.get_peft_model(
"up_proj",
"down_proj",
],
lora_alpha=16,
use_gradient_checkpointing="unsloth",
random_state=3407,
lora_alpha = 16,
use_gradient_checkpointing = "unsloth",
random_state = 3407,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=SFTConfig(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
max_steps=10,
learning_rate=2e-4,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
args = SFTConfig(
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
max_steps = 10,
learning_rate = 2e-4,
output_dir = "outputs",
report_to = "none",
),
)
@ -85,7 +85,7 @@ print("Fine-tuning complete.")
# --- Merge and Save ---
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
model.save_pretrained_merged(
save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer
save_directory = "./gpt-oss-finetuned-merged", tokenizer = tokenizer
)
print("✅ Model merged and saved.")

View file

@ -17,7 +17,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -36,25 +36,25 @@ else:
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Llama-3.1-8B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
# Load small dataset for quick training
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train[:100]"
"allenai/openassistant-guanaco-reformatted", split = "train[:100]"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
print("✅ Base model loaded successfully!")
@ -64,8 +64,8 @@ print(f"{'='*80}")
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -74,40 +74,40 @@ model = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 10, # Very short training for test
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 5,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -119,9 +119,9 @@ print("🔍 PHASE 3: Save with Forced 4bit Merge")
print(f"{'='*80}")
model.save_pretrained_merged(
save_directory="./test_4bit_model",
tokenizer=tokenizer,
save_method="forced_merged_4bit",
save_directory = "./test_4bit_model",
tokenizer = tokenizer,
save_method = "forced_merged_4bit",
)
print("✅ Model saved with forced 4bit merge!")
@ -137,15 +137,15 @@ torch.cuda.empty_cache()
# Load the 4bit merged model
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
model_name="./test_4bit_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./test_4bit_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
tokenizer_4bit = get_chat_template(
tokenizer_4bit,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
print("✅ 4bit model loaded successfully!")
@ -153,8 +153,8 @@ print("✅ 4bit model loaded successfully!")
# Add LoRA adapters to the 4bit model
model_4bit = FastLanguageModel.get_peft_model(
model_4bit,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -163,39 +163,39 @@ model_4bit = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
# Second fine-tuning
trainer_4bit = SFTTrainer(
model=model_4bit,
tokenizer=tokenizer_4bit,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs_4bit",
report_to="none",
model = model_4bit,
tokenizer = tokenizer_4bit,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer_4bit),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 10, # Very short training for test
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 5,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs_4bit",
report_to = "none",
),
)
@ -208,8 +208,8 @@ print(f"{'='*80}")
try:
model_4bit.save_pretrained_merged(
save_directory="./test_should_fail",
tokenizer=tokenizer_4bit,
save_directory = "./test_should_fail",
tokenizer = tokenizer_4bit,
# No save_method specified, should default to regular merge
)
assert False, "Expected TypeError but merge succeeded!"
@ -225,9 +225,9 @@ print(f"{'='*80}")
try:
model_4bit.save_pretrained_merged(
save_directory="./test_4bit_second",
tokenizer=tokenizer_4bit,
save_method="forced_merged_4bit",
save_directory = "./test_4bit_second",
tokenizer = tokenizer_4bit,
save_method = "forced_merged_4bit",
)
print("✅ Successfully saved 4bit model with forced 4bit method!")
except Exception as e:

View file

@ -36,14 +36,14 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return {"text": texts}
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
@ -51,20 +51,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
)
# Set up tokenizer
merged_tokenizer = get_chat_template(
merged_tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
# Format the dataset
@ -72,13 +72,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
convos = examples["messages"]
texts = [
merged_tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return {"text": texts}
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@ -104,7 +104,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@ -114,38 +114,38 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-3B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Llama-3.2-3B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train"
"allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -154,40 +154,40 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 10,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -195,8 +195,8 @@ if __name__ == "__main__":
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# run training
@ -207,7 +207,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
save_directory="./unsloth_out/merged_llama_text_model", tokenizer=tokenizer
save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
)
# print("cleaning")
@ -219,10 +219,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
add_to_comparison(
@ -231,7 +231,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@ -240,10 +240,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = False,
load_in_8bit = False,
)
add_to_comparison(

View file

@ -30,17 +30,17 @@ from tests.utils.perplexity_eval import (
)
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from tests.utils.perplexity_eval import ppl_model
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_mistral_text_model",
max_seq_length=2048,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
model_name = "./unsloth_out/merged_mistral_text_model",
max_seq_length = 2048,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
)
# Set up tokenizer
# merged_tokenizer = get_chat_template(
@ -50,7 +50,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@ -103,7 +103,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
"text": texts,
}
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@ -129,7 +129,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@ -139,13 +139,13 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b-v0.3",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/mistral-7b-v0.3",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
EOS_TOKEN = tokenizer.eos_token
@ -200,21 +200,21 @@ if __name__ == "__main__":
}
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train"
"allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -223,39 +223,39 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=200,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 200,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -267,7 +267,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
save_directory="./unsloth_out/merged_mistral_text_model", tokenizer=tokenizer
save_directory = "./unsloth_out/merged_mistral_text_model", tokenizer = tokenizer
)
# print("cleaning")
@ -279,10 +279,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_mistral_text_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./unsloth_out/merged_mistral_text_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
add_to_comparison(
@ -291,7 +291,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@ -300,10 +300,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_mistral_text_model",
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
model_name = "./unsloth_out/merged_mistral_text_model",
max_seq_length = 2048,
load_in_4bit = False,
load_in_8bit = False,
)
add_to_comparison(

View file

@ -36,7 +36,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -45,7 +45,7 @@ def formatting_prompts_func(examples):
}
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
@ -53,20 +53,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_phi4_text_model",
max_seq_length=2048,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
model_name = "./unsloth_out/merged_phi4_text_model",
max_seq_length = 2048,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
)
# Set up tokenizer
merged_tokenizer = get_chat_template(
merged_tokenizer,
chat_template="phi-4",
chat_template = "phi-4",
)
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
# Format the dataset
@ -74,13 +74,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
convos = examples["messages"]
texts = [
merged_tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return {"text": texts}
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@ -106,7 +106,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@ -116,36 +116,36 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Phi-4",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Phi-4",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="phi-4",
chat_template = "phi-4",
)
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train"
"allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -154,40 +154,40 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=200,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 200,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -195,8 +195,8 @@ if __name__ == "__main__":
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user<|im_sep|>\n\n",
response_part="<|im_start|>assistant<|im_sep|>\n\n",
instruction_part = "<|im_start|>user<|im_sep|>\n\n",
response_part = "<|im_start|>assistant<|im_sep|>\n\n",
)
# run training
@ -207,7 +207,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
save_directory="./unsloth_out/merged_phi4_text_model", tokenizer=tokenizer
save_directory = "./unsloth_out/merged_phi4_text_model", tokenizer = tokenizer
)
# print("cleaning")
@ -219,10 +219,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_phi4_text_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./unsloth_out/merged_phi4_text_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
add_to_comparison(
@ -231,7 +231,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@ -240,10 +240,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_phi4_text_model",
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
model_name = "./unsloth_out/merged_phi4_text_model",
max_seq_length = 2048,
load_in_4bit = False,
load_in_8bit = False,
)
add_to_comparison(

View file

@ -35,14 +35,14 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return {"text": texts}
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
@ -50,20 +50,20 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
)
# Set up tokenizer
merged_tokenizer = get_chat_template(
merged_tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
# Format the dataset
@ -71,13 +71,13 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
convos = examples["messages"]
texts = [
merged_tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return {"text": texts}
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@ -103,7 +103,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@ -113,31 +113,31 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Llama-3.1-8B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train"
"allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
print("\n dataset sample [0]")
print(dataset_train[0])
@ -146,8 +146,8 @@ if __name__ == "__main__":
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -156,40 +156,40 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=200,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 200,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -197,8 +197,8 @@ if __name__ == "__main__":
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
tokenizer.decode(trainer.train_dataset[0]["input_ids"])
@ -211,7 +211,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
save_directory="./unsloth_out/merged_llama_text_model", tokenizer=tokenizer
save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
)
# print("cleaning")
@ -223,10 +223,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
add_to_comparison(
@ -235,7 +235,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@ -244,10 +244,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_llama_text_model",
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
model_name = "./unsloth_out/merged_llama_text_model",
max_seq_length = 2048,
load_in_4bit = False,
load_in_8bit = False,
)
add_to_comparison(

View file

@ -16,12 +16,12 @@ print("🔍 PHASE 1: Loading Base Model")
print(f"{'='*80}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b-v0.3",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
model_name = "unsloth/mistral-7b-v0.3",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
)
@ -34,7 +34,7 @@ print(f"\n{'='*80}")
print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
print(f"{'='*80}")
with warnings.catch_warnings(record=True) as w:
with warnings.catch_warnings(record = True) as w:
warnings.simplefilter("always")
model.save_pretrained_merged("test_output", tokenizer)

View file

@ -16,12 +16,12 @@ print("🔍 PHASE 1: Loading Base Model")
print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
model_name="unsloth/whisper-large-v3",
dtype=None, # Leave as None for auto detection
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
auto_model=WhisperForConditionalGeneration,
whisper_language="English",
whisper_task="transcribe",
model_name = "unsloth/whisper-large-v3",
dtype = None, # Leave as None for auto detection
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
auto_model = WhisperForConditionalGeneration,
whisper_language = "English",
whisper_task = "transcribe",
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -34,7 +34,7 @@ print(f"\n{'='*80}")
print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
print(f"{'='*80}")
with warnings.catch_warnings(record=True) as w:
with warnings.catch_warnings(record = True) as w:
warnings.simplefilter("always")
model.save_pretrained_merged("test_output", tokenizer)

View file

@ -30,10 +30,10 @@ print(f"{'='*80}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/orpheus-3b-0.1-ft",
max_seq_length=2048, # Choose any for long context!
dtype=None, # Select None for auto detection
load_in_4bit=False, # Select True for 4bit which reduces memory usage
model_name = "unsloth/orpheus-3b-0.1-ft",
max_seq_length = 2048, # Choose any for long context!
dtype = None, # Select None for auto detection
load_in_4bit = False, # Select True for 4bit which reduces memory usage
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -42,8 +42,8 @@ base_model_class = model.__class__.__name__
model = FastLanguageModel.get_peft_model(
model,
r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj",
"k_proj",
"v_proj",
@ -52,14 +52,14 @@ model = FastLanguageModel.get_peft_model(
"up_proj",
"down_proj",
],
lora_alpha=64,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
lora_alpha = 64,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print("✅ Model and LoRA adapters loaded successfully!")
@ -112,10 +112,10 @@ print(f"{'='*80}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/orpheus-3b-0.1-ft",
max_seq_length=2048, # Choose any for long context!
dtype=None, # Select None for auto detection
load_in_4bit=False, # Select True for 4bit which reduces memory usage
model_name = "unsloth/orpheus-3b-0.1-ft",
max_seq_length = 2048, # Choose any for long context!
dtype = None, # Select None for auto detection
load_in_4bit = False, # Select True for 4bit which reduces memory usage
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -148,18 +148,18 @@ prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]
all_input_ids = []
for prompt in prompts_:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = tokenizer(prompt, return_tensors = "pt").input_ids
all_input_ids.append(input_ids)
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
start_token = torch.tensor([[128259]], dtype = torch.int64) # Start of human
end_tokens = torch.tensor(
[[128009, 128260]], dtype=torch.int64
[[128009, 128260]], dtype = torch.int64
) # End of text, End of human
all_modified_input_ids = []
for input_ids in all_input_ids:
modified_input_ids = torch.cat(
[start_token, input_ids, end_tokens], dim=1
[start_token, input_ids, end_tokens], dim = 1
) # SOH SOT Text EOT EOH
all_modified_input_ids.append(modified_input_ids)
@ -171,39 +171,39 @@ max_length = max(
for modified_input_ids in all_modified_input_ids:
padding = max_length - modified_input_ids.shape[1]
padded_tensor = torch.cat(
[torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1
[torch.full((1, padding), 128263, dtype = torch.int64), modified_input_ids], dim = 1
)
attention_mask = torch.cat(
[
torch.zeros((1, padding), dtype=torch.int64),
torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64),
torch.zeros((1, padding), dtype = torch.int64),
torch.ones((1, modified_input_ids.shape[1]), dtype = torch.int64),
],
dim=1,
dim = 1,
)
all_padded_tensors.append(padded_tensor)
all_attention_masks.append(attention_mask)
all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
all_attention_masks = torch.cat(all_attention_masks, dim=0)
all_padded_tensors = torch.cat(all_padded_tensors, dim = 0)
all_attention_masks = torch.cat(all_attention_masks, dim = 0)
input_ids = all_padded_tensors.to("cuda")
attention_mask = all_attention_masks.to("cuda")
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=1200,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1,
eos_token_id=128258,
use_cache=True,
input_ids = input_ids,
attention_mask = attention_mask,
max_new_tokens = 1200,
do_sample = True,
temperature = 0.6,
top_p = 0.95,
repetition_penalty = 1.1,
num_return_sequences = 1,
eos_token_id = 128258,
use_cache = True,
)
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
token_indices = (generated_ids == token_to_find).nonzero(as_tuple = True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()

View file

@ -22,7 +22,7 @@ from tests.utils.ocr_eval import OCRModelEvaluator
## Dataset Preparation
from datasets import load_dataset
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@ -81,39 +81,39 @@ model_comparison_results = {}
# Load Base Model
model, tokenizer = FastVisionModel.from_pretrained(
model_name="unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
max_seq_length=2048, # Choose any for long context!
load_in_4bit=True, # 4 bit quantization to reduce memory
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning=False, # [NEW!] We have full finetuning now!
model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
max_seq_length = 2048, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
# benchmark base model performance
model_name = "Unsloth Base model"
FastVisionModel.for_inference(model)
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results"
model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
## Lora Finetuning
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=True, # Turn off for just text!
finetune_language_layers=True, # Should leave on!
finetune_attention_modules=True, # Attention good for GRPO
finetune_mlp_modules=True, # SHould leave on always!
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
# "gate_proj", "up_proj", "down_proj",],
lora_alpha=32,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
from unsloth import is_bf16_supported
@ -124,40 +124,40 @@ model.config.use_cache = False
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
data_collator=UnslothVisionDataCollator(model, tokenizer),
train_dataset=train_dataset,
args=SFTConfig(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = train_dataset,
args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
gradient_checkpointing = True,
gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03,
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
max_steps=60,
learning_rate=2e-4,
fp16=not is_bf16_supported(),
bf16=is_bf16_supported(),
logging_steps=5,
save_strategy="epoch",
optim="adamw_torch_fused",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
report_to="none", # For Weights and Biases
max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
logging_steps = 5,
save_strategy = "epoch",
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=2048,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
@ -173,7 +173,7 @@ tokenizer.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter")
model_name = "Unsloth lora adapter model"
FastVisionModel.for_inference(model)
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results"
model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
@ -195,7 +195,7 @@ print((base.__class__.__name__))
# merge default 16 bits
model.save_pretrained_merged(
save_directory="qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer
save_directory = "qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
)
@ -204,7 +204,7 @@ model.save_pretrained_merged(
### 16 bits merged model
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=False
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
)
# benchmark 4bit loaded, 16bits merged model performance
@ -215,13 +215,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_16bits_results",
output_dir = "unsloth_16bits_merged_model_load_16bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
# load 16bits-merged model in 4 bits
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=True, load_in_8bit=False
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
)
# benchmark 4bit loaded, 16bits merged model performance
@ -232,13 +232,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_4bits_results",
output_dir = "unsloth_16bits_merged_model_load_4bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
# load model in 8 bits
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=True
"./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
)
# benchmark 4bit loaded, 16bits merged model performance
@ -247,7 +247,7 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_8bits_results",
output_dir = "unsloth_16bits_merged_model_load_8bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)

View file

@ -22,7 +22,7 @@ from tests.utils.ocr_eval import OCRModelEvaluator
## Dataset Preparation
from datasets import load_dataset
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@ -81,39 +81,39 @@ model_comparison_results = {}
# Load Base Model
model, tokenizer = FastVisionModel.from_pretrained(
model_name="unsloth/Qwen2-VL-7B-Instruct",
max_seq_length=2048, # Choose any for long context!
load_in_4bit=True, # 4 bit quantization to reduce memory
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning=False, # [NEW!] We have full finetuning now!
model_name = "unsloth/Qwen2-VL-7B-Instruct",
max_seq_length = 2048, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
# benchmark base model performance
model_name = "Unsloth Base model"
FastVisionModel.for_inference(model)
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results"
model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
## Lora Finetuning
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=True, # Turn off for just text!
finetune_language_layers=True, # Should leave on!
finetune_attention_modules=True, # Attention good for GRPO
finetune_mlp_modules=True, # SHould leave on always!
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
# "gate_proj", "up_proj", "down_proj",],
lora_alpha=32,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
from unsloth import is_bf16_supported
@ -124,40 +124,40 @@ model.config.use_cache = False
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
data_collator=UnslothVisionDataCollator(model, tokenizer),
train_dataset=train_dataset,
args=SFTConfig(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = train_dataset,
args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
gradient_checkpointing = True,
gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03,
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
max_steps=60,
learning_rate=2e-4,
fp16=not is_bf16_supported(),
bf16=is_bf16_supported(),
logging_steps=5,
save_strategy="epoch",
optim="adamw_torch_fused",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="unsloth-qwen2-7vl-french-ocr-checkpoints",
report_to="none", # For Weights and Biases
max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
logging_steps = 5,
save_strategy = "epoch",
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "unsloth-qwen2-7vl-french-ocr-checkpoints",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=2048,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
@ -173,7 +173,7 @@ tokenizer.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter")
model_name = "Unsloth lora adapter model"
FastVisionModel.for_inference(model)
avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results"
model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
@ -195,7 +195,7 @@ print((base.__class__.__name__))
# merge default 16 bits
model.save_pretrained_merged(
save_directory="qwen2-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer
save_directory = "qwen2-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
)
@ -204,7 +204,7 @@ model.save_pretrained_merged(
### 16 bits merged model
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=False
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
)
# benchmark 4bit loaded, 16bits merged model performance
@ -215,13 +215,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_16bits_results",
output_dir = "unsloth_16bits_merged_model_load_16bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
# load 16bits-merged model in 4 bits
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=True, load_in_8bit=False
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
)
# benchmark 4bit loaded, 16bits merged model performance
@ -232,13 +232,13 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_4bits_results",
output_dir = "unsloth_16bits_merged_model_load_4bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
# load model in 8 bits
model, tokenizer = FastVisionModel.from_pretrained(
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit=False, load_in_8bit=True
"./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
)
# benchmark 4bit loaded, 16bits merged model performance
@ -247,7 +247,7 @@ avg_wer, avg_cer = ocr_evaluator.evaluate_model(
model,
tokenizer,
eval_dataset,
output_dir="unsloth_16bits_merged_model_load_8bits_results",
output_dir = "unsloth_16bits_merged_model_load_8bits_results",
)
ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)

View file

@ -64,7 +64,7 @@ TestParams = [
# Test that model registration methods register respective models
@pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name)
@pytest.mark.parametrize("model_test_param", TestParams, ids = lambda param: param.name)
def test_model_registration(model_test_param: ModelTestParam):
MODEL_REGISTRY.clear()
registration_method = model_test_param.register_models
@ -86,7 +86,7 @@ def test_all_model_registration():
def test_quant_type():
# Test that the quant_type is correctly set for model paths
# NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH
dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH])
dynamic_quant_models = search_models(quant_types = [QuantType.UNSLOTH])
assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)
quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]
assert all(quant_tag in m.model_path for m in dynamic_quant_models)

View file

@ -25,7 +25,7 @@ def timer(name):
@contextmanager
def header_footer_context(title: str, char="-"):
def header_footer_context(title: str, char = "-"):
print()
print(f"{char}" * 50 + f" {title} " + f"{char}" * 50)
yield

View file

@ -7,7 +7,7 @@ import sys
import warnings
def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
def clear_memory(variables_to_clear = None, verbose = False, clear_all_caches = True):
"""
Comprehensive memory clearing for persistent memory leaks.
@ -104,7 +104,7 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
logger.setLevel(level)
def clear_all_lru_caches(verbose=True):
def clear_all_lru_caches(verbose = True):
"""Clear all LRU caches in loaded modules."""
cleared_caches = []
@ -210,7 +210,7 @@ def monitor_cache_sizes():
except:
pass
return sorted(cache_info, key=lambda x: x["size"], reverse=True)
return sorted(cache_info, key = lambda x: x["size"], reverse = True)
def safe_remove_directory(path):

View file

@ -32,10 +32,10 @@ def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = N
dataset = create_instruction_dataset(messages)
def _apply_chat_template(example):
chat = tokenizer.apply_chat_template(example["messages"], tokenize=False)
chat = tokenizer.apply_chat_template(example["messages"], tokenize = False)
return {"text": chat}
dataset = dataset.map(_apply_chat_template, remove_columns="messages")
dataset = dataset.map(_apply_chat_template, remove_columns = "messages")
if num_examples is not None:
if len(dataset) < num_examples:
num_repeats = num_examples // len(dataset) + 1
@ -139,11 +139,11 @@ def get_peft_weights(model):
def describe_peft_weights(model):
for name, param in get_peft_weights(model).items():
yield name, describe_param(param, as_str=True)
yield name, describe_param(param, as_str = True)
def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool:
for i, response in enumerate(responses, start=1):
for i, response in enumerate(responses, start = 1):
if answer in response:
print(f"\u2713 response {i} contains answer")
else:

View file

@ -41,14 +41,14 @@ class OCRModelEvaluator:
Evaluate a model on an OCR dataset.
"""
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok = True)
# Initialize results storage
results = []
# Process each sample in the dataset
for i, sample in enumerate(
tqdm(dataset, desc="Evaluating OCR performance", disable=not verbose)
tqdm(dataset, desc = "Evaluating OCR performance", disable = not verbose)
):
try:
# Extract components from sample
@ -187,7 +187,7 @@ class OCRModelEvaluator:
# Preparation for inference using Qwen's specific processing
text = processor.apply_chat_template(
input_messages, tokenize=False, add_generation_prompt=True
input_messages, tokenize = False, add_generation_prompt = True
)
# Process vision info (images/videos) from messages
@ -195,11 +195,11 @@ class OCRModelEvaluator:
# Create model inputs
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
text = [text],
images = image_inputs,
videos = video_inputs,
padding = True,
return_tensors = "pt",
)
inputs = inputs.to(model.device)
@ -207,10 +207,10 @@ class OCRModelEvaluator:
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
min_p=min_p,
use_cache=True,
max_new_tokens = max_new_tokens,
temperature = temperature,
min_p = min_p,
use_cache = True,
)
# Extract only the generated part (not the input)
@ -222,8 +222,8 @@ class OCRModelEvaluator:
# Decode the generated text
generated_response = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
skip_special_tokens = True,
clean_up_tokenization_spaces = False,
)[0]
return generated_response
@ -240,7 +240,7 @@ class OCRModelEvaluator:
):
"""Save individual sample result to file."""
output_file = os.path.join(output_dir, f"sample_{sample_idx}.txt")
with open(output_file, "w", encoding="utf-8") as f:
with open(output_file, "w", encoding = "utf-8") as f:
f.write(f"Sample {sample_idx}\n")
f.write(f"Question: {question}\n\n")
f.write(f"Model output:\n{generated_response.strip()}\n\n")
@ -268,7 +268,7 @@ class OCRModelEvaluator:
f.write(f"Average CER: {avg_cer:.4f}\n")
# Save detailed results
df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index=False)
df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index = False)
if verbose:
print("\nResults Summary:")
@ -310,12 +310,12 @@ class OCRModelEvaluator:
# Display the comparison table
print("\nComparison Table (sorted by WER):")
print(comparison_df.to_string(index=False))
print(comparison_df.to_string(index = False))
# Save the comparison table
if save_csv:
comparison_file = "model_comparison_results.csv"
comparison_df.to_csv(comparison_file, index=False)
comparison_df.to_csv(comparison_file, index = False)
print(f"\nComparison table saved to {comparison_file}")
# Generate a bar chart visualization
@ -326,23 +326,23 @@ class OCRModelEvaluator:
def _create_comparison_plot(self, comparison_df: pd.DataFrame):
"""Create and save comparison plot."""
plt.figure(figsize=(12, 6))
plt.figure(figsize = (12, 6))
# Plot WER
plt.subplot(1, 2, 1)
plt.bar(comparison_df["Model"], comparison_df["WER"], color="skyblue")
plt.bar(comparison_df["Model"], comparison_df["WER"], color = "skyblue")
plt.title("Word Error Rate Comparison")
plt.ylabel("WER (lower is better)")
plt.ylim(bottom=0)
plt.xticks(rotation=45, ha="right")
plt.ylim(bottom = 0)
plt.xticks(rotation = 45, ha = "right")
# Plot CER
plt.subplot(1, 2, 2)
plt.bar(comparison_df["Model"], comparison_df["CER"], color="lightgreen")
plt.bar(comparison_df["Model"], comparison_df["CER"], color = "lightgreen")
plt.title("Character Error Rate Comparison")
plt.ylabel("CER (lower is better)")
plt.ylim(bottom=0)
plt.xticks(rotation=45, ha="right")
plt.ylim(bottom = 0)
plt.xticks(rotation = 45, ha = "right")
plt.tight_layout()
plt.savefig("ocr_model_comparison.png")
@ -360,7 +360,7 @@ class OCRModelEvaluator:
def evaluate_ocr_model(
model, processor, dataset, output_dir="ocr_evaluation_results", **kwargs
model, processor, dataset, output_dir = "ocr_evaluation_results", **kwargs
):
"""
Convenience function that maintains backward compatibility with the original function.

View file

@ -21,7 +21,7 @@ def detect_package_manager():
return None
def check_package_installed(package_name, package_manager=None):
def check_package_installed(package_name, package_manager = None):
"""Check if a package is installed using the system package manager"""
if package_manager is None:
@ -35,26 +35,26 @@ def check_package_installed(package_name, package_manager=None):
if package_manager == "apt":
# Check with dpkg
result = subprocess.run(
["dpkg", "-l", package_name], capture_output=True, text=True
["dpkg", "-l", package_name], capture_output = True, text = True
)
return result.returncode == 0
elif package_manager in ["yum", "dnf"]:
# Check with rpm
result = subprocess.run(
["rpm", "-q", package_name], capture_output=True, text=True
["rpm", "-q", package_name], capture_output = True, text = True
)
return result.returncode == 0
elif package_manager == "pacman":
result = subprocess.run(
["pacman", "-Q", package_name], capture_output=True, text=True
["pacman", "-Q", package_name], capture_output = True, text = True
)
return result.returncode == 0
elif package_manager == "zypper":
result = subprocess.run(
["zypper", "se", "-i", package_name], capture_output=True, text=True
["zypper", "se", "-i", package_name], capture_output = True, text = True
)
return package_name in result.stdout
@ -63,7 +63,7 @@ def check_package_installed(package_name, package_manager=None):
return None
def require_package(package_name, executable_name=None):
def require_package(package_name, executable_name = None):
"""Require a package to be installed, exit if not found"""
# First check if executable is in PATH (most reliable)
@ -109,7 +109,7 @@ def require_package(package_name, executable_name=None):
# require_package("ffmpeg", "ffmpeg")
def require_python_package(package_name, import_name=None, pip_name=None):
def require_python_package(package_name, import_name = None, pip_name = None):
"""Require a Python package to be installed, exit if not found"""
if import_name is None:
import_name = package_name

View file

@ -12,7 +12,7 @@ def ppl_model(model, tokenizer, dataset):
max_length = 2048
stride = 512
for s in tqdm(range(len(dataset["text"]))):
encodings = tokenizer(dataset["text"][s], return_tensors="pt")
encodings = tokenizer(dataset["text"][s], return_tensors = "pt")
seq_len = encodings.input_ids.size(1)
prev_end_loc = 0
for begin_loc in range(0, seq_len, stride):
@ -28,7 +28,7 @@ def ppl_model(model, tokenizer, dataset):
attention_mask = (input_ids != pad_token_id).long()
with torch.no_grad():
outputs = model(
input_ids, labels=target_ids, attention_mask=attention_mask
input_ids, labels = target_ids, attention_mask = attention_mask
)
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
@ -78,4 +78,4 @@ def print_model_comparison():
# Display the comparison table
print("\nComparison Table:")
print(comparison_df.to_string(index=False))
print(comparison_df.to_string(index = False))

View file

@ -32,15 +32,15 @@ def _get_model(qat_scheme: str, full_finetuning: bool):
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
"""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen3-1.7B",
load_in_4bit=False,
full_finetuning=full_finetuning,
qat_scheme=qat_scheme if full_finetuning else None,
model_name = "unsloth/Qwen3-1.7B",
load_in_4bit = False,
full_finetuning = full_finetuning,
qat_scheme = qat_scheme if full_finetuning else None,
)
if not full_finetuning:
model = FastLanguageModel.get_peft_model(
model,
qat_scheme=qat_scheme,
qat_scheme = qat_scheme,
)
return model, tokenizer
@ -140,7 +140,7 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
inputs = tokenizer("How are you?", return_tensors="pt")
inputs = tokenizer("How are you?", return_tensors = "pt")
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
@ -148,9 +148,9 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
# how to disable model caching before re-enabling this test
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
def _test_full_model_fake_quantize(qat_scheme: bool):
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
def test_lora_model_fake_quantize(qat_scheme: bool):
_test_model_fake_quantize(qat_scheme, full_finetuning=False)
_test_model_fake_quantize(qat_scheme, full_finetuning = False)

View file

@ -47,17 +47,17 @@ def run(args):
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_seq_length,
dtype=args.dtype,
load_in_4bit=args.load_in_4bit,
model_name = args.model_name,
max_seq_length = args.max_seq_length,
dtype = args.dtype,
load_in_4bit = args.load_in_4bit,
)
# Configure PEFT model
model = FastLanguageModel.get_peft_model(
model,
r=args.r,
target_modules=[
r = args.r,
target_modules = [
"q_proj",
"k_proj",
"v_proj",
@ -66,13 +66,13 @@ def run(args):
"up_proj",
"down_proj",
],
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.bias,
use_gradient_checkpointing=args.use_gradient_checkpointing,
random_state=args.random_state,
use_rslora=args.use_rslora,
loftq_config=args.loftq_config,
lora_alpha = args.lora_alpha,
lora_dropout = args.lora_dropout,
bias = args.bias,
use_gradient_checkpointing = args.use_gradient_checkpointing,
random_state = args.random_state,
use_rslora = args.use_rslora,
loftq_config = args.loftq_config,
)
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@ -102,40 +102,40 @@ def run(args):
if use_modelscope:
from modelscope import MsDataset
dataset = MsDataset.load(args.dataset, split="train")
dataset = MsDataset.load(args.dataset, split = "train")
else:
# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)
dataset = load_dataset(args.dataset, split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True)
print("Data is formatted and ready!")
# Configure training arguments
training_args = SFTConfig(
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
learning_rate=args.learning_rate,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=args.logging_steps,
optim=args.optim,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
seed=args.seed,
output_dir=args.output_dir,
report_to=args.report_to,
max_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
per_device_train_batch_size = args.per_device_train_batch_size,
gradient_accumulation_steps = args.gradient_accumulation_steps,
warmup_steps = args.warmup_steps,
max_steps = args.max_steps,
learning_rate = args.learning_rate,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = args.logging_steps,
optim = args.optim,
weight_decay = args.weight_decay,
lr_scheduler_type = args.lr_scheduler_type,
seed = args.seed,
output_dir = args.output_dir,
report_to = args.report_to,
max_length = args.max_seq_length,
dataset_num_proc = 2,
packing = False,
)
# Initialize trainer
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
args=training_args,
model = model,
processing_class = tokenizer,
train_dataset = dataset,
args = training_args,
)
# Train model
@ -153,24 +153,24 @@ def run(args):
model.save_pretrained_gguf(
args.save_path,
tokenizer,
quantization_method=quantization_method,
quantization_method = quantization_method,
)
if args.push_model:
model.push_to_hub_gguf(
hub_path=args.hub_path,
hub_token=args.hub_token,
quantization_method=quantization_method,
hub_path = args.hub_path,
hub_token = args.hub_token,
quantization_method = quantization_method,
)
else:
print(f"Saving model with quantization method: {args.quantization}")
model.save_pretrained_gguf(
args.save_path, tokenizer, quantization_method=args.quantization
args.save_path, tokenizer, quantization_method = args.quantization
)
if args.push_model:
model.push_to_hub_gguf(
hub_path=args.hub_path,
hub_token=args.hub_token,
quantization_method=quantization_method,
hub_path = args.hub_path,
hub_token = args.hub_token,
quantization_method = quantization_method,
)
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
@ -183,38 +183,38 @@ def run(args):
if __name__ == "__main__":
# Define argument parser
parser = argparse.ArgumentParser(
description="🦥 Fine-tune your llm faster using unsloth!"
description = "🦥 Fine-tune your llm faster using unsloth!"
)
model_group = parser.add_argument_group("🤖 Model Options")
model_group.add_argument(
"--model_name",
type=str,
default="unsloth/llama-3-8b",
help="Model name to load",
type = str,
default = "unsloth/llama-3-8b",
help = "Model name to load",
)
model_group.add_argument(
"--max_seq_length",
type=int,
default=2048,
help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
type = int,
default = 2048,
help = "Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
)
model_group.add_argument(
"--dtype",
type=str,
default=None,
help="Data type for model (None for auto detection)",
type = str,
default = None,
help = "Data type for model (None for auto detection)",
)
model_group.add_argument(
"--load_in_4bit",
action="store_true",
help="Use 4bit quantization to reduce memory usage",
action = "store_true",
help = "Use 4bit quantization to reduce memory usage",
)
model_group.add_argument(
"--dataset",
type=str,
default="yahma/alpaca-cleaned",
help="Huggingface dataset to use for training",
type = str,
default = "yahma/alpaca-cleaned",
help = "Huggingface dataset to use for training",
)
lora_group = parser.add_argument_group(
@ -222,101 +222,101 @@ if __name__ == "__main__":
)
lora_group.add_argument(
"--r",
type=int,
default=16,
help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)",
type = int,
default = 16,
help = "Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)",
)
lora_group.add_argument(
"--lora_alpha",
type=int,
default=16,
help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
type = int,
default = 16,
help = "LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
)
lora_group.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="LoRA dropout rate, default is 0.0 which is optimized.",
type = float,
default = 0.0,
help = "LoRA dropout rate, default is 0.0 which is optimized.",
)
lora_group.add_argument(
"--bias", type=str, default="none", help="Bias setting for LoRA"
"--bias", type = str, default = "none", help = "Bias setting for LoRA"
)
lora_group.add_argument(
"--use_gradient_checkpointing",
type=str,
default="unsloth",
help="Use gradient checkpointing",
type = str,
default = "unsloth",
help = "Use gradient checkpointing",
)
lora_group.add_argument(
"--random_state",
type=int,
default=3407,
help="Random state for reproducibility, default is 3407.",
type = int,
default = 3407,
help = "Random state for reproducibility, default is 3407.",
)
lora_group.add_argument(
"--use_rslora", action="store_true", help="Use rank stabilized LoRA"
"--use_rslora", action = "store_true", help = "Use rank stabilized LoRA"
)
lora_group.add_argument(
"--loftq_config", type=str, default=None, help="Configuration for LoftQ"
"--loftq_config", type = str, default = None, help = "Configuration for LoftQ"
)
training_group = parser.add_argument_group("🎓 Training Options")
training_group.add_argument(
"--per_device_train_batch_size",
type=int,
default=2,
help="Batch size per device during training, default is 2.",
type = int,
default = 2,
help = "Batch size per device during training, default is 2.",
)
training_group.add_argument(
"--gradient_accumulation_steps",
type=int,
default=4,
help="Number of gradient accumulation steps, default is 4.",
type = int,
default = 4,
help = "Number of gradient accumulation steps, default is 4.",
)
training_group.add_argument(
"--warmup_steps",
type=int,
default=5,
help="Number of warmup steps, default is 5.",
type = int,
default = 5,
help = "Number of warmup steps, default is 5.",
)
training_group.add_argument(
"--max_steps", type=int, default=400, help="Maximum number of training steps."
"--max_steps", type = int, default = 400, help = "Maximum number of training steps."
)
training_group.add_argument(
"--learning_rate",
type=float,
default=2e-4,
help="Learning rate, default is 2e-4.",
type = float,
default = 2e-4,
help = "Learning rate, default is 2e-4.",
)
training_group.add_argument(
"--optim", type=str, default="adamw_8bit", help="Optimizer type."
"--optim", type = str, default = "adamw_8bit", help = "Optimizer type."
)
training_group.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay, default is 0.01.",
type = float,
default = 0.01,
help = "Weight decay, default is 0.01.",
)
training_group.add_argument(
"--lr_scheduler_type",
type=str,
default="linear",
help="Learning rate scheduler type, default is 'linear'.",
type = str,
default = "linear",
help = "Learning rate scheduler type, default is 'linear'.",
)
training_group.add_argument(
"--seed",
type=int,
default=3407,
help="Seed for reproducibility, default is 3407.",
type = int,
default = 3407,
help = "Seed for reproducibility, default is 3407.",
)
# Report/Logging arguments
report_group = parser.add_argument_group("📊 Report Options")
report_group.add_argument(
"--report_to",
type=str,
default="tensorboard",
choices=[
type = str,
default = "tensorboard",
choices = [
"azure_ml",
"clearml",
"codecarbon",
@ -331,62 +331,62 @@ if __name__ == "__main__":
"all",
"none",
],
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
help = "The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
)
report_group.add_argument(
"--logging_steps", type=int, default=1, help="Logging steps, default is 1"
"--logging_steps", type = int, default = 1, help = "Logging steps, default is 1"
)
# Saving and pushing arguments
save_group = parser.add_argument_group("💾 Save Model Options")
save_group.add_argument(
"--output_dir", type=str, default="outputs", help="Output directory"
"--output_dir", type = str, default = "outputs", help = "Output directory"
)
save_group.add_argument(
"--save_model", action="store_true", help="Save the model after training"
"--save_model", action = "store_true", help = "Save the model after training"
)
save_group.add_argument(
"--save_method",
type=str,
default="merged_16bit",
choices=["merged_16bit", "merged_4bit", "lora"],
help="Save method for the model, default is 'merged_16bit'",
type = str,
default = "merged_16bit",
choices = ["merged_16bit", "merged_4bit", "lora"],
help = "Save method for the model, default is 'merged_16bit'",
)
save_group.add_argument(
"--save_gguf",
action="store_true",
help="Convert the model to GGUF after training",
action = "store_true",
help = "Convert the model to GGUF after training",
)
save_group.add_argument(
"--save_path", type=str, default="model", help="Path to save the model"
"--save_path", type = str, default = "model", help = "Path to save the model"
)
save_group.add_argument(
"--quantization",
type=str,
default="q8_0",
nargs="+",
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
type = str,
default = "q8_0",
nargs = "+",
help = "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
)
push_group = parser.add_argument_group("🚀 Push Model Options")
push_group.add_argument(
"--push_model",
action="store_true",
help="Push the model to Hugging Face hub after training",
action = "store_true",
help = "Push the model to Hugging Face hub after training",
)
push_group.add_argument(
"--push_gguf",
action="store_true",
help="Push the model as GGUF to Hugging Face hub after training",
action = "store_true",
help = "Push the model as GGUF to Hugging Face hub after training",
)
push_group.add_argument(
"--hub_path",
type=str,
default="hf/model",
help="Path on Hugging Face hub to push the model",
type = str,
default = "hf/model",
help = "Path on Hugging Face hub to push the model",
)
push_group.add_argument(
"--hub_token", type=str, help="Token for pushing the model to Hugging Face hub"
"--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub"
)
args = parser.parse_args()

View file

@ -110,7 +110,6 @@ from unsloth_zoo.device_type import (
from .import_fixes import (
fix_xformers_performance_issue,
fix_vllm_aimv2_issue,
fix_vllm_guided_decoding_params,
ignore_logger_messages,
patch_ipykernel_hf_xet,
patch_trackio,
@ -119,14 +118,13 @@ from .import_fixes import (
fix_xformers_performance_issue()
fix_vllm_aimv2_issue()
fix_vllm_guided_decoding_params()
ignore_logger_messages()
patch_ipykernel_hf_xet()
patch_trackio()
patch_datasets()
del fix_xformers_performance_issue
del patch_vllm_imports
del fix_vllm_aimv2_issue
del ignore_logger_messages
del patch_ipykernel_hf_xet
del patch_trackio

View file

@ -40,7 +40,7 @@ from .synthetic_configs import (
)
def terminate_tree(proc: subprocess.Popen, timeout=15):
def terminate_tree(proc: subprocess.Popen, timeout = 15):
if proc is None or proc.poll() is not None:
return
@ -48,10 +48,10 @@ def terminate_tree(proc: subprocess.Popen, timeout=15):
import psutil
parent = psutil.Process(proc.pid)
for child in parent.children(recursive=True):
for child in parent.children(recursive = True):
child.terminate()
parent.terminate()
parent.wait(timeout=timeout / 2)
parent.wait(timeout = timeout / 2)
return
except:
pass
@ -60,17 +60,17 @@ def terminate_tree(proc: subprocess.Popen, timeout=15):
try:
subprocess.run(
["taskkill", "/T", "/F", "/PID", str(proc.pid)],
capture_output=True,
timeout=5,
capture_output = True,
timeout = 5,
)
proc.wait(timeout=1)
proc.wait(timeout = 1)
return
except:
pass
proc.kill()
try:
proc.wait(timeout=5)
proc.wait(timeout = 5)
except:
pass
@ -81,16 +81,16 @@ class PipeCapture:
def __init__(
self,
pipe,
keep_lines=2000,
echo=False,
name="",
text=True,
encoding="utf-8",
errors="replace",
ready_regex=None,
keep_lines = 2000,
echo = False,
name = "",
text = True,
encoding = "utf-8",
errors = "replace",
ready_regex = None,
):
self.pipe = pipe
self.buf = deque(maxlen=keep_lines)
self.buf = deque(maxlen = keep_lines)
self.lock = threading.Lock()
self.echo = echo
self.name = name
@ -107,7 +107,7 @@ class PipeCapture:
ready_regex = re.compile(ready_regex)
self.ready_regex = ready_regex
self.t = threading.Thread(target=self._reader, daemon=True)
self.t = threading.Thread(target = self._reader, daemon = True)
self.t.start()
def _reader(self):
@ -136,16 +136,16 @@ class PipeCapture:
pass
self.closed_event.set()
def wait_for_ready(self, timeout=None):
def wait_for_ready(self, timeout = None):
return self.ready_event.wait(timeout)
def has_closed(self):
return self.closed_event.is_set()
def wait_until_closed(self, timeout=None):
def wait_until_closed(self, timeout = None):
return self.closed_event.wait(timeout)
def tail(self, n=200):
def tail(self, n = 200):
with self.lock:
return "\n".join(list(self.buf)[-n:])
@ -153,13 +153,13 @@ class PipeCapture:
class SyntheticDataKit:
def __init__(
self,
model_name="unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
max_seq_length=2048,
gpu_memory_utilization=0.98,
float8_kv_cache=False,
conservativeness=1.0,
token=None,
timeout=1200, # maybe this is not enough for large models if we need to download
model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
max_seq_length = 2048,
gpu_memory_utilization = 0.98,
float8_kv_cache = False,
conservativeness = 1.0,
token = None,
timeout = 1200, # maybe this is not enough for large models if we need to download
**kwargs,
):
assert type(model_name) is str
@ -176,25 +176,25 @@ class SyntheticDataKit:
self.config = AutoConfig.from_pretrained(
model_name,
token=token,
token = token,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=token,
token = token,
)
patch_vllm(debug=False)
patch_vllm(debug = False)
engine_args = load_vllm(
model_name=model_name,
config=self.config,
gpu_memory_utilization=gpu_memory_utilization,
max_seq_length=max_seq_length,
disable_log_stats=True,
float8_kv_cache=float8_kv_cache,
conservativeness=conservativeness,
return_args=True,
enable_lora=False,
use_bitsandbytes=False,
compilation_config=3,
model_name = model_name,
config = self.config,
gpu_memory_utilization = gpu_memory_utilization,
max_seq_length = max_seq_length,
disable_log_stats = True,
float8_kv_cache = float8_kv_cache,
conservativeness = conservativeness,
return_args = True,
enable_lora = False,
use_bitsandbytes = False,
compilation_config = 3,
**kwargs,
)
if "dtype" in engine_args:
@ -252,31 +252,31 @@ class SyntheticDataKit:
logger.info(subprocess_commands)
vllm_process = subprocess.Popen(
subprocess_commands,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
start_new_session = True,
)
ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
self.vllm_process = vllm_process
self.stdout_capture = PipeCapture(
vllm_process.stdout,
keep_lines=1000,
echo=True,
name="vLLM STDOUT",
ready_regex=ready_re,
text=False,
keep_lines = 1000,
echo = True,
name = "vLLM STDOUT",
ready_regex = ready_re,
text = False,
)
self.stderr_capture = PipeCapture(
vllm_process.stderr,
keep_lines=2000,
echo=False,
name="vLLM STDERR",
ready_regex=None,
text=False,
keep_lines = 2000,
echo = False,
name = "vLLM STDERR",
ready_regex = None,
text = False,
)
# we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines
ready = self.stdout_capture.wait_for_ready(timeout=timeout)
ready = self.stdout_capture.wait_for_ready(timeout = timeout)
if not ready:
if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:
print("Stdout stream ended before readiness message detected.")
@ -305,21 +305,21 @@ class SyntheticDataKit:
@staticmethod
def from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
max_seq_length=2048,
gpu_memory_utilization=0.9,
float8_kv_cache=False,
conservativeness=1.0,
token=None,
model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
max_seq_length = 2048,
gpu_memory_utilization = 0.9,
float8_kv_cache = False,
conservativeness = 1.0,
token = None,
**kwargs,
):
return SyntheticDataKit(
model_name=model_name,
max_seq_length=max_seq_length,
gpu_memory_utilization=gpu_memory_utilization,
float8_kv_cache=float8_kv_cache,
conservativeness=conservativeness,
token=token,
model_name = model_name,
max_seq_length = max_seq_length,
gpu_memory_utilization = gpu_memory_utilization,
float8_kv_cache = float8_kv_cache,
conservativeness = conservativeness,
token = token,
**kwargs,
)
@ -340,7 +340,7 @@ class SyntheticDataKit:
print("Attempting to terminate the VLLM server gracefully...")
try:
vllm_process.terminate()
vllm_process.wait(timeout=10)
vllm_process.wait(timeout = 10)
print("Server terminated gracefully.")
except subprocess.TimeoutExpired:
print(
@ -364,7 +364,7 @@ class SyntheticDataKit:
gc.collect()
# Delete vLLM module as well
delete_vllm(llm=None)
delete_vllm(llm = None)
def __enter__(self):
return self
@ -375,7 +375,7 @@ class SyntheticDataKit:
def __del__(self):
self.cleanup()
def chunk_data(self, filename=None):
def chunk_data(self, filename = None):
# Chunks data by max tokens and generation length
assert filename is not None
assert os.path.exists(filename)
@ -387,7 +387,7 @@ class SyntheticDataKit:
if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"):
raise RuntimeError("Please use prepare_qa_generation first!")
with open(filename, "r", encoding="utf-8") as f:
with open(filename, "r", encoding = "utf-8") as f:
text = f.read()
max_tokens = (
@ -395,7 +395,7 @@ class SyntheticDataKit:
) # -128 to reduce errors
if max_tokens <= 5:
raise RuntimeError("Generation length is way too long!")
input_ids = self.tokenizer(text, add_special_tokens=False).input_ids
input_ids = self.tokenizer(text, add_special_tokens = False).input_ids
# Get left and right boundaries
length = len(input_ids)
@ -416,21 +416,21 @@ class SyntheticDataKit:
chunked_text = self.tokenizer.decode(input_ids[left:right])
new_filename = f"{filename}_{i}{extension}"
all_filenames.append(new_filename)
with open(new_filename, "w", encoding="utf-8") as f:
with open(new_filename, "w", encoding = "utf-8") as f:
f.write(chunked_text)
return all_filenames
def prepare_qa_generation(
self,
output_folder="data",
max_generation_tokens=512,
temperature=0.7,
top_p=0.95,
overlap=64,
default_num_pairs=25,
cleanup_threshold=1.0,
cleanup_batch_size=4,
cleanup_temperature=0.3,
output_folder = "data",
max_generation_tokens = 512,
temperature = 0.7,
top_p = 0.95,
overlap = 64,
default_num_pairs = 25,
cleanup_threshold = 1.0,
cleanup_batch_size = 4,
cleanup_temperature = 0.3,
):
assert hasattr(self, "model_name")
assert hasattr(self, "max_seq_length")
@ -439,7 +439,7 @@ class SyntheticDataKit:
locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final"
locations = locations.split(",")
for path in locations:
os.makedirs(os.path.join(output_folder, path), exist_ok=True)
os.makedirs(os.path.join(output_folder, path), exist_ok = True)
self.max_generation_tokens = max_generation_tokens
@ -459,7 +459,7 @@ class SyntheticDataKit:
.replace("{cleanup_temperature}", str(cleanup_temperature))
)
with open("synthetic_data_kit_config.yaml", "w", encoding="utf-8") as f:
with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f:
f.write(config)
self.overlap = overlap

View file

@ -114,16 +114,6 @@ def fix_xformers_performance_issue():
print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
def fix_vllm_aimv2_issue():
if importlib.util.find_spec("vllm") is None:
return
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
vllm_version = importlib_version("vllm")
if Version(vllm_version) < Version("0.10.1"):
vllm_version = importlib.util.find_spec("vllm").origin
vllm_version = os.path.split(vllm_version)[0]
ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py"
try:
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
def fix_vllm_aimv2_issue():
if importlib.util.find_spec("vllm") is None:
@ -165,22 +155,6 @@ def fix_vllm_aimv2_issue():
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
def fix_vllm_guided_decoding_params():
if importlib.util.find_spec("vllm") is None:
return
# GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
# https://github.com/vllm-project/vllm/pull/22772/files
# trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
import vllm
try:
from vllm.sampling_params import GuidedDecodingParams
except ImportError:
vllm.sampling_params.GuidedDecodingParams = (
vllm.sampling_params.StructuredOutputsParams
)
def ignore_logger_messages():
# Ignore Environment variable `HF_TOKEN` is set
try:

View file

@ -75,7 +75,7 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
tl.float32
)
@ -162,7 +162,7 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
tl.float32
)
@ -246,7 +246,7 @@ def _cross_entropy_backward(
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
# Do logit scaling for Cohere
if DO_LOGIT_SCALING:
@ -277,7 +277,7 @@ def _cross_entropy_backward(
y = y * (1.0 - partial * partial)
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
@ -304,7 +304,7 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks: int = div + (mod != 0)
losses = torch.empty(n_rows, dtype=torch.float32, device=device)
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
DO_SOFTCAPPING: bool = bool(logit_softcapping != 0)
DO_LOGIT_SCALING: bool = bool(logit_scaling != 0)
@ -316,7 +316,7 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
if is_cdna():
num_warps = num_warps // 2
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=device)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
with torch_gpu_device(device):
_cross_entropy_forward[(n_rows,)](
@ -325,13 +325,13 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
DO_SOFTCAPPING=DO_SOFTCAPPING,
SOFTCAP=logit_softcapping,
DO_LOGIT_SCALING=DO_LOGIT_SCALING,
LOGIT_SCALE=logit_scaling,
num_warps=num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
@ -340,8 +340,8 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
n_rows,
n_chunks,
),
dtype=torch.float32,
device=device,
dtype = torch.float32,
device = device,
)
with torch_gpu_device(device):
@ -356,18 +356,18 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
N_CHUNKS=n_chunks,
BLOCK_SIZE=MAX_FUSED_SIZE,
DO_SOFTCAPPING=DO_SOFTCAPPING,
SOFTCAP=logit_softcapping,
DO_LOGIT_SCALING=DO_LOGIT_SCALING,
LOGIT_SCALE=logit_scaling,
num_warps=32 if not is_cdna() else 16,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = 32 if not is_cdna() else 16,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
losses += logsumexp
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
@ -404,13 +404,13 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
DO_SOFTCAPPING=ctx.DO_SOFTCAPPING,
SOFTCAP=ctx.logit_softcapping,
DO_LOGIT_SCALING=ctx.DO_LOGIT_SCALING,
LOGIT_SCALE=ctx.logit_scaling,
num_warps=8,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
LOGIT_SCALE = ctx.logit_scaling,
num_warps = 8,
)
return (
logits,
@ -423,9 +423,9 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
def fast_cross_entropy_loss(
logits,
labels,
logit_softcapping=0,
logit_scaling=0,
n_items=None,
logit_softcapping = 0,
logit_scaling = 0,
n_items = None,
):
"""
Arguments:
@ -455,5 +455,5 @@ if (Version(torch.__version__) < Version("2.4.0")) and not hasattr(
# Patch CE Losses in transformers
def patch_loss_functions(torch_compile=True):
_patch_loss_functions(fast_cross_entropy_loss, torch_compile=torch_compile)
def patch_loss_functions(torch_compile = True):
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)

View file

@ -33,7 +33,7 @@ try:
)
_flex_attention = torch.compile(
_flex_attention, dynamic=True, options=torch_compile_options
_flex_attention, dynamic = True, options = torch_compile_options
)
HAS_FLEX_ATTENTION = False
except:
@ -42,7 +42,7 @@ except:
if not HAS_FLEX_ATTENTION:
# Logit softcapping
@torch.compile(fullgraph=True, dynamic=True, options=torch_compile_options)
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.config.num_attention_heads
head_dim = self.head_dim
@ -62,13 +62,13 @@ if not HAS_FLEX_ATTENTION:
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype=Q.dtype) # Follow Keras exactly
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim=-1, dtype=torch.float32).to(Q.dtype)
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads * head_dim)
@ -92,7 +92,7 @@ else:
return q_idx >= kv_idx
@functools.lru_cache
def sliding_window_masker(size=4096):
def sliding_window_masker(size = 4096):
def sliding_window(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= size
@ -101,23 +101,23 @@ else:
return sliding_window
@functools.lru_cache
def create_block_mask(mask, n=128):
def create_block_mask(mask, n = 128):
return _create_block_mask(
mask,
1,
1,
n,
n,
BLOCK_SIZE=128,
_compile=True,
BLOCK_SIZE = 128,
_compile = True,
)
def create_flex_attention_causal_mask(max_seq_length=8192):
def create_flex_attention_causal_mask(max_seq_length = 8192):
causal_mask = create_block_mask(causal_masker, max_seq_length)
return causal_mask
def create_flex_attention_sliding_window_mask(
max_seq_length=8192, sliding_window=4096
max_seq_length = 8192, sliding_window = 4096
):
sliding_masker = sliding_window_masker(sliding_window)
causal_mask = create_block_mask(sliding_masker, max_seq_length)
@ -129,9 +129,9 @@ else:
score_mod = generate_tanh_softcap(t)
return functools.partial(
_flex_attention,
score_mod=score_mod,
scale=scale,
enable_gqa=True,
score_mod = score_mod,
scale = scale,
enable_gqa = True,
)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
@ -140,7 +140,7 @@ else:
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
fx = flex_attention(s, t)
A = fx(query=Q, key=K, value=V, block_mask=causal_mask)
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads * head_dim)
return A
@ -170,17 +170,17 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len)
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype=Q.dtype) # Follow Keras exactly
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch_matmul(Q, K.transpose(2, 3))
# Logit softcapping
A /= t
torch_tanh(A, out=A)
torch_tanh(A, out = A)
A *= t
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32).to(Q.dtype)
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch_matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads * head_dim)

View file

@ -63,21 +63,21 @@ except:
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
pid_m = tl.program_id(axis = 0)
pid_n = tl.program_id(axis = 1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
x = tl.load(x_ptr + offs, mask = mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
tl.store(y_ptr + offs, y, mask = mask)
def weight_dequant_block(
x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16
x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype = torch.bfloat16
) -> torch.Tensor:
if not x.is_contiguous():
x = x.contiguous()
@ -85,16 +85,16 @@ def weight_dequant_block(
s = s.contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=dtype)
y = torch.empty_like(x, dtype = dtype)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)
return y
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
if s.shape[1] == 1:
# this is row quantized weight, just simple multiplication suffices
if x.shape[0] == s.shape[0]:
@ -108,13 +108,13 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
return y
else:
# this is block quantized weight
return weight_dequant_block(x, s, dtype=dtype)
return weight_dequant_block(x, s, dtype = dtype)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
pid = tl.program_id(axis = 0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.0
@ -134,13 +134,13 @@ def act_quant(
if not x.is_contiguous():
x = x.contiguous()
assert x.shape[-1] % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
y = torch.empty_like(x, dtype = torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32)
def grid(meta):
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)
return y, s
@ -182,7 +182,7 @@ def _w8a8_block_fp8_matmul(
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
pid = tl.program_id(axis = 0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@ -202,10 +202,10 @@ def _w8a8_block_fp8_matmul(
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(a_ptrs, mask = offs_k[None, :] < K - k * BLOCK_SIZE_K, other = 0.0)
b = tl.load(b_ptrs, mask = offs_k[:, None] < K - k * BLOCK_SIZE_K, other = 0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
@ -227,7 +227,7 @@ def _w8a8_block_fp8_matmul(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c, mask = c_mask)
def w8a8_block_fp8_matmul_triton(
@ -267,7 +267,7 @@ def w8a8_block_fp8_matmul_triton(
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
C = A.new_empty(C_shape, dtype = output_dtype)
BLOCK_SIZE_M = 128
if M < BLOCK_SIZE_M:
@ -303,10 +303,10 @@ def w8a8_block_fp8_matmul_triton(
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_N = BLOCK_SIZE_N,
BLOCK_SIZE_K = BLOCK_SIZE_K,
GROUP_SIZE_M = 8,
)
return C
@ -324,7 +324,7 @@ def torchao_block_matmul(
act_scale.contiguous(),
weight_q.contiguous(),
weight_scale.contiguous(),
block_size=block_size[1],
block_size = block_size[1],
)
return out.to(output_dtype)
@ -372,7 +372,7 @@ class FP8BlockQuantLinear(torch.autograd.Function):
scale,
weight_scale,
block_size,
output_dtype=X.dtype,
output_dtype = X.dtype,
)
ctx.weight = weight
ctx.weight_scale = weight_scale
@ -394,7 +394,7 @@ def fp8_torch_block_quant_forward(X, weight, weight_scale):
class FbgemmFp8Linear_matmul(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, weight_scale, bias=None):
def forward(ctx, x, weight, weight_scale, bias = None):
if weight.shape[0] == weight_scale.shape[0] and (
weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0
):
@ -409,7 +409,7 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]).contiguous(),
scale_ub=getattr(weight, "input_scale_ub", None),
scale_ub = getattr(weight, "input_scale_ub", None),
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
@ -423,7 +423,7 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
weight_scale = weight_scale.contiguous()
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True
x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum = True
)
output = output + bias if bias is not None else output
# Hacky for now, we have the output to the device of x
@ -459,13 +459,13 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
@torch_compile
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None):
def fbgemm_fp8_linear(X, weight, weight_scale, bias = None):
return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)
class FP8_fbgemm_block_linear(torch.autograd.Function):
@staticmethod
def forward(ctx, X, weight, weight_scale, bias=None):
def forward(ctx, X, weight, weight_scale, bias = None):
orig_shape = X.shape
X = X.view(-1, X.shape[-1])
@ -516,7 +516,7 @@ class FP8_fbgemm_block_linear(torch.autograd.Function):
@torch_compile
def fp8_fbgemm_block_linear(X, weight, weight_scale, bias=None):
def fp8_fbgemm_block_linear(X, weight, weight_scale, bias = None):
return FP8_fbgemm_block_linear.apply(X, weight, weight_scale, bias)
@ -525,11 +525,11 @@ def test_has_fbgemm():
# For example RTX 5090 and RTX 4090 does not work
# [TODO] Investigate with TorchAO why FBGEMM fails on consumer GPUs
M, N, K = 128, 128, 128
xq = torch.ones(M, K, dtype=torch.float8_e4m3fn, device="cuda")
xq = torch.ones(M, K, dtype = torch.float8_e4m3fn, device = "cuda")
wq = xq
M, K = xq.shape
N, _ = wq.shape
block_scale = torch.ones(M // 128, K // 128, dtype=torch.float32, device="cuda")
block_scale = torch.ones(M // 128, K // 128, dtype = torch.float32, device = "cuda")
has_fbgemm = False
try:
out = torch.ops.fbgemm.f8f8bf16_blockwise(xq, wq, block_scale, block_scale)
@ -575,7 +575,7 @@ except:
@torch_compile
def fp8_linear(X, weight, weight_scale, bias=None):
def fp8_linear(X, weight, weight_scale, bias = None):
if weight_scale.ndim == 2 and weight_scale.shape[1] > 1:
# This is block quantized FP8 matmul
out = fp8_block_quant_linear(X, weight, weight_scale)
@ -585,7 +585,7 @@ def fp8_linear(X, weight, weight_scale, bias=None):
return out
def module_forward_patch(forward_function, scale_attr="weight_scale"):
def module_forward_patch(forward_function, scale_attr = "weight_scale"):
def patched_forward(self, X):
return forward_function(X, self.weight, getattr(self, scale_attr))

View file

@ -49,22 +49,22 @@ def _exact_forward_kernel(
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
# h = f * up
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask=mask)
tl.store(h + offsets, h_row, mask = mask)
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=device)
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch_gpu_device(device):
_exact_forward_kernel[grid](
@ -72,8 +72,8 @@ def geglu_exact_forward_kernel(gate, up):
up,
out,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return out
@ -107,9 +107,9 @@ def _exact_backward_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# Break e_row away for re-use
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
@ -132,9 +132,9 @@ def _exact_backward_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
tl.store(g + offsets, de_row, mask=mask) # de
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
def geglu_exact_backward_kernel(DW, e, g):
@ -147,8 +147,8 @@ def geglu_exact_backward_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g
@ -177,8 +177,8 @@ def _approx_forward_kernel(
# h = f * up
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
f_row = (
0.5 * e_row * (triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0)
@ -187,14 +187,14 @@ def _approx_forward_kernel(
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask=mask)
tl.store(h + offsets, h_row, mask = mask)
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=device)
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch_gpu_device(device):
_approx_forward_kernel[grid](
@ -202,8 +202,8 @@ def geglu_approx_forward_kernel(gate, up):
up,
out,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return out
@ -241,9 +241,9 @@ def _approx_backward_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# See https://www.desmos.com/calculator/nqprfoni6x
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
@ -269,9 +269,9 @@ def _approx_backward_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
tl.store(g + offsets, de_row, mask=mask) # de
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
def geglu_approx_backward_kernel(DW, e, g):
@ -284,7 +284,7 @@ def geglu_approx_backward_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g

View file

@ -47,19 +47,19 @@ def layernorm_forward(
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
# are in float32!
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
mean_X = tl.sum(X_row, axis=0) / n_cols
mean_X = tl.sum(X_row, axis = 0) / n_cols
# (X[0] - mean) == -mean so we need to mask it out
XX = tl.where(mask, X_row - mean_X, 0)
row_var = tl.sum(XX * XX, axis=0) / n_cols
row_var = tl.sum(XX * XX, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
tl.store(mu, mean_X)
output = (XX * inv_var) * W_row + b_row
tl.store(Y + col_offsets, output, mask=mask)
tl.store(Y + col_offsets, output, mask = mask)
@triton.jit
@ -88,10 +88,10 @@ def layernorm_backward(
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
# are in float32!
dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0).to(tl.float32)
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
inv_var = tl.load(r).to(tl.float32)
mean = tl.load(mu).to(tl.float32)
@ -99,11 +99,11 @@ def layernorm_backward(
dY_W = dY_row * W_row
dX_row = (
dY_W
- tl.sum(dY_W, axis=0) / n_cols
- normed * tl.sum(dY_W * normed, axis=0) / n_cols
- tl.sum(dY_W, axis = 0) / n_cols
- normed * tl.sum(dY_W * normed, axis = 0) / n_cols
)
dX_row = dX_row * inv_var
tl.store(dY + col_offsets, dX_row, mask=mask)
tl.store(dY + col_offsets, dX_row, mask = mask)
class Fast_Layernorm(torch.autograd.Function):
@ -115,9 +115,9 @@ class Fast_Layernorm(torch.autograd.Function):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device)
r = torch.empty(n_rows, dtype=torch.float32, device=device)
mu = torch.empty(n_rows, dtype=torch.float32, device=device)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
r = torch.empty(n_rows, dtype = torch.float32, device = device)
mu = torch.empty(n_rows, dtype = torch.float32, device = device)
with torch_gpu_device(device):
layernorm_forward[(n_rows,)](
@ -131,8 +131,8 @@ class Fast_Layernorm(torch.autograd.Function):
mu,
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
@ -160,8 +160,8 @@ class Fast_Layernorm(torch.autograd.Function):
mu,
n_cols,
ctx.eps,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
return dX, None, None, None, None
@ -181,26 +181,26 @@ def fast_layernorm(layernorm, X):
def test_layernorm(
dim=1024,
eps=1e-5,
dtype=torch.float16,
bsz=21,
random_state=3407,
seqlen=3341,
dim = 1024,
eps = 1e-5,
dtype = torch.float16,
bsz = 21,
random_state = 3407,
seqlen = 3341,
):
from torch.nn import LayerNorm
layernorm = LayerNorm((dim,), eps=eps, device="cuda", dtype=dtype)
layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
torch.cuda.manual_seed(random_state)
torch.manual_seed(random_state)
torch.nn.init.uniform_(layernorm.weight)
torch.nn.init.uniform_(layernorm.bias)
X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda")
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
XX = X.clone()
X.requires_grad_(True)
XX.requires_grad_(True)
Y = layernorm(X)
YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True)
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
Y.backward(YY)
correct_grad = X.grad.clone()
# from unsloth.kernels import fast_layernorm
@ -212,14 +212,14 @@ def test_layernorm(
def testing_suite_layernorm():
for dim in [512, 1024, 2048]:
for dtype in [torch.float16, torch.bfloat16]:
with torch.autocast(device_type="cuda", dtype=dtype):
with torch.autocast(device_type = "cuda", dtype = dtype):
for seqlen in [3341, 2048, 349]:
for random_state in [3407, 42]:
test_layernorm(
dim=dim,
eps=1e-5,
dtype=dtype,
bsz=21,
random_state=random_state,
seqlen=seqlen,
dim = dim,
eps = 1e-5,
dtype = dtype,
bsz = 21,
random_state = random_state,
seqlen = seqlen,
)

View file

@ -44,7 +44,7 @@ def post_process_results(
dtype: torch.dtype,
autotune: bool,
):
df = KernelResult.to_dataframe(results, sort_by="speedup")
df = KernelResult.to_dataframe(results, sort_by = "speedup")
df = create_merged_results(df, mode, seqlen, dtype, autotune)
return df
@ -63,16 +63,16 @@ def save_results(
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"Saving results to {save_path}")
df.to_csv(save_path, index=False)
df.to_csv(save_path, index = False)
def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool):
block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1])
block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1])
block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1])
num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step=2)
num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step = 2)
num_stages_range = multiples_of_range(
args.num_stages[0], args.num_stages[1], step=1
args.num_stages[0], args.num_stages[1], step = 1
)
mode = args.mode
@ -96,39 +96,39 @@ def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y:
):
if mode == "forward":
kernel_config = KernelConfigForward(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_w=tma_load_a,
use_tma_load_x=tma_load_b,
permute_x=permute_x,
permute_y=permute_y,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_w = tma_load_a,
use_tma_load_x = tma_load_b,
permute_x = permute_x,
permute_y = permute_y,
)
elif mode == "dW":
kernel_config = KernelConfigBackward_dW(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_dy=tma_load_a,
use_tma_load_x=tma_load_b,
permute_x=permute_x,
permute_y=permute_y,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_dy = tma_load_a,
use_tma_load_x = tma_load_b,
permute_x = permute_x,
permute_y = permute_y,
)
elif mode == "dX":
kernel_config = KernelConfigBackward_dX(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_dy=tma_load_a,
use_tma_load_w=tma_load_b,
permute_x=permute_x,
permute_y=permute_y,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_dy = tma_load_a,
use_tma_load_w = tma_load_b,
permute_x = permute_x,
permute_y = permute_y,
)
else:
raise ValueError(f"Invalid mode: {mode}")
@ -161,7 +161,7 @@ def power_of_two_range(start, end):
return [2**i for i in range(int(start), int(end) + 1)]
def multiples_of_range(start, end, step=1):
def multiples_of_range(start, end, step = 1):
return list(range(start, end + step, step))
@ -221,8 +221,8 @@ def postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_
print(f"{mode} {key}: {value.all_kwargs()}")
save_autotune_results(
autotuner.cache,
mode=mode,
ref_time=ref_time,
fused_time=fused_time,
results_dir=results_dir,
mode = mode,
ref_time = ref_time,
fused_time = fused_time,
results_dir = results_dir,
)

View file

@ -60,7 +60,7 @@ def get_per_device_per_stream_alloc_fn(device):
or _per_stream_tensors[stream].numel() < size
):
_per_stream_tensors[stream] = torch.empty(
size, device=device, dtype=torch.int8
size, device = device, dtype = torch.int8
)
_per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
return _per_stream_tensors[stream]
@ -160,7 +160,7 @@ def grouped_gemm_forward(
if use_tma or autotune:
def alloc_fn(size: int, alignment: int, stream: int):
return torch.empty(size, device="cuda", dtype=torch.int8)
return torch.empty(size, device = "cuda", dtype = torch.int8)
triton.set_allocator(alloc_fn)
@ -211,7 +211,7 @@ def grouped_gemm_forward(
f"DEBUG::GROUPED_GEMM {topk_weights.tolist()} {gather_indices.tolist()}"
)
y = torch.empty((total_tokens, N), device=X.device, dtype=X.dtype)
y = torch.empty((total_tokens, N), device = X.device, dtype = X.dtype)
if total_tokens == 0 or N == 0:
return y
@ -357,7 +357,7 @@ def grouped_gemm_dX(
def alloc_fn(size: int, alignment: int, stream: int):
# print(f"DEBUG::GROUPED_GEMM alloc_fn {size=} {alignment=} {stream=}")
return torch.empty(size, device="cuda", dtype=torch.int8)
return torch.empty(size, device = "cuda", dtype = torch.int8)
triton.set_allocator(alloc_fn)
@ -383,7 +383,7 @@ def grouped_gemm_dX(
# Note that the output shape is [NUM_TOKENS * TOPK, K] even when `permute_x` is True since we need to accumulate gradients across all experts chosen by the token.
# This will be done in a post-processing step reduction step.
output_shape = (total_tokens, K)
dX = torch.zeros(output_shape, device=dY.device, dtype=dY.dtype)
dX = torch.zeros(output_shape, device = dY.device, dtype = dY.dtype)
NUM_SMS = torch.cuda.get_device_properties(
"cuda"
@ -512,7 +512,7 @@ def grouped_gemm_dW(
if use_tma or autotune:
def alloc_fn(size: int, alignment: int, stream: int):
return torch.empty(size, device="cuda", dtype=torch.int8)
return torch.empty(size, device = "cuda", dtype = torch.int8)
triton.set_allocator(alloc_fn)
@ -538,7 +538,7 @@ def grouped_gemm_dW(
assert M_grad == total_tokens, f"dY M ({M_grad}) != total_tokens ({total_tokens})"
dW = torch.zeros((num_experts, N, K), device=X.device, dtype=X.dtype)
dW = torch.zeros((num_experts, N, K), device = X.device, dtype = X.dtype)
if not autotune:
BLOCK_SIZE_M = min(total_tokens, BLOCK_SIZE_M)
@ -663,17 +663,17 @@ class GroupedGemm(torch.autograd.Function):
fwd_config["use_tma_store"] = kernel_config_fwd.use_tma_store
return grouped_gemm_forward(
X=X,
W=W,
topk=topk,
m_sizes=m_sizes,
gather_indices=gather_indices,
topk_weights=topk_weights,
permute_x=permute_x,
permute_y=permute_y,
fuse_mul_post=fuse_mul_post,
X = X,
W = W,
topk = topk,
m_sizes = m_sizes,
gather_indices = gather_indices,
topk_weights = topk_weights,
permute_x = permute_x,
permute_y = permute_y,
fuse_mul_post = fuse_mul_post,
# Autotune -- this will override the manual kernel config if true
autotune=autotune,
autotune = autotune,
# Manual kernel config
**fwd_config,
)
@ -719,15 +719,15 @@ class GroupedGemm(torch.autograd.Function):
bwd_dW_config["num_stages"] = kernel_config_bwd_dW.num_stages
dW = grouped_gemm_dW(
X=X,
dY=dY,
m_sizes=m_sizes,
gather_indices=gather_indices,
topk=topk,
permute_x=permute_x,
permute_y=permute_y,
X = X,
dY = dY,
m_sizes = m_sizes,
gather_indices = gather_indices,
topk = topk,
permute_x = permute_x,
permute_y = permute_y,
# Autotune -- this will override the manual kernel config if true
autotune=autotune,
autotune = autotune,
# Manual kernel config
**bwd_dW_config,
)
@ -747,21 +747,21 @@ class GroupedGemm(torch.autograd.Function):
bwd_dX_config["num_stages"] = kernel_config_bwd_dX.num_stages
dX = grouped_gemm_dX(
dY=dY,
W=W,
m_sizes=m_sizes,
gather_indices=gather_indices,
topk=topk,
permute_x=permute_x,
permute_y=permute_y,
dY = dY,
W = W,
m_sizes = m_sizes,
gather_indices = gather_indices,
topk = topk,
permute_x = permute_x,
permute_y = permute_y,
# Autotune -- this will override the manual kernel config if true
autotune=autotune,
autotune = autotune,
# Manual kernel config
**bwd_dX_config,
)
if topk > 1 and permute_x:
dX = dX.view(X.shape[0], topk, -1).sum(dim=1)
dX = dX.view(X.shape[0], topk, -1).sum(dim = 1)
else:
dX = None
@ -866,8 +866,8 @@ def grouped_gemm(
gather_indices: torch.Tensor = None,
permute_x: bool = False,
permute_y: bool = False,
topk_weights=None,
fuse_mul_post=False,
topk_weights = None,
fuse_mul_post = False,
kernel_config_fwd: KernelConfigForward = None,
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
@ -908,31 +908,31 @@ def grouped_gemm(
check_valid_config_fwd(
permute_x,
permute_y,
use_tma_load_x=kernel_config_fwd.use_tma_load_x,
use_tma_load_w=kernel_config_fwd.use_tma_load_w,
use_tma_store=kernel_config_fwd.use_tma_store,
fuse_mul_post=fuse_mul_post,
is_first_gemm=is_first_gemm,
use_tma_load_x = kernel_config_fwd.use_tma_load_x,
use_tma_load_w = kernel_config_fwd.use_tma_load_w,
use_tma_store = kernel_config_fwd.use_tma_store,
fuse_mul_post = fuse_mul_post,
is_first_gemm = is_first_gemm,
)
if kernel_config_bwd_dW is not None and not dX_only:
check_valid_config_bwd_dW(
permute_x,
permute_y,
use_tma_load_dY=kernel_config_bwd_dW.use_tma_load_dy,
use_tma_load_x=kernel_config_bwd_dW.use_tma_load_x,
use_tma_store=kernel_config_bwd_dW.use_tma_store,
fuse_mul_post=fuse_mul_post,
is_first_gemm=is_first_gemm,
use_tma_load_dY = kernel_config_bwd_dW.use_tma_load_dy,
use_tma_load_x = kernel_config_bwd_dW.use_tma_load_x,
use_tma_store = kernel_config_bwd_dW.use_tma_store,
fuse_mul_post = fuse_mul_post,
is_first_gemm = is_first_gemm,
)
if kernel_config_bwd_dX is not None and not dW_only:
check_valid_config_bwd_dX(
permute_x,
permute_y,
use_tma_load_dY=kernel_config_bwd_dX.use_tma_load_dy,
use_tma_load_w=kernel_config_bwd_dX.use_tma_load_w,
use_tma_store=kernel_config_bwd_dX.use_tma_store,
fuse_mul_post=fuse_mul_post,
is_first_gemm=is_first_gemm,
use_tma_load_dY = kernel_config_bwd_dX.use_tma_load_dy,
use_tma_load_w = kernel_config_bwd_dX.use_tma_load_w,
use_tma_store = kernel_config_bwd_dX.use_tma_store,
fuse_mul_post = fuse_mul_post,
is_first_gemm = is_first_gemm,
)
if permute_x or permute_y:

View file

@ -68,25 +68,25 @@ def _grouped_gemm_forward_kernel(
if USE_TMA_LOAD_X:
x_desc = tl._experimental_make_tensor_descriptor(
x_ptr,
shape=[TOTAL_TOKENS, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
shape = [TOTAL_TOKENS, K],
strides = [K, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
if USE_TMA_LOAD_W:
expert_stride = N * K
w_desc = tl._experimental_make_tensor_descriptor(
w_ptr,
shape=[NUM_EXPERTS, N, K],
strides=[expert_stride, K, 1],
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
shape = [NUM_EXPERTS, N, K],
strides = [expert_stride, K, 1],
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
m_end = 0
processed_tiles = 0
m_block_range = tl.arange(0, BLOCK_SIZE_M)
for expert_idx in tl.range(NUM_EXPERTS, flatten=FLATTEN):
for expert_idx in tl.range(NUM_EXPERTS, flatten = FLATTEN):
m_start = m_end
m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
m_end = m_start + m_size
@ -102,9 +102,9 @@ def _grouped_gemm_forward_kernel(
if USE_TMA_STORE:
y_desc = tl._experimental_make_tensor_descriptor(
y_ptr, # + m_start * N,
shape=[m_end, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
shape = [m_end, N],
strides = [N, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
# Process tiles for this expert
@ -127,7 +127,7 @@ def _grouped_gemm_forward_kernel(
)
expert_token_idx = tl.load(
gather_indices_ptr + indices_to_gather,
mask=indices_to_gather < TOTAL_TOKENS,
mask = indices_to_gather < TOTAL_TOKENS,
)
expert_token_offsets = expert_token_idx[:, None]
@ -178,7 +178,7 @@ def _grouped_gemm_forward_kernel(
if SHOULD_FUSE_MUL:
topk_load_idx = expert_token_offsets
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = acc_dtype)
offs_k = tl.arange(0, BLOCK_SIZE_K)
@ -194,19 +194,19 @@ def _grouped_gemm_forward_kernel(
for k_offset in range(0, K, BLOCK_SIZE_K):
if not USE_TMA_LOAD_X:
x = tl.load(x_ptrs, mask=row_mask)
x = tl.load(x_ptrs, mask = row_mask)
else:
x = x_desc.load([m_start + off_am, k_offset])
if FUSE_MUL_PRE:
# Check for correct broadcasting
topk_weights = tl.load(
topk_weights_ptr + topk_load_idx, mask=row_mask
topk_weights_ptr + topk_load_idx, mask = row_mask
)
x *= topk_weights.to(x.dtype)
if not USE_TMA_LOAD_W:
w = tl.load(w_ptrs, mask=offs_bn[:, None] < N)
w = tl.load(w_ptrs, mask = offs_bn[:, None] < N)
else:
w = w_desc.load(
[expert_idx, tile_n_idx * BLOCK_SIZE_N, k_offset]
@ -228,7 +228,7 @@ def _grouped_gemm_forward_kernel(
if FUSE_MUL_POST:
# Check for correct broadcasting
topk_weights = tl.load(
topk_weights_ptr + topk_load_idx, mask=row_mask
topk_weights_ptr + topk_load_idx, mask = row_mask
)
y *= topk_weights.to(output_dtype)
@ -243,7 +243,7 @@ def _grouped_gemm_forward_kernel(
tl.store(
y_ptr + store_idx + offs_bn[None, :],
y,
mask=store_mask,
mask = store_mask,
)
tidx += NUM_SMS
@ -251,9 +251,9 @@ def _grouped_gemm_forward_kernel(
_autotuned_grouped_gemm_forward_kernel = triton.autotune(
configs=get_forward_configs(),
prune_configs_by={"early_config_prune": prune_kernel_configs_fwd},
key=[
configs = get_forward_configs(),
prune_configs_by = {"early_config_prune": prune_kernel_configs_fwd},
key = [
"NUM_EXPERTS",
"NUM_TOKENS",
"N",

View file

@ -109,9 +109,9 @@ class KernelResult:
def to_dict(self):
return OrderedDict(
**asdict(self.kernel_config),
torch_time=self.torch_time,
triton_time=self.triton_time,
speedup=self.speedup,
torch_time = self.torch_time,
triton_time = self.triton_time,
speedup = self.speedup,
)
@staticmethod
@ -119,7 +119,7 @@ class KernelResult:
results: list["KernelResult"], sort_by: str = "speedup", ascending: bool = False
):
df = pd.DataFrame([result.to_dict() for result in results])
df = df.sort_values(by=sort_by, ascending=ascending)
df = df.sort_values(by = sort_by, ascending = ascending)
return df
@staticmethod
@ -130,7 +130,7 @@ class KernelResult:
filename: str = "results.csv",
):
df = KernelResult.to_dataframe(results, sort_by, ascending)
df.to_csv(filename, index=False)
df.to_csv(filename, index = False)
@staticmethod
def print_table(
@ -140,17 +140,17 @@ class KernelResult:
num_results: int = 10,
):
df = KernelResult.to_dataframe(results, sort_by, ascending)
print(df.head(num_results).to_string(index=False))
print(df.head(num_results).to_string(index = False))
def get_kernel_configs(
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
num_warps=DEFAULT_NUM_WARPS,
num_stages=DEFAULT_NUM_STAGES,
use_tma_loads=BOOLS,
fuse_permute=BOOLS,
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
num_warps = DEFAULT_NUM_WARPS,
num_stages = DEFAULT_NUM_STAGES,
use_tma_loads = BOOLS,
fuse_permute = BOOLS,
):
kernel_configs_fwd = []
kernel_configs_backward_dW = []
@ -160,44 +160,44 @@ def get_kernel_configs(
):
kernel_configs_fwd.append(
KernelConfigForward(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=w,
num_stages=s,
use_tma_load_x=use_tma_load,
use_tma_load_w=use_tma_load,
use_tma_store=False,
permute_x=permute,
permute_y=permute,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = w,
num_stages = s,
use_tma_load_x = use_tma_load,
use_tma_load_w = use_tma_load,
use_tma_store = False,
permute_x = permute,
permute_y = permute,
)
)
kernel_configs_backward_dW.append(
KernelConfigBackward_dW(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=w,
num_stages=s,
use_tma_load_dy=use_tma_load,
use_tma_load_x=use_tma_load,
use_tma_store=False,
permute_x=permute,
permute_y=permute,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = w,
num_stages = s,
use_tma_load_dy = use_tma_load,
use_tma_load_x = use_tma_load,
use_tma_store = False,
permute_x = permute,
permute_y = permute,
)
)
kernel_configs_backward_dX.append(
KernelConfigBackward_dX(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
num_warps=w,
num_stages=s,
use_tma_load_dy=use_tma_load,
use_tma_load_w=use_tma_load,
use_tma_store=False,
permute_x=permute,
permute_y=permute,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
num_warps = w,
num_stages = s,
use_tma_load_dy = use_tma_load,
use_tma_load_w = use_tma_load,
use_tma_store = False,
permute_x = permute,
permute_y = permute,
)
)

View file

@ -78,8 +78,8 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
self.gate = torch.nn.Parameter(gate)
# experts
self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=True)
self.down_proj = torch.nn.Parameter(down_proj, requires_grad=True)
self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad = True)
self.down_proj = torch.nn.Parameter(down_proj, requires_grad = True)
self.act_fn = ACT2FN[config.hidden_act]
@staticmethod
@ -90,17 +90,17 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
gate = moe_block.gate.weight.data
gate_proj = torch.stack(
[moe_block.experts[i].gate_proj.weight.data for i in range(num_experts)],
dim=0,
dim = 0,
)
up_proj = torch.stack(
[moe_block.experts[i].up_proj.weight.data for i in range(num_experts)],
dim=0,
dim = 0,
)
down_proj = torch.stack(
[moe_block.experts[i].down_proj.weight.data for i in range(num_experts)],
dim=0,
dim = 0,
)
gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
gate_up_proj = torch.cat([gate_proj, up_proj], dim = 1)
return gate, gate_up_proj, down_proj
@classmethod
@ -117,7 +117,7 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
moe_block.experts[i].gate_proj.weight.data,
moe_block.experts[i].up_proj.weight.data,
],
dim=0,
dim = 0,
)
)
assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data)
@ -132,12 +132,12 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
# router_logits: (batch * sequence_length, n_experts)
router_logits = torch.nn.functional.linear(hidden_states, self.gate)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
routing_weights, self.top_k, dim = -1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
@ -177,13 +177,13 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
# Start expert computation
first_gemm = torch_grouped_gemm(
X=hidden_states, W=self.gate_up_proj, m_sizes=token_counts_by_expert
X = hidden_states, W = self.gate_up_proj, m_sizes = token_counts_by_expert
)
assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
intermediate = self.act_and_mul(first_gemm)
assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
second_gemm = torch_grouped_gemm(
X=intermediate, W=self.down_proj, m_sizes=token_counts_by_expert
X = intermediate, W = self.down_proj, m_sizes = token_counts_by_expert
)
assert second_gemm.shape == (total_tokens, hidden_dim)
@ -197,19 +197,19 @@ class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
* routing_weights[..., None]
)
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
assert hidden_states.shape == (num_tokens, hidden_dim)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return GroupedGEMMResult(
token_counts_by_expert=token_counts_by_expert,
gather_indices=gather_indices,
topk_weights=routing_weights,
first_gemm=first_gemm,
intermediate=intermediate,
second_gemm=second_gemm,
hidden_states_unpermute=hidden_states_unpermute,
hidden_states=hidden_states,
token_counts_by_expert = token_counts_by_expert,
gather_indices = gather_indices,
topk_weights = routing_weights,
first_gemm = first_gemm,
intermediate = intermediate,
second_gemm = second_gemm,
hidden_states_unpermute = hidden_states_unpermute,
hidden_states = hidden_states,
), router_logits
@ -267,14 +267,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
gate,
gate_up_proj,
down_proj,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
dW_only=dW_only,
dX_only=dX_only,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
dW_only = dW_only,
dX_only = dX_only,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -299,37 +299,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states = permute(hidden_states, gather_indices, self.top_k)
# Start expert computation
hidden_states = grouped_gemm(
X=hidden_states,
W=self.gate_up_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=self.permute_x,
permute_y=False, # output of first grouped gemm should never be permuted
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=True,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.gate_up_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = self.permute_x,
permute_y = False, # output of first grouped gemm should never be permuted
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = True,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
hidden_states = self.act_and_mul(hidden_states)
hidden_states = grouped_gemm(
X=hidden_states,
W=self.down_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=False,
permute_y=self.permute_y,
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=False,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.down_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = False,
permute_y = self.permute_y,
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = False,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
# Post-processing
@ -342,7 +342,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states.view(num_tokens, self.top_k, hidden_dim)
* routing_weights[..., None]
)
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return hidden_states, router_logits

View file

@ -18,7 +18,7 @@ from grouped_gemm.kernels.tuning import (
)
def print_delimiter(char="-", length=80):
def print_delimiter(char = "-", length = 80):
print(char * length)
@ -29,28 +29,28 @@ def delimiter_context():
print_delimiter()
def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
def make_inputs(M, N, K, E, topk, dtype, requires_grad = False):
X1 = (
torch.randn((M, K), device="cuda", dtype=dtype, requires_grad=requires_grad)
torch.randn((M, K), device = "cuda", dtype = dtype, requires_grad = requires_grad)
/ 10
)
X2 = (
torch.randn(
(M * topk, N), device="cuda", dtype=dtype, requires_grad=requires_grad
(M * topk, N), device = "cuda", dtype = dtype, requires_grad = requires_grad
)
/ 10
)
W1 = (
torch.randn(
(E, 2 * N, K), device="cuda", dtype=dtype, requires_grad=requires_grad
(E, 2 * N, K), device = "cuda", dtype = dtype, requires_grad = requires_grad
)
/ 10
)
W2 = (
torch.randn((E, K, N), device="cuda", dtype=dtype, requires_grad=requires_grad)
torch.randn((E, K, N), device = "cuda", dtype = dtype, requires_grad = requires_grad)
/ 10
)
score = torch.randn((M, E), device="cuda", dtype=dtype, requires_grad=requires_grad)
score = torch.randn((M, E), device = "cuda", dtype = dtype, requires_grad = requires_grad)
if requires_grad:
X1.retain_grad()
X2.retain_grad()
@ -60,7 +60,7 @@ def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
return X1, X2, W1, W2, score
@dataclass(kw_only=True)
@dataclass(kw_only = True)
class DataConfig:
seq_len: int
dtype: torch.dtype
@ -68,7 +68,7 @@ class DataConfig:
bs: int = 1
@dataclass(kw_only=True)
@dataclass(kw_only = True)
class ModelConfig:
hidden_size: int
intermediate_size: int
@ -77,13 +77,13 @@ class ModelConfig:
use_sigmoid: bool
renormalize: bool
pre_mul: bool = False
post_mul: bool = field(init=False)
post_mul: bool = field(init = False)
def __post_init__(self):
self.post_mul = not self.pre_mul
@dataclass(kw_only=True)
@dataclass(kw_only = True)
class GroupedGEMMTestConfig:
name: str = "test"
data_config: DataConfig
@ -105,7 +105,7 @@ def assert_equal(ref, tri):
assert ref == tri, f"ref not equal to tri {ref} != {tri}"
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
def assert_close(ref, tri, maxtol = None, rmstol = None, description = "--", verbose = True):
if tri.dtype.itemsize == 1:
ref_as_type = ref.to(tri.dtype)
if ref.dtype == tri.dtype:
@ -182,11 +182,11 @@ def assert_indx_equal(ref, tri):
def get_kernel_test_configs(
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
num_warps=4,
num_stages=2,
BLOCK_SIZE_M = 32,
BLOCK_SIZE_N = 32,
BLOCK_SIZE_K = 32,
num_warps = 4,
num_stages = 2,
) -> list[KernelConfig]:
configs_fwd = []
configs_bwd_dX = []
@ -199,44 +199,44 @@ def get_kernel_test_configs(
for use_tma_store in [True, False]:
configs_fwd.append(
KernelConfigForward(
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_w=use_tma_load_w,
use_tma_load_x=use_tma_load_x,
use_tma_store=use_tma_store,
permute_x=permute_x,
permute_y=permute_y,
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_N = BLOCK_SIZE_N,
BLOCK_SIZE_K = BLOCK_SIZE_K,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_w = use_tma_load_w,
use_tma_load_x = use_tma_load_x,
use_tma_store = use_tma_store,
permute_x = permute_x,
permute_y = permute_y,
)
)
configs_bwd_dX.append(
KernelConfigBackward_dX(
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_dy=use_tma_load_x,
use_tma_load_w=use_tma_load_w,
permute_x=permute_x,
permute_y=permute_y,
use_tma_store=use_tma_store,
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_N = BLOCK_SIZE_N,
BLOCK_SIZE_K = BLOCK_SIZE_K,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_dy = use_tma_load_x,
use_tma_load_w = use_tma_load_w,
permute_x = permute_x,
permute_y = permute_y,
use_tma_store = use_tma_store,
)
)
configs_bwd_dW.append(
KernelConfigBackward_dW(
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=num_warps,
num_stages=num_stages,
use_tma_load_dy=use_tma_load_w,
use_tma_load_x=use_tma_load_x,
permute_x=permute_x,
permute_y=permute_y,
use_tma_store=use_tma_store,
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_N = BLOCK_SIZE_N,
BLOCK_SIZE_K = BLOCK_SIZE_K,
num_warps = num_warps,
num_stages = num_stages,
use_tma_load_dy = use_tma_load_w,
use_tma_load_x = use_tma_load_x,
permute_x = permute_x,
permute_y = permute_y,
use_tma_store = use_tma_store,
)
)
configs_fwd = prune_kernel_configs_fwd(configs_fwd)
@ -289,39 +289,39 @@ TEST_MODEL_SIZES = [
SMALL_MODEL_CONFIGS = [
ModelConfig(
topk=topk,
num_experts=num_experts,
hidden_size=model_size[0],
intermediate_size=model_size[1],
use_sigmoid=False,
renormalize=False,
topk = topk,
num_experts = num_experts,
hidden_size = model_size[0],
intermediate_size = model_size[1],
use_sigmoid = False,
renormalize = False,
)
for topk, num_experts, model_size in itertools.product(
TOPK, NUM_EXPERTS, TEST_MODEL_SIZES
)
]
LLAMA_MODEL_CONFIG = ModelConfig(
topk=1,
num_experts=16,
hidden_size=5120,
intermediate_size=8192,
use_sigmoid=True,
renormalize=False,
topk = 1,
num_experts = 16,
hidden_size = 5120,
intermediate_size = 8192,
use_sigmoid = True,
renormalize = False,
)
QWEN_MODEL_CONFIG = ModelConfig(
topk=8,
num_experts=128,
hidden_size=2048,
intermediate_size=768,
use_sigmoid=False,
renormalize=False,
topk = 8,
num_experts = 128,
hidden_size = 2048,
intermediate_size = 768,
use_sigmoid = False,
renormalize = False,
)
SEQLENS = [128, 1024]
DTYPE = [torch.bfloat16]
DATA_CONFIGS = [
DataConfig(seq_len=seq_len, dtype=dtype)
DataConfig(seq_len = seq_len, dtype = dtype)
for seq_len, dtype in itertools.product(SEQLENS, DTYPE)
]
KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
@ -331,6 +331,6 @@ KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
if __name__ == "__main__":
print(
KERNEL_CONFIGS_BWD_dX[0].to_string(
include_tuning_params=False, include_tma=False
include_tuning_params = False, include_tma = False
)
)

View file

@ -33,13 +33,13 @@ def rebind_experts_to_shared_buffer(
dtype = moe_block.experts[0].down_proj.weight.dtype
buffer_up = torch.empty(
num_experts, interm_size, hidden_size, device=device, dtype=dtype
num_experts, interm_size, hidden_size, device = device, dtype = dtype
)
buffer_gate = torch.empty(
num_experts, interm_size, hidden_size, device=device, dtype=dtype
num_experts, interm_size, hidden_size, device = device, dtype = dtype
)
buffer_down = torch.empty(
num_experts, hidden_size, interm_size, device=device, dtype=dtype
num_experts, hidden_size, interm_size, device = device, dtype = dtype
)
# Step 2: Copy existing expert weights into buffers
@ -114,7 +114,7 @@ def check_down_proj_grad(
test_grad = grouped_gemm_block.down_proj.grad[i]
assert test_grad is not None
diff = (ref_grad - test_grad).abs().max()
if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
print(f"expert {i} down_proj_grad_diff: {diff.detach().cpu().item():.6f}")
@ -152,12 +152,12 @@ def check_gate_up_proj_grad(
# Check gradients
diff = (ref_gate_proj_grad - test_gate_proj_grad).abs().max()
if not torch.allclose(
ref_gate_proj_grad, test_gate_proj_grad, atol=atol, rtol=rtol
ref_gate_proj_grad, test_gate_proj_grad, atol = atol, rtol = rtol
):
print(f"expert {i} gate_proj_grad_diff: {diff.detach().cpu().item():.6f}")
diff = (ref_up_proj_grad - test_up_proj_grad).abs().max()
if not torch.allclose(
ref_up_proj_grad, test_up_proj_grad, atol=atol, rtol=rtol
ref_up_proj_grad, test_up_proj_grad, atol = atol, rtol = rtol
):
print(f"expert {i} up_proj_grad_diff: {diff.detach().cpu().item():.6f}")
@ -173,7 +173,7 @@ def check_gate_grad(
test_grad = grouped_gemm_block.gate.grad
assert test_grad is not None
diff = (ref_grad - test_grad).abs().max()
if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
print(f"gate_grad_diff: {diff.detach().cpu().item():.6f}")
@ -200,7 +200,7 @@ def check_tensor_allclose(
if verbose:
print(f"{name} diff: {diff.detach().cpu().item():.6f}")
assert torch.allclose(
X_ref, X_test, atol=atol, rtol=rtol
X_ref, X_test, atol = atol, rtol = rtol
), f"{name} diff: {diff.detach().cpu().item():.6f}"
@ -227,7 +227,7 @@ def check_expert_grads(
test_grad = test_grads[i]
diff = (ref_grad - test_grad).abs().max()
assert torch.allclose(
ref_grad, test_grad, atol=atol, rtol=rtol
ref_grad, test_grad, atol = atol, rtol = rtol
), f"{field}[{i}] diff: {diff.detach().cpu().item():.6f}"
# Test all experts
@ -235,7 +235,7 @@ def check_expert_grads(
if verbose:
print(f"{field} diff: {diff.detach().cpu().item():.6f}")
assert torch.allclose(
ref_grads, test_grads, atol=atol, rtol=rtol
ref_grads, test_grads, atol = atol, rtol = rtol
), f"{field} diff: {diff.detach().cpu().item():.6f}"
@ -269,7 +269,7 @@ def check_fwd(
if verbose:
print(f"output diff: {diff.detach().cpu().item():.6f}")
assert torch.allclose(
ref_output, test_output, atol=atol, rtol=rtol
ref_output, test_output, atol = atol, rtol = rtol
), f"output diff: {diff.detach().cpu().item():.6f}"
# Check router logits
@ -279,7 +279,7 @@ def check_fwd(
if verbose:
print(f"router_logits diff: {diff.detach().cpu().item():.6f}")
assert torch.allclose(
ref_router_logits, test_router_logits, atol=atol, rtol=rtol
ref_router_logits, test_router_logits, atol = atol, rtol = rtol
), f"router_logits diff: {diff.detach().cpu().item():.6f}"
@ -305,7 +305,7 @@ def check_grouped_gemm_results(
print(f"{field.name} diff: {diff.detach().cpu().item():.6f}")
assert torch.allclose(
ref_value, test_value, atol=atol, rtol=rtol
ref_value, test_value, atol = atol, rtol = rtol
), f"{field.name} diff: {diff.detach().cpu().item():.6f}"
@ -314,13 +314,13 @@ def run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False
output, router_logits = model(X)
if is_grouped_gemm:
result = ForwardResult(
output=output.hidden_states,
router_logits=router_logits,
X=X,
grouped_gemm_result=output,
output = output.hidden_states,
router_logits = router_logits,
X = X,
grouped_gemm_result = output,
)
else:
result = ForwardResult(output=output, router_logits=router_logits, X=X)
result = ForwardResult(output = output, router_logits = router_logits, X = X)
return result
@ -344,16 +344,16 @@ def run_backward(
)
elif isinstance(model, Qwen3MoeGroupedGEMMBlock):
gate_grad = model.gate.grad
gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim=1)
gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim = 1)
down_proj_grad = model.down_proj.grad
else:
raise ValueError(f"Unsupported model type: {type(model)}")
return BackwardResult(
X_grad=X.grad,
gate_grad=gate_grad,
gate_proj_grad=gate_proj_grad,
up_proj_grad=up_proj_grad,
down_proj_grad=down_proj_grad,
X_grad = X.grad,
gate_grad = gate_grad,
gate_proj_grad = gate_proj_grad,
up_proj_grad = up_proj_grad,
down_proj_grad = down_proj_grad,
)
@ -414,12 +414,12 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
gate,
gate_up_proj,
down_proj,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
)
def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Tensor:
@ -446,35 +446,35 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
# Start expert computation
first_gemm = grouped_gemm(
X=hidden_states,
W=self.gate_up_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=self.permute_x,
permute_y=False, # output of first grouped gemm should never be permuted
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=True,
X = hidden_states,
W = self.gate_up_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = self.permute_x,
permute_y = False, # output of first grouped gemm should never be permuted
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = True,
)
assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
intermediate = self.act_and_mul(first_gemm)
assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
second_gemm = grouped_gemm(
X=intermediate,
W=self.down_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=False,
permute_y=self.permute_y,
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=False,
X = intermediate,
W = self.down_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = False,
permute_y = self.permute_y,
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = False,
)
assert second_gemm.shape == (total_tokens, hidden_dim)
@ -491,17 +491,17 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
* routing_weights[..., None]
)
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
assert hidden_states.shape == (num_tokens, hidden_dim)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return GroupedGEMMResult(
token_counts_by_expert=token_counts_by_expert,
gather_indices=gather_indices,
topk_weights=routing_weights,
first_gemm=first_gemm,
intermediate=intermediate,
second_gemm=second_gemm,
hidden_states_unpermute=hidden_states_unpermute,
hidden_states=hidden_states,
token_counts_by_expert = token_counts_by_expert,
gather_indices = gather_indices,
topk_weights = routing_weights,
first_gemm = first_gemm,
intermediate = intermediate,
second_gemm = second_gemm,
hidden_states_unpermute = hidden_states_unpermute,
hidden_states = hidden_states,
), router_logits

File diff suppressed because it is too large Load diff

View file

@ -68,18 +68,18 @@ TOLERANCES = {
}
@pytest.fixture(scope="module")
@pytest.fixture(scope = "module")
def model_id():
return "Qwen/Qwen3-30B-A3B"
@pytest.fixture(scope="module")
@pytest.fixture(scope = "module")
def config(model_id: str):
return AutoConfig.from_pretrained(model_id)
@contextmanager
def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
print(char * num_chars)
print(prelude)
yield
@ -96,16 +96,16 @@ NUM_AUTOTUNE_CONFIGS = 50
@pytest.mark.parametrize(
"permute_y", [True], ids=lambda x: "permute_y" if x else "no_permute_y"
"permute_y", [True], ids = lambda x: "permute_y" if x else "no_permute_y"
)
@pytest.mark.parametrize(
"permute_x", [True], ids=lambda x: "permute_x" if x else "no_permute_x"
"permute_x", [True], ids = lambda x: "permute_x" if x else "no_permute_x"
)
@pytest.mark.parametrize(
"autotune", [True], ids=lambda x: "autotune" if x else "manual"
"autotune", [True], ids = lambda x: "autotune" if x else "manual"
)
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
@pytest.mark.parametrize("dtype", DTYPES, ids=str)
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
@pytest.mark.parametrize("dtype", DTYPES, ids = str)
def test_qwen3_moe(
config: Qwen3MoeConfig,
seqlen: int,
@ -157,36 +157,36 @@ def test_qwen3_moe(
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
moe_block,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
).to(device, dtype)
fused_gemm_block.check_weights(moe_block)
X = torch.randn(
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
)
# Forward
ref_result = run_forward(moe_block, X, is_grouped_gemm=False)
grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm=True)
fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm=True)
ref_result = run_forward(moe_block, X, is_grouped_gemm = False)
grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)
fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)
with annotated_context(
"Testing forward pass",
epilogue="Passed forward tests!",
char="=",
num_chars=100,
epilogue = "Passed forward tests!",
char = "=",
num_chars = 100,
):
# Sanity checks
with annotated_context(
"Checking HF vs torch grouped gemm MoE forward outputs..."
):
check_fwd(ref_result, grouped_result, atol, rtol, verbose=False)
check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)
with annotated_context(
"Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs..."
@ -195,42 +195,42 @@ def test_qwen3_moe(
check_grouped_gemm_results(
grouped_result.grouped_gemm_result,
fused_result.grouped_gemm_result,
permute_y=permute_y,
atol=atol,
rtol=rtol,
verbose=False,
permute_y = permute_y,
atol = atol,
rtol = rtol,
verbose = False,
)
# Actual test
with annotated_context(
"Checking HF vs fused grouped gemm MoE forward outputs..."
):
check_fwd(ref_result, fused_result, atol, rtol, verbose=True)
check_fwd(ref_result, fused_result, atol, rtol, verbose = True)
# Backward
grad_output = torch.randn_like(ref_result.output)
ref_backward_result = run_backward(
moe_block, grad_output, output=ref_result.output, X=ref_result.X
moe_block, grad_output, output = ref_result.output, X = ref_result.X
)
grouped_backward_result = run_backward(
grouped_gemm_block,
grad_output,
output=grouped_result.output,
X=grouped_result.X,
output = grouped_result.output,
X = grouped_result.X,
)
fused_backward_result = run_backward(
fused_gemm_block, grad_output, output=fused_result.output, X=fused_result.X
fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X
)
with annotated_context(
"Testing backward pass",
epilogue="Passed backward tests!",
char="=",
num_chars=100,
epilogue = "Passed backward tests!",
char = "=",
num_chars = 100,
):
# Sanity checks
with annotated_context("Checking HF vs torch grouped gemm MoE grads..."):
check_grads(
ref_backward_result, grouped_backward_result, atol, rtol, verbose=False
ref_backward_result, grouped_backward_result, atol, rtol, verbose = False
)
with annotated_context(
"Checking torch grouped gemm MoE vs fused grouped gemm MoE grads..."
@ -240,25 +240,25 @@ def test_qwen3_moe(
fused_backward_result,
atol,
rtol,
verbose=False,
verbose = False,
)
# Actual test
with annotated_context("Checking HF vs fused grouped gemm MoE grads..."):
check_grads(
ref_backward_result, fused_backward_result, atol, rtol, verbose=True
ref_backward_result, fused_backward_result, atol, rtol, verbose = True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seqlen", type=int, default=1024)
parser.add_argument("--seqlen", type = int, default = 1024)
parser.add_argument(
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
)
parser.add_argument("--permute_x", action="store_true")
parser.add_argument("--permute_y", action="store_true")
parser.add_argument("--autotune", action="store_true")
parser.add_argument("--permute_x", action = "store_true")
parser.add_argument("--permute_y", action = "store_true")
parser.add_argument("--autotune", action = "store_true")
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
args_dict = vars(args)

View file

@ -54,10 +54,10 @@ except:
CohereFlashAttention2 = CohereAttention
def fast_layernorm_inference(self, X, out_weight=None):
XX = X.to(torch.float32, copy=True)
XX -= X.mean(-1, keepdim=True)
variance = XX.square().mean(-1, keepdim=True)
def fast_layernorm_inference(self, X, out_weight = None):
XX = X.to(torch.float32, copy = True)
XX -= X.mean(-1, keepdim = True)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
out_weight[:] = self.weight
@ -120,8 +120,8 @@ def CohereAttention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@ -143,14 +143,14 @@ def CohereAttention_fast_forward(
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal=True)
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
@ -169,7 +169,7 @@ def CohereAttention_fast_forward(
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(
Q, K, V, attn_mask=attention_mask, is_causal=False
Q, K, V, attn_mask = attention_mask, is_causal = False
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@ -198,7 +198,7 @@ def CohereDecoderLayer_fast_forward(
self, "_flag_for_generation"
): # past_key_value is not None:
out_weight = torch.empty(
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
# Self Attention
@ -207,14 +207,14 @@ def CohereDecoderLayer_fast_forward(
self.input_layernorm, hidden_states, out_weight
)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
)
# Fully Connected
@ -226,14 +226,14 @@ def CohereDecoderLayer_fast_forward(
residual = hidden_states
hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
)
# Fully Connected
@ -260,8 +260,8 @@ def CohereAttention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill=False,
attention_mask=None,
do_prefill = False,
attention_mask = None,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
@ -284,45 +284,45 @@ def CohereAttention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
dtype=dtype,
device="cuda:0",
dtype = dtype,
device = "cuda:0",
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
(2, bsz, 1, attention_size), dtype=dtype, device="cuda:0"
(2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0"
)
self.temp_KV = torch.empty(
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device="cuda:0"
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = "cuda:0"
)
self.RH_Q = torch.empty(
(bsz, n_heads, 1, head_dim), dtype=dtype, device="cuda:0"
(bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0"
)
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty(
(1, bsz, hidden_size), dtype=dtype, device="cuda:0"
(1, bsz, hidden_size), dtype = dtype, device = "cuda:0"
)
else:
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
self.attention = torch.empty(
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),
dtype=dtype,
device="cuda:0",
dtype = dtype,
device = "cuda:0",
)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
# Cohere has QK layernorms
if self.use_qk_norm:
self.q_norm_out_weight = torch.empty(
self.q_norm.weight.shape, dtype=torch.float32, device="cuda:0"
self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
self.k_norm_out_weight = torch.empty(
self.k_norm.weight.shape, dtype=torch.float32, device="cuda:0"
self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
else:
self.q_norm_out_weight = None
@ -343,9 +343,9 @@ def CohereAttention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@ -363,7 +363,7 @@ def CohereAttention_fast_forward_inference(
RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
@ -372,7 +372,7 @@ def CohereAttention_fast_forward_inference(
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
@ -414,20 +414,20 @@ def CohereAttention_fast_forward_inference(
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
)
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(
A, dim=-1, dtype=torch.float32
A, dim = -1, dtype = torch.float32
) # .to(A.dtype)
A = torch_matmul(A, Vnn, out=Qn)
A = torch_matmul(A, Vnn, out = Qn)
else:
A = scaled_dot_product_attention(
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@ -438,13 +438,13 @@ def CohereModel_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
attention_mask=None,
attention_mask = None,
):
out_weights = tuple(
torch.empty_like(
self.model.layers[0].input_layernorm.weight,
dtype=torch.float32,
device=torch.device(x),
dtype = torch.float32,
device = torch.device(x),
)
for x in range(DEVICE_COUNT)
)
@ -459,7 +459,7 @@ def CohereModel_fast_forward_inference(
(bsz, q_len),
hidden_states,
seq_len,
sliding_window=getattr(self.config, "sliding_window", None),
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
@ -477,11 +477,11 @@ def CohereModel_fast_forward_inference(
hidden_states_attention, present_key_value = (
CohereAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states=hidden_states,
past_key_value=past_key_values[idx],
position_ids=position_ids,
attention_mask=attention_mask,
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
)
@ -496,10 +496,10 @@ def CohereModel_fast_forward_inference(
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=[],
attentions=[],
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
@ -507,10 +507,10 @@ class FastCohereModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="cohere",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=CohereAttention,
model_name = "cohere",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = CohereAttention,
)
if init_name is not None:
exec(function, globals())

View file

@ -121,7 +121,7 @@ def FalconH1Attention_fast_forward(
else:
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
device_index = Q.device.index
if position_ids is None:
@ -132,8 +132,8 @@ def FalconH1Attention_fast_forward(
Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@ -157,7 +157,7 @@ def FalconH1Attention_fast_forward(
# Xformers does support the forward pass though
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
@ -166,7 +166,7 @@ def FalconH1Attention_fast_forward(
V = V.transpose(1, 2)
sw = kv_seq_len
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
@ -181,7 +181,7 @@ def FalconH1Attention_fast_forward(
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(
Q, K, V, attn_mask=attention_mask, is_causal=False
Q, K, V, attn_mask = attention_mask, is_causal = False
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@ -200,8 +200,8 @@ def FalconH1Attention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill=False,
attention_mask=None,
do_prefill = False,
attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
@ -253,29 +253,29 @@ def FalconH1Attention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
(2, bsz, 1, attention_size), dtype=dtype, device=device
(2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
self.attention = torch.empty(
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
@ -295,10 +295,10 @@ def FalconH1Attention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Kn = Kn * self.config.key_multiplier
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(
bsz, 1, n_heads, head_dim
) # .transpose(1, 2) # we will transpose after normalisation
@ -375,25 +375,25 @@ def FalconH1Attention_fast_forward_inference(
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
)
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(
A, dim=-1, dtype=torch.float32
A, dim = -1, dtype = torch.float32
) # .to(A.dtype)
A = torch_matmul(A, Vnn, out=Qn)
A = torch_matmul(A, Vnn, out = Qn)
else:
if SDPA_HAS_GQA:
A = scaled_dot_product_attention(
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False, enable_gqa=True
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True
)
else:
A = scaled_dot_product_attention(
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=False
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@ -401,7 +401,7 @@ def FalconH1Attention_fast_forward_inference(
def FalconH1DecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask=None,
causal_mask = None,
attention_mask: Optional[torch.Tensor] = None,
mamba_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -433,23 +433,23 @@ def FalconH1DecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
mamba_hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=mamba_attention_mask,
hidden_states = hidden_states,
cache_params = past_key_value,
cache_position = cache_position,
attention_mask = mamba_attention_mask,
)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
@ -469,23 +469,23 @@ def FalconH1DecoderLayer_fast_forward(
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
mamba_hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=mamba_attention_mask,
hidden_states = hidden_states,
cache_params = past_key_value,
cache_position = cache_position,
attention_mask = mamba_attention_mask,
)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
@ -509,8 +509,8 @@ def FalconH1DecoderLayer_fast_forward(
def _FalconH1_fast_forward_inference(
attention_fast_forward_inference=FalconH1Attention_fast_forward_inference,
mlp_fast_forward_inference=fast_swiglu_inference,
attention_fast_forward_inference = FalconH1Attention_fast_forward_inference,
mlp_fast_forward_inference = fast_swiglu_inference,
):
# This makes the attention and MLP customisable.
# Now for models like qwen3 or cohere which use custom attention operations, we can use this function
@ -519,9 +519,9 @@ def _FalconH1_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
cache_position=None,
attention_mask=None,
mamba_attention_mask=None,
cache_position = None,
attention_mask = None,
mamba_attention_mask = None,
):
input_ids = input_ids[:, : self.max_seq_length]
bsz, q_len = input_ids.shape
@ -536,11 +536,11 @@ def _FalconH1_fast_forward_inference(
bsz, q_len, hd = X.shape
assert q_len == 1
# Get saved buffers to reduce memory movement
residual = torch.empty((bsz, q_len, hd), dtype=torch.float32, device="cuda:0")
_XX = torch.empty((2, bsz, q_len, hd), dtype=torch.float32, device="cuda:0")
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
XX, XX2 = _XX[0], _XX[1]
variance = torch.empty((bsz, q_len, 1), dtype=torch.float32, device="cuda:0")
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype=X.dtype, device="cuda:0")
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0")
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
temp_gate, temp_up = temp_mlp[0], temp_mlp[1]
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
@ -549,7 +549,7 @@ def _FalconH1_fast_forward_inference(
(bsz, q_len),
X,
seq_len,
sliding_window=getattr(self.config, "sliding_window", None),
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
@ -561,28 +561,28 @@ def _FalconH1_fast_forward_inference(
X = fast_rms_layernorm_inference(
decoder_layer.input_layernorm,
X,
XX=XX,
XX2=XX2,
variance=variance,
XX = XX,
XX2 = XX2,
variance = variance,
)
attention_hidden_states, present_key_value = (
attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states=X * decoder_layer.attention_in_multiplier,
past_key_value=past_key_values[idx],
position_ids=position_ids,
attention_mask=attention_mask,
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
hidden_states = X * decoder_layer.attention_in_multiplier,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
)
attention_hidden_states = (
attention_hidden_states * decoder_layer.attn_out_multiplier
)
mamba_hidden_states = decoder_layer.mamba(
hidden_states=X,
cache_params=present_key_value,
cache_position=cache_position,
attention_mask=mamba_attention_mask,
hidden_states = X,
cache_params = present_key_value,
cache_position = cache_position,
attention_mask = mamba_attention_mask,
)
mamba_hidden_states = mamba_hidden_states * decoder_layer.ssm_out_multiplier
X = mamba_hidden_states + attention_hidden_states
@ -593,17 +593,17 @@ def _FalconH1_fast_forward_inference(
X = fast_rms_layernorm_inference(
decoder_layer.pre_ff_layernorm,
X,
XX=XX,
XX2=XX2,
variance=variance,
XX = XX,
XX2 = XX2,
variance = variance,
)
X = mlp_fast_forward_inference(
decoder_layer.feed_forward,
X,
temp_gate=temp_gate,
temp_up=temp_up,
gate_multiplier=gate_multiplier,
down_multiplier=down_multiplier,
temp_gate = temp_gate,
temp_up = temp_up,
gate_multiplier = gate_multiplier,
down_multiplier = down_multiplier,
)
X += residual
@ -611,16 +611,16 @@ def _FalconH1_fast_forward_inference(
X = fast_rms_layernorm_inference(
self.model.final_layernorm,
X,
XX=XX,
XX2=XX2,
variance=variance,
XX = XX,
XX2 = XX2,
variance = variance,
)
return BaseModelOutputWithPast(
last_hidden_state=X,
past_key_values=next_decoder_cache,
hidden_states=[],
attentions=[],
last_hidden_state = X,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
return FalconH1Model_fast_forward_inference_custom
@ -630,12 +630,12 @@ def _FalconH1_fast_forward_inference(
def _fast_prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
past_key_values = None,
attention_mask = None,
inputs_embeds = None,
cache_position = None,
position_ids = None,
use_cache = True,
**kwargs,
):
# Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
@ -707,10 +707,10 @@ class FastFalconH1Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="FalconH1",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=FalconH1Attention,
model_name = "FalconH1",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = FalconH1Attention,
)
if init_name is not None:
exec(function, globals())
@ -738,30 +738,30 @@ class FastFalconH1Model(FastLlamaModel):
@staticmethod
def from_pretrained( # TODO: Change after release
model_name="Qwen/FalconH1-7B",
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
token=None,
device_map="sequential",
rope_scaling=None,
fix_tokenizer=True,
model_patcher=None,
tokenizer_name=None,
trust_remote_code=False,
model_name = "Qwen/FalconH1-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=FastFalconH1Model,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastFalconH1Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)

View file

@ -67,11 +67,11 @@ def fast_geglu_inference(self, X):
gate = fast_linear_forward(self.gate_proj, X) # , out = temp[0])
up = fast_linear_forward(self.up_proj, X) # , out = temp[1])
gate = torch_nn_functional_gelu(gate, approximate="tanh")
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out=up[:, :, :hd])
down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
return down
@ -93,7 +93,7 @@ def GemmaDecoderLayer_fast_forward(
self, "_flag_for_generation"
): # past_key_value is not None:
out_weight = torch.empty(
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
# Self Attention
@ -102,14 +102,14 @@ def GemmaDecoderLayer_fast_forward(
self.input_layernorm, hidden_states, out_weight
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
)
hidden_states += residual
@ -123,24 +123,24 @@ def GemmaDecoderLayer_fast_forward(
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.input_layernorm, hidden_states, gemma=True
self.input_layernorm, hidden_states, gemma = True
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states, gemma=True
self.post_attention_layernorm, hidden_states, gemma = True
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
@ -163,13 +163,13 @@ def GemmaModel_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
attention_mask=None,
attention_mask = None,
):
out_weights = tuple(
torch.empty_like(
self.model.layers[0].input_layernorm.weight,
dtype=torch.float32,
device=torch.device(x),
dtype = torch.float32,
device = torch.device(x),
)
for x in range(DEVICE_COUNT)
)
@ -179,7 +179,7 @@ def GemmaModel_fast_forward_inference(
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(
math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype
math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
)
bsz, q_len, hd = hidden_states.shape
@ -205,11 +205,11 @@ def GemmaModel_fast_forward_inference(
)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states=hidden_states,
past_key_value=past_key_values[idx],
position_ids=position_ids,
attention_mask=attention_mask,
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
@ -228,10 +228,10 @@ def GemmaModel_fast_forward_inference(
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=[],
attentions=[],
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
@ -243,11 +243,11 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
config=None, # [TODO] Hack to pass in config - need to remove later
dim = None,
max_position_embeddings = 2048,
base = 10000,
device = None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
@ -274,17 +274,17 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Build here to make `torch.jit.trace` work.
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(
seq_len=self.current_rope_size,
device=torch.device(device),
dtype=torch.get_default_dtype(),
seq_len = self.current_rope_size,
device = torch.device(device),
dtype = torch.get_default_dtype(),
)
# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(
1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()
1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
)
self.sin_cached = torch.empty(
1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()
1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@ -294,27 +294,27 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
# The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(
self.current_rope_size, device="cpu", dtype=torch.int64
self.current_rope_size, device = "cpu", dtype = torch.int64
).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim=-1)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device=device, non_blocking=True) # , dtype = dtype)
sin = emb.sin().to(device=device, non_blocking=True) # , dtype = dtype)
cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
def forward(self, x, position_ids=None, seq_len=None):
def forward(self, x, position_ids = None, seq_len = None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
device_index = x.device.index
@ -323,7 +323,7 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
self.multi_gpu_sin_cached[device_index][:seq_len],
)
def get_cached(self, seq_len=None, device_index=None):
def get_cached(self, seq_len = None, device_index = None):
if device_index is None:
device_index = torch.cuda.current_device()
return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
@ -337,7 +337,7 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(
self.current_rope_size, device=torch.device(device), dtype=x.dtype
self.current_rope_size, device = torch.device(device), dtype = x.dtype
)
@ -349,20 +349,20 @@ class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
config=None, # [TODO] Hack to pass in config - need to remove later
dim = None,
max_position_embeddings = 2048,
base = 10000,
device = None,
scaling_factor = 1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(
dim=dim,
max_position_embeddings=max_position_embeddings,
base=base,
device=device,
config=config,
dim = dim,
max_position_embeddings = max_position_embeddings,
base = base,
device = device,
config = config,
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@ -372,20 +372,20 @@ class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
# The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(
self.current_rope_size, device="cpu", dtype=torch.int64
self.current_rope_size, device = "cpu", dtype = torch.int64
).float()
positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim=-1)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device=device, non_blocking=True) # , dtype = dtype)
sin = emb.sin().to(device=device, non_blocking=True) # , dtype = dtype)
cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
@ -395,10 +395,10 @@ class FastGemmaModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="gemma",
rope_module=GemmaFixedRotaryEmbedding,
scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding,
attention_module=GemmaAttention,
model_name = "gemma",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = GemmaAttention,
)
if init_name is not None:
exec(function, globals())
@ -430,7 +430,7 @@ class FastGemmaModel(FastLlamaModel):
def post_patch(model, tokenizer):
# Gemma does not downcast RoPE
model, tokenizer = patch_model_and_tokenizer(
model, tokenizer, downcast_rope=False
model, tokenizer, downcast_rope = False
)
# Add 1 to weight

View file

@ -47,13 +47,13 @@ BAD_MAPPINGS = {
def __get_model_name(
model_name,
load_in_4bit=True,
INT_TO_FLOAT_MAPPER=None,
FLOAT_TO_INT_MAPPER=None,
MAP_TO_UNSLOTH_16bit=None,
load_in_fp8=False,
FLOAT_TO_FP8_BLOCK_MAPPER=None,
FLOAT_TO_FP8_ROW_MAPPER=None,
load_in_4bit = True,
INT_TO_FLOAT_MAPPER = None,
FLOAT_TO_INT_MAPPER = None,
MAP_TO_UNSLOTH_16bit = None,
load_in_fp8 = False,
FLOAT_TO_FP8_BLOCK_MAPPER = None,
FLOAT_TO_FP8_ROW_MAPPER = None,
):
model_name = str(model_name)
lower_model_name = model_name.lower()
@ -116,7 +116,7 @@ def _get_new_mapper():
import requests
new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
with requests.get(new_mapper, timeout=3) as new_mapper:
with requests.get(new_mapper, timeout = 3) as new_mapper:
new_mapper = new_mapper.text
new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER") :]
new_mapper = (
@ -135,17 +135,17 @@ def _get_new_mapper():
return {}, {}, {}
def get_model_name(model_name, load_in_4bit=True, load_in_fp8=False):
def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
assert load_in_fp8 in (True, False, "block")
new_model_name = __get_model_name(
model_name=model_name,
load_in_4bit=load_in_4bit,
INT_TO_FLOAT_MAPPER=INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER=FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit=MAP_TO_UNSLOTH_16bit,
load_in_fp8=load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER=FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER=FLOAT_TO_FP8_ROW_MAPPER,
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
)
# In the rare case, we convert bad model names to other names
# For eg too large dynamic quants or MoEs
@ -166,14 +166,14 @@ def get_model_name(model_name, load_in_4bit=True, load_in_fp8=False):
_get_new_mapper()
)
upgraded_model_name = __get_model_name(
model_name=model_name,
load_in_4bit=load_in_4bit,
INT_TO_FLOAT_MAPPER=NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER=NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit=NEW_MAP_TO_UNSLOTH_16bit,
load_in_fp8=load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER=FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER=FLOAT_TO_FP8_ROW_MAPPER,
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
)
if upgraded_model_name is not None:
raise NotImplementedError(
@ -207,8 +207,8 @@ def _get_torchao_fp8_config(fp8_mode: str):
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")
return Float8DynamicActivationFloat8WeightConfig(
granularity=granularity,
activation_value_lb=1e-12,
granularity = granularity,
activation_value_lb = 1e-12,
)
@ -255,12 +255,12 @@ def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
auto_processor = AutoProcessor if is_vlm else AutoTokenizer
model = auto_model.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
quantization_config=qconfig,
torch_dtype = "auto",
device_map = "auto",
quantization_config = qconfig,
)
tokenizer = auto_processor.from_pretrained(model_name)
model.save_pretrained(new_model_name, safe_serialization=False)
model.save_pretrained(new_model_name, safe_serialization = False)
del model
for _ in range(2):
torch.cuda.empty_cache()
@ -276,8 +276,8 @@ def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
try:
base_config = _get_torchao_fp8_config(fp8_mode)
model.torchao_config = TorchAOConfig(
qat_scheme=None,
base_config_and_filter_fns=[(base_config, None)],
qat_scheme = None,
base_config_and_filter_fns = [(base_config, None)],
)
except:
pass

View file

@ -82,7 +82,7 @@ def MistralAttention_fast_forward(
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
if position_ids is None:
@ -91,8 +91,8 @@ def MistralAttention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@ -128,7 +128,7 @@ def MistralAttention_fast_forward(
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
@ -138,7 +138,7 @@ def MistralAttention_fast_forward(
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
@ -153,7 +153,7 @@ def MistralAttention_fast_forward(
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(
Q, K, V, attn_mask=attention_mask, is_causal=False
Q, K, V, attn_mask = attention_mask, is_causal = False
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@ -199,7 +199,7 @@ def MistralForCausalLM_fast_forward(
else:
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask.from_seqlens(
[q_len] * bsz
).make_local_attention(window_size=sliding_window)
).make_local_attention(window_size = sliding_window)
# If attention_mask exists, it will be handled in the attention forward
@ -213,13 +213,13 @@ def MistralForCausalLM_fast_forward(
):
# Fully causal mask
causal_mask_values = torch.triu(
torch.full((q_len, q_len), -torch.inf, device=input_ids.device),
diagonal=1,
torch.full((q_len, q_len), -torch.inf, device = input_ids.device),
diagonal = 1,
)
else:
# Sliding window attention
q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1)
k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1)
q_indices = torch.arange(q_len, device = input_ids.device).view(-1, 1)
k_indices = torch.arange(q_len, device = input_ids.device).view(1, -1)
causal_bool_mask = k_indices <= q_indices
window_bool_mask = (q_indices - k_indices) < sliding_window
@ -243,7 +243,7 @@ def MistralForCausalLM_fast_forward(
attention_mask = attention_mask + causal_mask_values[None, None, :, :]
attention_mask = attention_mask.to(
dtype=_get_dtype(dtype_from_config(self.config))
dtype = _get_dtype(dtype_from_config(self.config))
)
output_attentions = (
@ -268,21 +268,21 @@ def MistralForCausalLM_fast_forward(
self,
input_ids,
past_key_values,
position_ids=position_ids,
attention_mask=attention_mask,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
input_ids = input_ids,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
)
hidden_states = outputs[0]
@ -302,11 +302,11 @@ def MistralForCausalLM_fast_forward(
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
loss=None,
logits=hidden_states,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss = None,
logits = hidden_states,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
if bsz == 1 and q_len == 1:
@ -337,28 +337,28 @@ def MistralForCausalLM_fast_forward(
# logit_softcapping = logit_softcapping,
# )
loss = unsloth_fused_ce_loss(
trainer=None,
hidden_states=hidden_states,
lm_head_weight=lm_head,
lm_head_bias=None,
labels=labels,
mask=None,
n_items=n_items,
scaling=getattr(self, "accelerator_scaler", None),
target_gb=None,
torch_compile=True,
logit_softcapping=logit_softcapping,
trainer = None,
hidden_states = hidden_states,
lm_head_weight = lm_head,
lm_head_bias = None,
labels = labels,
mask = None,
n_items = n_items,
scaling = getattr(self, "accelerator_scaler", None),
target_gb = None,
torch_compile = True,
logit_softcapping = logit_softcapping,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
output = CausalLMOutputWithPast(
loss=loss,
logits=EMPTY_LOGITS,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss = loss,
logits = EMPTY_LOGITS,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
return output
pass
@ -377,9 +377,9 @@ def MistralForCausalLM_fast_forward(
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
loss = fast_cross_entropy_loss(
logits=shift_logits,
labels=shift_labels,
n_items=kwargs.get("num_items_in_batch", None)
logits = shift_logits,
labels = shift_labels,
n_items = kwargs.get("num_items_in_batch", None)
or kwargs.get("n_items", None),
)
@ -388,11 +388,11 @@ def MistralForCausalLM_fast_forward(
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss = loss,
logits = logits,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
@ -417,10 +417,10 @@ class FastMistralModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="mistral",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=MistralAttention,
model_name = "mistral",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = MistralAttention,
)
# Just for Mistral Nemo models!
if function is not None and init_name is not None:
@ -451,30 +451,30 @@ class FastMistralModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name="unsloth/mistral-7b-bnb-4bit",
max_seq_length=None,
dtype=None,
load_in_4bit=True,
token=None,
device_map="sequential",
rope_scaling=None, # Mistral does not support RoPE scaling
fix_tokenizer=True,
model_patcher=None,
tokenizer_name=None,
trust_remote_code=False,
model_name = "unsloth/mistral-7b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Mistral does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=FastMistralModel,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastMistralModel,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)

View file

@ -59,9 +59,7 @@ def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = Non
self.gate_proj, X, out = temp_gate
) # pretty much the only change from transformers implementation.
routing_weights = torch_nn_functional_softmax(
router_logits, dim = -1, dtype = torch.float32
)
routing_weights = torch_nn_functional_softmax(router_logits, dim = -1, dtype = torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
# we cast back to the input dtype

View file

@ -329,7 +329,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
try:
trainer = eval(f"trl.trainer.{trainer_file}")
except Exception as error:
print(f"Unsloth: Could not import trl.trainer.{trainer_file}: {error}")
return
# Get SFTTrainer and SFTConfig names
@ -348,14 +347,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
and trainer_file.split("_")[0] in x.lower()
]
if len(name) != 1:
print(
f"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}"
)
return
if len(config) != 1:
print(
f"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}"
)
return
# Get SFTTrainer, SFTConfig
@ -364,24 +357,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
try:
RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
except:
print(
f"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}"
)
return
try:
RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
except:
print(
f"Unsloth: Could not load {RLConfig_name} from trl.trainer.{trainer_file}"
)
return
# Check name
if RLTrainer.__name__.startswith("Unsloth"):
print(f"Unsloth: {RLTrainer.__name__} is already patched.")
return
if RLConfig.__name__.startswith("Unsloth"):
print(f"Unsloth: {RLConfig.__name__} is already patched.")
return
# Get old source
@ -1306,11 +1291,7 @@ def patch_trl_rl_trainers():
import trl.trainer
all_trainers = dir(trl.trainer)
all_trainers = [
x
for x in all_trainers
if x.islower() and x.endswith("_trainer") and x != "base_trainer"
]
all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
for trainer in all_trainers:
_patch_trl_rl_trainers(trainer)
return

View file

@ -105,7 +105,7 @@ def sft_trainer_prepare_dataset(function_name, function):
matched = re.match(
r"[\s]{0,}def _prepare_dataset\(.*?" + params + r".*?\)",
function,
flags=re.MULTILINE | re.DOTALL,
flags = re.MULTILINE | re.DOTALL,
)
if matched:
# Use fast version!
@ -147,7 +147,7 @@ def sft_trainer_prepare_dataset(function_name, function):
replacer = re.findall(
r"def " + function_name + r"\(.*?\).*?\:\n",
function,
flags=re.MULTILINE | re.DOTALL,
flags = re.MULTILINE | re.DOTALL,
)
if len(replacer) != 0:
replacer = replacer[0]
@ -175,13 +175,13 @@ def sft_trainer_compute_loss(function_name, function):
return function
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
self, model, inputs, return_outputs = False, num_items_in_batch = None
):
outputs = super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
return_outputs = return_outputs,
num_items_in_batch = num_items_in_batch,
)
return outputs
@ -286,7 +286,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"
r"\2if self\.use_vllm:)",
function,
flags=re.DOTALL | re.MULTILINE,
flags = re.DOTALL | re.MULTILINE,
)
if len(found) != 0:
replace_part, spacing = found[0]
@ -441,7 +441,7 @@ def grpo_trainer__get_per_token_logps(function_name, function):
return function
def _get_per_token_logps(
self, model, input_ids, attention_mask, logits_to_keep, compute_efficient=False
self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False
):
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
return None # Unsloth efficient GRPO
@ -456,12 +456,12 @@ def grpo_trainer__get_per_token_logps(function_name, function):
self._autocast_dtype = torch.float16
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type=DEVICE_TYPE, dtype=self._autocast_dtype):
with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
input_ids = input_ids,
attention_mask = attention_mask,
logits_to_keep = logits_to_keep + 1,
).logits
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits
@ -500,9 +500,9 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
input_ids,
attention_mask,
logits_to_keep,
batch_size=None,
compute_entropy=False,
compute_efficient=False,
batch_size = None,
compute_entropy = False,
compute_efficient = False,
*args,
**kwargs,
):
@ -533,33 +533,33 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
unwrapped_model = self.accelerator.unwrap_model(
model, keep_fp32_wrapper=False
model, keep_fp32_wrapper = False
)
with torch.amp.autocast(device_type="cuda", dtype=self._autocast_dtype):
with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype):
with _get_inference_mode_context_manager(model):
if pixel_values is None:
attention_mask = input_ids != self.processing_class.pad_token_id
attention_mask = attention_mask.to(attention_mask.dtype)
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = unwrapped_model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
# logits_to_keep = logits_to_keep + 1,
).logits
else:
logits = unwrapped_model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
logits_to_keep=logits_to_keep + 1,
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
logits_to_keep = logits_to_keep + 1,
).logits
entropies = None
@ -615,7 +615,7 @@ def grpo_trainer_compute_loss(function_name, function):
return function
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
self, model, inputs, return_outputs = False, num_items_in_batch = None
):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
@ -639,9 +639,9 @@ def grpo_trainer_compute_loss(function_name, function):
current_gradient_accumulation_steps = self.current_gradient_accumulation_steps
num_processes = self.accelerator.num_processes
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
bsz, qlen = input_ids.shape
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
# attention_mask = None
logits_to_keep = completion_ids.size(
1
@ -654,9 +654,9 @@ def grpo_trainer_compute_loss(function_name, function):
input_ids,
attention_mask,
logits_to_keep,
batch_size=None,
compute_entropy=False,
compute_efficient=False: self._get_per_token_logps(
batch_size = None,
compute_entropy = False,
compute_efficient = False: self._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep, compute_efficient
)
if hasattr(self, "_get_per_token_logps")
@ -672,7 +672,7 @@ def grpo_trainer_compute_loss(function_name, function):
) # logps
per_token_logps = get_logps_func(
model, input_ids, attention_mask, logits_to_keep, compute_efficient=True
model, input_ids, attention_mask, logits_to_keep, compute_efficient = True
)
# Compute the KL divergence between the model and the reference model
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
@ -726,71 +726,71 @@ def grpo_trainer_compute_loss(function_name, function):
completion_mask,
self.beta,
advantages,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
loss_type=self.args.loss_type,
importance_sampling_level=self.importance_sampling_level,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
max_completion_length=self.args.max_completion_length,
delta=self.args.delta,
temperature=self.args.temperature,
logit_softcapping=logit_softcapping,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
num_items_in_batch=num_items_in_batch,
current_gradient_accumulation_steps=current_gradient_accumulation_steps,
num_processes=num_processes,
sampling_per_token_logps=sampling_per_token_logps,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
loss_type = self.args.loss_type,
importance_sampling_level = self.importance_sampling_level,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
num_items_in_batch = num_items_in_batch,
current_gradient_accumulation_steps = current_gradient_accumulation_steps,
num_processes = num_processes,
sampling_per_token_logps = sampling_per_token_logps,
)
)
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl, delta, flat_is_ratio = (
grpo_accumulated_loss(
trainer=self,
input_ids=_input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
logits_to_keep=logits_to_keep,
completion_mask=completion_mask,
advantages=advantages,
old_hidden_states=old_hidden_states,
ref_hidden_states=ref_hidden_states,
n_chunks=self.args.unsloth_num_chunks,
loss_type=self.args.loss_type,
importance_sampling_level=self.importance_sampling_level,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
max_completion_length=self.args.max_completion_length,
delta=self.args.delta,
temperature=self.args.temperature,
logit_softcapping=logit_softcapping,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
attention_mask=attention_mask,
num_items_in_batch=num_items_in_batch,
current_gradient_accumulation_steps=current_gradient_accumulation_steps,
num_processes=num_processes,
sampling_per_token_logps=sampling_per_token_logps,
trainer = self,
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
ref_hidden_states = ref_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
loss_type = self.args.loss_type,
importance_sampling_level = self.importance_sampling_level,
epsilon_low = self.epsilon_low,
epsilon_high = self.epsilon_high,
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
num_items_in_batch = num_items_in_batch,
current_gradient_accumulation_steps = current_gradient_accumulation_steps,
num_processes = num_processes,
sampling_per_token_logps = sampling_per_token_logps,
)
)
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
trainer=self,
input_ids=_input_ids,
logits_to_keep=logits_to_keep,
completion_mask=completion_mask,
advantages=advantages,
old_hidden_states=old_hidden_states,
ref_hidden_states=ref_hidden_states,
n_chunks=self.args.unsloth_num_chunks,
temperature=self.args.temperature,
logit_softcapping=logit_softcapping,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
attention_mask=attention_mask,
trainer = self,
input_ids = _input_ids,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
old_hidden_states = old_hidden_states,
ref_hidden_states = ref_hidden_states,
n_chunks = self.args.unsloth_num_chunks,
temperature = self.args.temperature,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
)
if "train" in self._metrics:
@ -805,12 +805,12 @@ def grpo_trainer_compute_loss(function_name, function):
mean_delta = (
torch.mean(delta)
if delta.numel() > 0
else torch.tensor(0.0, device=self.model.device)
else torch.tensor(0.0, device = self.model.device)
)
max_delta = (
torch.max(delta)
if delta.numel() > 0
else torch.tensor(0.0, device=self.model.device)
else torch.tensor(0.0, device = self.model.device)
)
self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
self.accelerator.gather(mean_delta).mean().item()
@ -822,17 +822,17 @@ def grpo_trainer_compute_loss(function_name, function):
min_importance_sampling_ratio = (
torch.min(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=self.model.device)
else torch.tensor(0.0, device = self.model.device)
)
mean_importance_sampling_ratio = (
torch.mean(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=self.model.device)
else torch.tensor(0.0, device = self.model.device)
)
max_importance_sampling_ratio = (
torch.max(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=self.model.device)
else torch.tensor(0.0, device = self.model.device)
)
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()

View file

@ -115,9 +115,9 @@ HAS_TORCH_DTYPE = "torch_dtype" in PretrainedConfig.__doc__
from transformers import GenerationConfig, CompileConfig, HybridCache
_compile_config = CompileConfig(
fullgraph=False,
dynamic=None,
mode="reduce-overhead",
fullgraph = False,
dynamic = None,
mode = "reduce-overhead",
)
_compile_config.disable = True # Must set manually
@ -219,10 +219,10 @@ def unsloth_base_fast_generate(
# Mixed precision autocast
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
autocaster = torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=torch.float16)
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = torch.float16)
dtype = torch.float16
else:
autocaster = torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=dtype)
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype)
# Prepare LoRA
# state_dict = convert_lora_modules(self, dtype = dtype)
@ -316,34 +316,34 @@ def unsloth_base_fast_generate(
class FastBaseModel:
@staticmethod
def from_pretrained(
model_name="unsloth/Llama-3.2-1B-Instruct",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
load_in_8bit=False,
load_in_16bit=False,
full_finetuning=False,
token=None,
device_map="sequential",
trust_remote_code=False,
model_types=None,
tokenizer_name=None,
auto_model=AutoModelForVision2Seq,
use_gradient_checkpointing="unsloth",
supports_sdpa=True,
whisper_language=None,
whisper_task=None,
auto_config=None,
offload_embedding=False,
float32_mixed_precision=None, # Forces float32 mixed precision
model_name = "unsloth/Llama-3.2-1B-Instruct",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
load_in_8bit = False,
load_in_16bit = False,
full_finetuning = False,
token = None,
device_map = "sequential",
trust_remote_code = False,
model_types = None,
tokenizer_name = None,
auto_model = AutoModelForVision2Seq,
use_gradient_checkpointing = "unsloth",
supports_sdpa = True,
whisper_language = None,
whisper_task = None,
auto_config = None,
offload_embedding = False,
float32_mixed_precision = None, # Forces float32 mixed precision
# vLLM parameters
fast_inference=False,
gpu_memory_utilization=0.5,
float8_kv_cache=False,
random_state=3407,
max_lora_rank=64,
disable_log_stats=False,
unsloth_vllm_standby=False,
fast_inference = False,
gpu_memory_utilization = 0.5,
float8_kv_cache = False,
random_state = 3407,
max_lora_rank = 64,
disable_log_stats = False,
unsloth_vllm_standby = False,
**kwargs,
):
if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1":
@ -539,16 +539,16 @@ class FastBaseModel:
)
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_compute_dtype,
llm_int8_skip_modules=SKIP_QUANTIZATION_MODULES.copy(),
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = bnb_compute_dtype,
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
)
elif load_in_8bit:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=SKIP_QUANTIZATION_MODULES.copy(),
load_in_8bit = True,
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
)
elif load_in_16bit:
bnb_config = None
@ -597,8 +597,8 @@ class FastBaseModel:
if auto_config is None:
auto_config = AutoConfig.from_pretrained(
model_name,
token=token,
trust_remote_code=trust_remote_code,
token = token,
trust_remote_code = trust_remote_code,
)
if hasattr(auto_config, "quantization_config"):
from transformers.quantizers.auto import (
@ -648,9 +648,9 @@ class FastBaseModel:
model_config = AutoConfig.from_pretrained(
model_name,
token=token,
attn_implementation="sdpa" if supports_sdpa else "eager",
trust_remote_code=trust_remote_code,
token = token,
attn_implementation = "sdpa" if supports_sdpa else "eager",
trust_remote_code = trust_remote_code,
)
verify_fp8_support_if_applicable(model_config)
@ -660,11 +660,11 @@ class FastBaseModel:
load_in_fp8 = kwargs.pop("load_in_fp8", None)
model = auto_model.from_pretrained(
model_name,
device_map=device_map,
device_map = device_map,
# torch_dtype = torch_dtype, # Transformers removed torch_dtype
# quantization_config = bnb_config,
token=token,
trust_remote_code=trust_remote_code,
token = token,
trust_remote_code = trust_remote_code,
# attn_implementation = attn_implementation,
**kwargs,
)
@ -715,18 +715,18 @@ class FastBaseModel:
allowed_args = inspect.getfullargspec(load_vllm).args
load_vllm_kwargs = dict(
model_name=model_name,
config=model_config,
gpu_memory_utilization=gpu_memory_utilization,
max_seq_length=max_seq_length,
dtype=dtype,
float8_kv_cache=float8_kv_cache,
enable_lora=vllm_enable_lora,
max_lora_rank=max_lora_rank,
disable_log_stats=disable_log_stats,
use_bitsandbytes=load_in_4bit,
unsloth_vllm_standby=unsloth_vllm_standby,
is_vision_model=is_vlm,
model_name = model_name,
config = model_config,
gpu_memory_utilization = gpu_memory_utilization,
max_seq_length = max_seq_length,
dtype = dtype,
float8_kv_cache = float8_kv_cache,
enable_lora = vllm_enable_lora,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
use_bitsandbytes = load_in_4bit,
unsloth_vllm_standby = unsloth_vllm_standby,
is_vision_model = is_vlm,
)
for allowed_arg in allowed_args:
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
@ -738,15 +738,15 @@ class FastBaseModel:
# Convert to HF format
_, quant_state_dict = get_vllm_state_dict(
llm,
config=model_config,
is_vision_model=is_vlm,
config = model_config,
is_vision_model = is_vlm,
)
model = convert_vllm_to_huggingface(
quant_state_dict,
model_config,
dtype,
bnb_config,
is_vision_model=is_vlm,
is_vision_model = is_vlm,
)
model.vllm_engine = llm
model.fast_generate = model.vllm_engine.generate
@ -788,26 +788,26 @@ class FastBaseModel:
):
tokenizer = auto_processor.from_pretrained(
tokenizer_name,
padding_side="left",
token=token,
language=whisper_language,
task=whisper_task,
trust_remote_code=trust_remote_code,
padding_side = "left",
token = token,
language = whisper_language,
task = whisper_task,
trust_remote_code = trust_remote_code,
)
else:
try:
tokenizer = auto_processor.from_pretrained(
tokenizer_name,
padding_side="left",
token=token,
trust_remote_code=trust_remote_code,
padding_side = "left",
token = token,
trust_remote_code = trust_remote_code,
)
except:
tokenizer = get_auto_processor(
tokenizer_name,
padding_side="left",
token=token,
trust_remote_code=trust_remote_code,
padding_side = "left",
token = token,
trust_remote_code = trust_remote_code,
)
if hasattr(tokenizer, "tokenizer"):
__tokenizer = tokenizer.tokenizer
@ -827,10 +827,10 @@ class FastBaseModel:
model, tokenizer = patch_model_and_tokenizer(
model,
tokenizer,
downcast_rope=False,
fix_embeddings=False,
do_forced_float32=do_forced_float32,
correct_dtype=correct_dtype,
downcast_rope = False,
fix_embeddings = False,
do_forced_float32 = do_forced_float32,
correct_dtype = correct_dtype,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
model = post_patch_loss_function(model)
@ -838,13 +838,13 @@ class FastBaseModel:
# Log Unsloth version for future fastpaths for inference
if hasattr(model, "config"):
model.config.update({"unsloth_version": __version__})
patch_saving_functions(model, vision=True)
patch_saving_functions(model, vision = True)
if tokenizer is None:
del model
raise RuntimeError(
"Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
)
patch_saving_functions(tokenizer, vision=True)
patch_saving_functions(tokenizer, vision = True)
# Fix gradient accumulation
from transformers.trainer import Trainer
@ -882,11 +882,11 @@ class FastBaseModel:
# Post patches
model = FastBaseModel.post_patch_model(
model,
use_gradient_checkpointing=use_gradient_checkpointing,
trust_remote_code=trust_remote_code,
model_type=model_type_arch,
tokenizer=tokenizer,
float32_mixed_precision=float32_mixed_precision,
use_gradient_checkpointing = use_gradient_checkpointing,
trust_remote_code = trust_remote_code,
model_type = model_type_arch,
tokenizer = tokenizer,
float32_mixed_precision = float32_mixed_precision,
)
# Clear deleted GPU items
for _ in range(3):
@ -900,27 +900,27 @@ class FastBaseModel:
@staticmethod
def get_peft_model(
model,
r=16,
target_modules=None,
lora_alpha=16,
lora_dropout=0.0,
bias="none",
finetune_vision_layers=True,
finetune_language_layers=True,
finetune_attention_modules=True,
finetune_mlp_modules=True,
layers_to_transform=None,
layers_pattern=None,
use_gradient_checkpointing="unsloth",
random_state=3407,
max_seq_length=2048, # not used anymore
use_rslora=False,
modules_to_save=None,
init_lora_weights=True,
loftq_config={},
task_type=TaskType.CAUSAL_LM,
temporary_location="_unsloth_temporary_saved_buffers",
qat_scheme=None,
r = 16,
target_modules = None,
lora_alpha = 16,
lora_dropout = 0.0,
bias = "none",
finetune_vision_layers = True,
finetune_language_layers = True,
finetune_attention_modules = True,
finetune_mlp_modules = True,
layers_to_transform = None,
layers_pattern = None,
use_gradient_checkpointing = "unsloth",
random_state = 3407,
max_seq_length = 2048, # not used anymore
use_rslora = False,
modules_to_save = None,
init_lora_weights = True,
loftq_config = {},
task_type = TaskType.CAUSAL_LM,
temporary_location = "_unsloth_temporary_saved_buffers",
qat_scheme = None,
**kwargs,
):
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
@ -948,10 +948,10 @@ class FastBaseModel:
if target_modules is None or target_modules == "all-linear":
target_modules = get_peft_regex(
model,
finetune_vision_layers=finetune_vision_layers,
finetune_language_layers=finetune_language_layers,
finetune_attention_modules=finetune_attention_modules,
finetune_mlp_modules=finetune_mlp_modules,
finetune_vision_layers = finetune_vision_layers,
finetune_language_layers = finetune_language_layers,
finetune_attention_modules = finetune_attention_modules,
finetune_mlp_modules = finetune_mlp_modules,
)
else:
assert type(target_modules) in (
@ -1013,7 +1013,7 @@ class FastBaseModel:
)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing = use_gradient_checkpointing,
)
model = _get_peft_model(model, lora_config)
# Apply QAT + LoRA if specified
@ -1027,8 +1027,8 @@ class FastBaseModel:
trust_remote_code = getattr(model, "_unsloth_trust_remote_code", False)
model = FastBaseModel.post_patch_model(
model,
use_gradient_checkpointing=use_gradient_checkpointing,
trust_remote_code=trust_remote_code,
use_gradient_checkpointing = use_gradient_checkpointing,
trust_remote_code = trust_remote_code,
)
model.max_seq_length = max_seq_length
# Save to modules as well
@ -1041,7 +1041,7 @@ class FastBaseModel:
torch.cuda.empty_cache()
elif DEVICE_TYPE == "xpu":
torch.xpu.empty_cache()
patch_saving_functions(model, vision=True)
patch_saving_functions(model, vision = True)
patch_peft_fast_inference(model)
# Add for_inference and for_training
@ -1057,11 +1057,11 @@ class FastBaseModel:
@staticmethod
def post_patch_model(
model,
use_gradient_checkpointing=True,
trust_remote_code=False,
model_type=None,
tokenizer=None,
float32_mixed_precision=None,
use_gradient_checkpointing = True,
trust_remote_code = False,
model_type = None,
tokenizer = None,
float32_mixed_precision = None,
):
full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1"
@ -1079,14 +1079,14 @@ class FastBaseModel:
model = prepare_model_for_training(
model,
use_gradient_checkpointing=use_gradient_checkpointing,
use_reentrant=True,
full_finetuning=full_finetuning,
train_layernorms=full_finetuning,
train_embedding=full_finetuning,
train_lm_head=full_finetuning,
float32_mixed_precision=float32_mixed_precision,
patch_modules_to_save=True,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
full_finetuning = full_finetuning,
train_layernorms = full_finetuning,
train_embedding = full_finetuning,
train_lm_head = full_finetuning,
float32_mixed_precision = float32_mixed_precision,
patch_modules_to_save = True,
)
from transformers.trainer import Trainer
@ -1096,7 +1096,7 @@ class FastBaseModel:
and trust_remote_code == False
):
raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
patch_saving_functions(model, vision=True)
patch_saving_functions(model, vision = True)
# Patch tokenizer to pad to the left
m = model
@ -1194,11 +1194,11 @@ class FastBaseModel:
os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
# Turn off skip guards and set stance to default
if torch_compiler_set_stance is not None:
torch_compiler_set_stance(stance="default", skip_guard_eval_unsafe=False)
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
return model
@staticmethod
def for_training(model, use_gradient_checkpointing=True):
def for_training(model, use_gradient_checkpointing = True):
if not hasattr(model, "parameters"):
raise TypeError(
"Unsloth: I think you're passing a tokenizer, not the model to for_training!"
@ -1250,5 +1250,5 @@ class FastBaseModel:
os.environ["UNSLOTH_RETURN_LOGITS"] = "0"
# Turn off skip guards and set stance to default
if torch_compiler_set_stance is not None:
torch_compiler_set_stance(stance="default", skip_guard_eval_unsafe=False)
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
return model

View file

@ -25,50 +25,50 @@ class LlamaVisionModelInfo(ModelInfo):
# Llama 3.1
LlamaMeta_3_1 = ModelMeta(
org="meta-llama",
base_name="Llama",
instruct_tags=[None, "Instruct"],
model_version="3.1",
model_sizes=["8"],
model_info_cls=LlamaModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "meta-llama",
base_name = "Llama",
instruct_tags = [None, "Instruct"],
model_version = "3.1",
model_sizes = ["8"],
model_info_cls = LlamaModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Llama 3.2 Base Models
LlamaMeta_3_2_Base = ModelMeta(
org="meta-llama",
base_name="Llama",
instruct_tags=[None],
model_version="3.2",
model_sizes=["1", "3"],
model_info_cls=LlamaModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "meta-llama",
base_name = "Llama",
instruct_tags = [None],
model_version = "3.2",
model_sizes = ["1", "3"],
model_info_cls = LlamaModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Llama 3.2 Instruction Tuned Models
LlamaMeta_3_2_Instruct = ModelMeta(
org="meta-llama",
base_name="Llama",
instruct_tags=["Instruct"],
model_version="3.2",
model_sizes=["1", "3"],
model_info_cls=LlamaModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
org = "meta-llama",
base_name = "Llama",
instruct_tags = ["Instruct"],
model_version = "3.2",
model_sizes = ["1", "3"],
model_info_cls = LlamaModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
# Llama 3.2 Vision
LlamaMeta_3_2_Vision = ModelMeta(
org="meta-llama",
base_name="Llama",
instruct_tags=[None, "Instruct"],
model_version="3.2",
model_sizes=["11", "90"],
model_info_cls=LlamaVisionModelInfo,
is_multimodal=True,
quant_types={
org = "meta-llama",
base_name = "Llama",
instruct_tags = [None, "Instruct"],
model_version = "3.2",
model_sizes = ["11", "90"],
model_info_cls = LlamaVisionModelInfo,
is_multimodal = True,
quant_types = {
"11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
"90": [QuantType.NONE],
},
@ -79,7 +79,7 @@ def register_llama_3_1_models(include_original_model: bool = False):
global _IS_LLAMA_3_1_REGISTERED
if _IS_LLAMA_3_1_REGISTERED:
return
_register_models(LlamaMeta_3_1, include_original_model=include_original_model)
_register_models(LlamaMeta_3_1, include_original_model = include_original_model)
_IS_LLAMA_3_1_REGISTERED = True
@ -87,9 +87,9 @@ def register_llama_3_2_models(include_original_model: bool = False):
global _IS_LLAMA_3_2_REGISTERED
if _IS_LLAMA_3_2_REGISTERED:
return
_register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model)
_register_models(LlamaMeta_3_2_Base, include_original_model = include_original_model)
_register_models(
LlamaMeta_3_2_Instruct, include_original_model=include_original_model
LlamaMeta_3_2_Instruct, include_original_model = include_original_model
)
_IS_LLAMA_3_2_REGISTERED = True
@ -99,15 +99,15 @@ def register_llama_3_2_vision_models(include_original_model: bool = False):
if _IS_LLAMA_3_2_VISION_REGISTERED:
return
_register_models(
LlamaMeta_3_2_Vision, include_original_model=include_original_model
LlamaMeta_3_2_Vision, include_original_model = include_original_model
)
_IS_LLAMA_3_2_VISION_REGISTERED = True
def register_llama_models(include_original_model: bool = False):
register_llama_3_1_models(include_original_model=include_original_model)
register_llama_3_2_models(include_original_model=include_original_model)
register_llama_3_2_vision_models(include_original_model=include_original_model)
register_llama_3_1_models(include_original_model = include_original_model)
register_llama_3_2_models(include_original_model = include_original_model)
register_llama_3_2_vision_models(include_original_model = include_original_model)
if __name__ == "__main__":
@ -115,7 +115,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
register_llama_models(include_original_model=True)
register_llama_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

View file

@ -15,26 +15,26 @@ class PhiModelInfo(ModelInfo):
# Phi Model Meta
PhiMeta4 = ModelMeta(
org="microsoft",
base_name="phi",
instruct_tags=[None],
model_version="4",
model_sizes=["1"], # Assuming only one size
model_info_cls=PhiModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "microsoft",
base_name = "phi",
instruct_tags = [None],
model_version = "4",
model_sizes = ["1"], # Assuming only one size
model_info_cls = PhiModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Phi Instruct Model Meta
PhiInstructMeta4 = ModelMeta(
org="microsoft",
base_name="phi",
instruct_tags=["mini-instruct"],
model_version="4",
model_sizes=["1"], # Assuming only one size
model_info_cls=PhiModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
org = "microsoft",
base_name = "phi",
instruct_tags = ["mini-instruct"],
model_version = "4",
model_sizes = ["1"], # Assuming only one size
model_info_cls = PhiModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
@ -42,7 +42,7 @@ def register_phi_4_models(include_original_model: bool = False):
global _IS_PHI_4_REGISTERED
if _IS_PHI_4_REGISTERED:
return
_register_models(PhiMeta4, include_original_model=include_original_model)
_register_models(PhiMeta4, include_original_model = include_original_model)
_IS_PHI_4_REGISTERED = True
@ -50,13 +50,13 @@ def register_phi_4_instruct_models(include_original_model: bool = False):
global _IS_PHI_4_INSTRUCT_REGISTERED
if _IS_PHI_4_INSTRUCT_REGISTERED:
return
_register_models(PhiInstructMeta4, include_original_model=include_original_model)
_register_models(PhiInstructMeta4, include_original_model = include_original_model)
_IS_PHI_4_INSTRUCT_REGISTERED = True
def register_phi_models(include_original_model: bool = False):
register_phi_4_models(include_original_model=include_original_model)
register_phi_4_instruct_models(include_original_model=include_original_model)
register_phi_4_models(include_original_model = include_original_model)
register_phi_4_instruct_models(include_original_model = include_original_model)
if __name__ == "__main__":
@ -64,7 +64,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
register_phi_models(include_original_model=True)
register_phi_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

File diff suppressed because it is too large Load diff

View file

@ -71,7 +71,7 @@ KAGGLE_TMP = "/tmp"
del keynames
def try_fix_tokenizer(tokenizer, prepend=True):
def try_fix_tokenizer(tokenizer, prepend = True):
if hasattr(tokenizer, "_tokenizer"):
converted_tokenizer = tokenizer._tokenizer
else:
@ -130,7 +130,7 @@ def get_sorted_dict(dictionary):
def convert_to_fast_tokenizer(
slow_tokenizer,
temporary_location="_unsloth_sentencepiece_temp",
temporary_location = "_unsloth_sentencepiece_temp",
):
is_fast = getattr(slow_tokenizer, "is_fast", False)
if is_fast:
@ -152,20 +152,20 @@ def convert_to_fast_tokenizer(
# Get all arguments (bos_token, etc)
docs = FastTokenizer.__doc__
docs = docs[docs.find("Args:") :]
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags=re.MULTILINE)
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args = [x for x in args if not x.endswith("_file")]
# Also some missing maybe!
docs = PreTrainedTokenizerFast.__doc__
docs = docs[docs.find("Args:") :]
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags=re.MULTILINE)
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args2 = [x for x in args2 if not x.endswith("_file")]
args = list(set(args + args2))
kwargs = {}
for arg in args:
kwargs[arg] = getattr(slow_tokenizer, arg, None)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend=True)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
fast_tokenizer = FastTokenizer(**kwargs)
# Check if they're similar!
@ -184,7 +184,7 @@ def convert_to_fast_tokenizer(
# Now confirm if they match
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Maybe remove prepending of __apple?
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend=False)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
fast_tokenizer = FastTokenizer(**kwargs)
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Failure :(
@ -347,7 +347,7 @@ def fix_sentencepiece_tokenizer(
old_tokenizer,
new_tokenizer,
token_mapping,
temporary_location="_unsloth_sentencepiece_temp",
temporary_location = "_unsloth_sentencepiece_temp",
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
@ -390,7 +390,7 @@ def fix_sentencepiece_tokenizer(
# Now correct the old tokenizer's .model file
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens=False).input_ids
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
if len(ids) != 1:
# Skip this token!
@ -416,8 +416,8 @@ def fix_sentencepiece_tokenizer(
tokenizer = AutoTokenizer.from_pretrained(
temporary_location,
eos_token=new_tokenizer.eos_token,
pad_token=new_tokenizer.pad_token,
eos_token = new_tokenizer.eos_token,
pad_token = new_tokenizer.pad_token,
)
return tokenizer
@ -453,13 +453,13 @@ def fix_sentencepiece_gguf(saved_location):
# Load added_tokens_json
if not os.path.isfile(f"{saved_location}/added_tokens.json"):
return
with open(f"{saved_location}/added_tokens.json", "r", encoding="utf-8") as file:
with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file:
added_tokens_json = json.load(file)
if len(added_tokens_json) == 0:
return
added_tokens_json = dict(
sorted(added_tokens_json.items(), key=lambda item: item[1])
sorted(added_tokens_json.items(), key = lambda item: item[1])
)
new_size = sentence_piece_size + len(added_tokens_json)
@ -496,12 +496,12 @@ def fix_sentencepiece_gguf(saved_location):
def _load_correct_tokenizer(
tokenizer_name,
model_max_length=None,
padding_side="right",
token=None,
trust_remote_code=False,
cache_dir="huggingface_tokenizers_cache",
fix_tokenizer=True,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
if IS_COLAB_ENVIRONMENT:
cache_dir = cache_dir
@ -518,15 +518,15 @@ def _load_correct_tokenizer(
try:
slow_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length=model_max_length,
padding_side=padding_side,
token=token,
trust_remote_code=trust_remote_code,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast=False,
legacy=False,
from_slow=True,
cache_dir=cache_dir,
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
except:
slow_tokenizer = None
@ -540,11 +540,11 @@ def _load_correct_tokenizer(
fast_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length=model_max_length,
padding_side=padding_side,
token=token,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
cache_dir = cache_dir,
)
if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
@ -580,21 +580,21 @@ def _load_correct_tokenizer(
def load_correct_tokenizer(
tokenizer_name,
model_max_length=None,
padding_side="right",
token=None,
trust_remote_code=False,
cache_dir="huggingface_tokenizers_cache",
fix_tokenizer=True,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
tokenizer = _load_correct_tokenizer(
tokenizer_name=tokenizer_name,
model_max_length=model_max_length,
padding_side=padding_side,
token=token,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
fix_tokenizer=fix_tokenizer,
tokenizer_name = tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
cache_dir = cache_dir,
fix_tokenizer = fix_tokenizer,
)
### 1. Fixup tokenizer's chat_template
@ -683,7 +683,7 @@ def fix_chat_template(tokenizer):
{"role": "user", "content": "Who are you?"},
]
tokenizer.apply_chat_template(
messages, add_generation_prompt=False, tokenize=False
messages, add_generation_prompt = False, tokenize = False
)
is_sharegpt = False
except:
@ -692,7 +692,7 @@ def fix_chat_template(tokenizer):
{"from": "human", "value": "Who are you?"},
]
tokenizer.apply_chat_template(
messages, add_generation_prompt=False, tokenize=False
messages, add_generation_prompt = False, tokenize = False
)
is_sharegpt = True
except:
@ -709,10 +709,10 @@ def fix_chat_template(tokenizer):
else {"from": "human", "value": "Who are you?"}
]
no = tokenizer.apply_chat_template(
messages, add_generation_prompt=False, tokenize=False
messages, add_generation_prompt = False, tokenize = False
)
yes = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
messages, add_generation_prompt = True, tokenize = False
)
if no == yes:
@ -750,11 +750,11 @@ def fix_chat_template(tokenizer):
def check_tokenizer(
model,
tokenizer,
model_name="unsloth/llama-2-7b-bnb-4bit",
model_max_length=4096,
padding_side="right",
token=None,
_reload=True,
model_name = "unsloth/llama-2-7b-bnb-4bit",
model_max_length = 4096,
padding_side = "right",
token = None,
_reload = True,
):
# Checks tokenizer for out of bounds ids.
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
@ -861,23 +861,23 @@ def check_tokenizer(
# Try slow tokenizer which can fix things!
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length=model_max_length,
padding_side=padding_side,
token=token,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast=False,
legacy=False,
from_slow=True,
cache_dir=cache_dir,
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
return check_tokenizer(
model=model,
tokenizer=tokenizer,
model_name=model_name,
model_max_length=model_max_length,
padding_side=padding_side,
token=token,
_reload=False,
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
_reload = False,
)
break
except:
@ -993,7 +993,7 @@ def patch_sft_trainer_tokenizer():
replacer = re.findall(
f"def {function_name}" + r"\(.*?\).*?\:\n",
function,
flags=re.MULTILINE | re.DOTALL,
flags = re.MULTILINE | re.DOTALL,
)
if len(replacer) == 0:
continue