mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
Revert "[FIX] Vllm guided decoding params (#3662)"
This reverts commit fb4f0fdf56.
This commit is contained in:
parent
fb4f0fdf56
commit
ba2897a318
51 changed files with 2649 additions and 2698 deletions
|
|
@ -110,7 +110,6 @@ from unsloth_zoo.device_type import (
|
|||
from .import_fixes import (
|
||||
fix_xformers_performance_issue,
|
||||
fix_vllm_aimv2_issue,
|
||||
fix_vllm_guided_decoding_params,
|
||||
ignore_logger_messages,
|
||||
patch_ipykernel_hf_xet,
|
||||
patch_trackio,
|
||||
|
|
@ -119,14 +118,13 @@ from .import_fixes import (
|
|||
|
||||
fix_xformers_performance_issue()
|
||||
fix_vllm_aimv2_issue()
|
||||
fix_vllm_guided_decoding_params()
|
||||
ignore_logger_messages()
|
||||
patch_ipykernel_hf_xet()
|
||||
patch_trackio()
|
||||
patch_datasets()
|
||||
|
||||
del fix_xformers_performance_issue
|
||||
del patch_vllm_imports
|
||||
del fix_vllm_aimv2_issue
|
||||
del ignore_logger_messages
|
||||
del patch_ipykernel_hf_xet
|
||||
del patch_trackio
|
||||
|
|
|
|||
|
|
@ -114,16 +114,6 @@ def fix_xformers_performance_issue():
|
|||
print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
|
||||
|
||||
|
||||
def fix_vllm_aimv2_issue():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
return
|
||||
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
|
||||
vllm_version = importlib_version("vllm")
|
||||
if Version(vllm_version) < Version("0.10.1"):
|
||||
vllm_version = importlib.util.find_spec("vllm").origin
|
||||
vllm_version = os.path.split(vllm_version)[0]
|
||||
ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py"
|
||||
try:
|
||||
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
|
||||
def fix_vllm_aimv2_issue():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
|
|
@ -165,22 +155,6 @@ def fix_vllm_aimv2_issue():
|
|||
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
|
||||
|
||||
|
||||
def fix_vllm_guided_decoding_params():
|
||||
if importlib.util.find_spec("vllm") is None:
|
||||
return
|
||||
# GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
|
||||
# https://github.com/vllm-project/vllm/pull/22772/files
|
||||
# trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
|
||||
import vllm
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
except ImportError:
|
||||
vllm.sampling_params.GuidedDecodingParams = (
|
||||
vllm.sampling_params.StructuredOutputsParams
|
||||
)
|
||||
|
||||
|
||||
def ignore_logger_messages():
|
||||
# Ignore Environment variable `HF_TOKEN` is set
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -59,9 +59,7 @@ def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = Non
|
|||
self.gate_proj, X, out = temp_gate
|
||||
) # pretty much the only change from transformers implementation.
|
||||
|
||||
routing_weights = torch_nn_functional_softmax(
|
||||
router_logits, dim = -1, dtype = torch.float32
|
||||
)
|
||||
routing_weights = torch_nn_functional_softmax(router_logits, dim = -1, dtype = torch.float32)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
|
||||
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
|
||||
# we cast back to the input dtype
|
||||
|
|
|
|||
|
|
@ -329,7 +329,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
try:
|
||||
trainer = eval(f"trl.trainer.{trainer_file}")
|
||||
except Exception as error:
|
||||
print(f"Unsloth: Could not import trl.trainer.{trainer_file}: {error}")
|
||||
return
|
||||
|
||||
# Get SFTTrainer and SFTConfig names
|
||||
|
|
@ -348,14 +347,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
and trainer_file.split("_")[0] in x.lower()
|
||||
]
|
||||
if len(name) != 1:
|
||||
print(
|
||||
f"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}"
|
||||
)
|
||||
return
|
||||
if len(config) != 1:
|
||||
print(
|
||||
f"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}"
|
||||
)
|
||||
return
|
||||
|
||||
# Get SFTTrainer, SFTConfig
|
||||
|
|
@ -364,24 +357,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
try:
|
||||
RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
|
||||
except:
|
||||
print(
|
||||
f"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}"
|
||||
)
|
||||
return
|
||||
try:
|
||||
RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
|
||||
except:
|
||||
print(
|
||||
f"Unsloth: Could not load {RLConfig_name} from trl.trainer.{trainer_file}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check name
|
||||
if RLTrainer.__name__.startswith("Unsloth"):
|
||||
print(f"Unsloth: {RLTrainer.__name__} is already patched.")
|
||||
return
|
||||
if RLConfig.__name__.startswith("Unsloth"):
|
||||
print(f"Unsloth: {RLConfig.__name__} is already patched.")
|
||||
return
|
||||
|
||||
# Get old source
|
||||
|
|
@ -1306,11 +1291,7 @@ def patch_trl_rl_trainers():
|
|||
import trl.trainer
|
||||
|
||||
all_trainers = dir(trl.trainer)
|
||||
all_trainers = [
|
||||
x
|
||||
for x in all_trainers
|
||||
if x.islower() and x.endswith("_trainer") and x != "base_trainer"
|
||||
]
|
||||
all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
|
||||
for trainer in all_trainers:
|
||||
_patch_trl_rl_trainers(trainer)
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue