Revert "[FIX] Vllm guided decoding params (#3662)"

This reverts commit fb4f0fdf56.
This commit is contained in:
Daniel Han 2025-12-01 05:43:45 -08:00
parent fb4f0fdf56
commit ba2897a318
51 changed files with 2649 additions and 2698 deletions

View file

@ -110,7 +110,6 @@ from unsloth_zoo.device_type import (
from .import_fixes import (
fix_xformers_performance_issue,
fix_vllm_aimv2_issue,
fix_vllm_guided_decoding_params,
ignore_logger_messages,
patch_ipykernel_hf_xet,
patch_trackio,
@ -119,14 +118,13 @@ from .import_fixes import (
fix_xformers_performance_issue()
fix_vllm_aimv2_issue()
fix_vllm_guided_decoding_params()
ignore_logger_messages()
patch_ipykernel_hf_xet()
patch_trackio()
patch_datasets()
del fix_xformers_performance_issue
del patch_vllm_imports
del fix_vllm_aimv2_issue
del ignore_logger_messages
del patch_ipykernel_hf_xet
del patch_trackio

View file

@ -114,16 +114,6 @@ def fix_xformers_performance_issue():
print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
def fix_vllm_aimv2_issue():
if importlib.util.find_spec("vllm") is None:
return
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
vllm_version = importlib_version("vllm")
if Version(vllm_version) < Version("0.10.1"):
vllm_version = importlib.util.find_spec("vllm").origin
vllm_version = os.path.split(vllm_version)[0]
ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py"
try:
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
def fix_vllm_aimv2_issue():
if importlib.util.find_spec("vllm") is None:
@ -165,22 +155,6 @@ def fix_vllm_aimv2_issue():
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
def fix_vllm_guided_decoding_params():
if importlib.util.find_spec("vllm") is None:
return
# GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
# https://github.com/vllm-project/vllm/pull/22772/files
# trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
import vllm
try:
from vllm.sampling_params import GuidedDecodingParams
except ImportError:
vllm.sampling_params.GuidedDecodingParams = (
vllm.sampling_params.StructuredOutputsParams
)
def ignore_logger_messages():
# Ignore Environment variable `HF_TOKEN` is set
try:

View file

@ -59,9 +59,7 @@ def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = Non
self.gate_proj, X, out = temp_gate
) # pretty much the only change from transformers implementation.
routing_weights = torch_nn_functional_softmax(
router_logits, dim = -1, dtype = torch.float32
)
routing_weights = torch_nn_functional_softmax(router_logits, dim = -1, dtype = torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
# we cast back to the input dtype

View file

@ -329,7 +329,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
try:
trainer = eval(f"trl.trainer.{trainer_file}")
except Exception as error:
print(f"Unsloth: Could not import trl.trainer.{trainer_file}: {error}")
return
# Get SFTTrainer and SFTConfig names
@ -348,14 +347,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
and trainer_file.split("_")[0] in x.lower()
]
if len(name) != 1:
print(
f"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}"
)
return
if len(config) != 1:
print(
f"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}"
)
return
# Get SFTTrainer, SFTConfig
@ -364,24 +357,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
try:
RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
except:
print(
f"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}"
)
return
try:
RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
except:
print(
f"Unsloth: Could not load {RLConfig_name} from trl.trainer.{trainer_file}"
)
return
# Check name
if RLTrainer.__name__.startswith("Unsloth"):
print(f"Unsloth: {RLTrainer.__name__} is already patched.")
return
if RLConfig.__name__.startswith("Unsloth"):
print(f"Unsloth: {RLConfig.__name__} is already patched.")
return
# Get old source
@ -1306,11 +1291,7 @@ def patch_trl_rl_trainers():
import trl.trainer
all_trainers = dir(trl.trainer)
all_trainers = [
x
for x in all_trainers
if x.islower() and x.endswith("_trainer") and x != "base_trainer"
]
all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
for trainer in all_trainers:
_patch_trl_rl_trainers(trainer)
return