mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-26 10:31:03 +00:00
fix: guard resolve_model_class fallback against unresolvable transformers AutoModel entries (#5155)
* fix: avoid PerceptionEncoder ImportError blocking trust_remote_code model loads * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update config class retrieval in _utils.py Refactor config class retrieval logic to use model mapping. * resolve_model_class: restore _extra_content fallback The previous fallback iterated mapping.items(), which transformers' _LazyAutoMapping defines as _model_mapping entries + _extra_content entries. The PR's per-key loop covers only _model_mapping, so subclasses of configs registered via AutoModel.register(cfg, model) silently resolve to None. Add a safe isinstance pass over _extra_content (no lazy loads, no crash risk) before giving up. * Add tests for resolve_model_class fallback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
parent
06ed94da0d
commit
0326577b82
2 changed files with 161 additions and 6 deletions
137
tests/test_resolve_model_class.py
Normal file
137
tests/test_resolve_model_class.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from unsloth.models._utils import resolve_model_class
|
||||
|
||||
|
||||
class _AutoModelLike:
|
||||
def __init__(self, mapping):
|
||||
self._model_mapping = mapping
|
||||
|
||||
|
||||
class _FakeLazyMapping:
|
||||
def __init__(self, entries, extra_content = None, broken_keys = ()):
|
||||
self._entries = dict(entries)
|
||||
self._config_mapping = {k: f"Cfg_{k}" for k in self._entries}
|
||||
self._model_mapping = {k: f"Mdl_{k}" for k in self._entries}
|
||||
self._extra_content = dict(extra_content or {})
|
||||
self._broken_keys = set(broken_keys)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
for k, (cfg_cls, mdl_cls) in self._entries.items():
|
||||
if k in self._broken_keys:
|
||||
raise ValueError(f"broken entry {k}")
|
||||
if cfg_cls is key:
|
||||
return mdl_cls
|
||||
raise KeyError(key)
|
||||
|
||||
def _load_attr_from_module(self, key, attr):
|
||||
if key in self._broken_keys:
|
||||
raise ValueError(f"broken entry {key}")
|
||||
cfg_cls, mdl_cls = self._entries[key]
|
||||
if attr == self._config_mapping[key]:
|
||||
return cfg_cls
|
||||
if attr == self._model_mapping[key]:
|
||||
return mdl_cls
|
||||
raise KeyError(attr)
|
||||
|
||||
|
||||
class CfgA:
|
||||
pass
|
||||
|
||||
|
||||
class CfgB:
|
||||
pass
|
||||
|
||||
|
||||
class CfgBChild(CfgB):
|
||||
pass
|
||||
|
||||
|
||||
class ModelA:
|
||||
pass
|
||||
|
||||
|
||||
class ModelB:
|
||||
pass
|
||||
|
||||
|
||||
class RegBase:
|
||||
pass
|
||||
|
||||
|
||||
class RegChild(RegBase):
|
||||
pass
|
||||
|
||||
|
||||
class RegModel:
|
||||
pass
|
||||
|
||||
|
||||
class UnknownCfg:
|
||||
pass
|
||||
|
||||
|
||||
def test_fast_path_exact_match():
|
||||
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, CfgA()) is ModelA
|
||||
|
||||
|
||||
def test_fallback_subclass_match_via_lazy_mapping():
|
||||
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, CfgBChild()) is ModelB
|
||||
|
||||
|
||||
def test_broken_lazy_entry_does_not_crash():
|
||||
m = _FakeLazyMapping(
|
||||
{"broken": (CfgA, ModelA), "b": (CfgB, ModelB)},
|
||||
broken_keys = ("broken",),
|
||||
)
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, CfgBChild()) is ModelB
|
||||
|
||||
|
||||
def test_unknown_config_returns_none():
|
||||
m = _FakeLazyMapping({"a": (CfgA, ModelA)})
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, UnknownCfg()) is None
|
||||
|
||||
|
||||
def test_extra_content_subclass_fallback():
|
||||
m = _FakeLazyMapping(
|
||||
{"a": (CfgA, ModelA)},
|
||||
extra_content = {RegBase: RegModel},
|
||||
)
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, RegChild()) is RegModel
|
||||
|
||||
|
||||
def test_extra_content_exact_match_fast_path():
|
||||
m = _FakeLazyMapping(
|
||||
{"a": (CfgA, ModelA)},
|
||||
extra_content = {RegBase: RegModel},
|
||||
)
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, RegBase()) is RegModel
|
||||
|
||||
|
||||
def test_broken_entry_with_extra_content_subclass():
|
||||
m = _FakeLazyMapping(
|
||||
{"broken": (CfgA, ModelA)},
|
||||
extra_content = {RegBase: RegModel},
|
||||
broken_keys = ("broken",),
|
||||
)
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, RegChild()) is RegModel
|
||||
|
||||
|
||||
def test_plain_dict_mapping_is_not_required():
|
||||
am = _AutoModelLike({CfgA: ModelA})
|
||||
assert resolve_model_class(am, CfgA()) is ModelA
|
||||
|
||||
|
||||
def test_tuple_result_unwrapped():
|
||||
m = _FakeLazyMapping({"a": (CfgA, (ModelA, "extra"))})
|
||||
am = _AutoModelLike(m)
|
||||
assert resolve_model_class(am, CfgA()) is ModelA
|
||||
|
|
@ -415,13 +415,31 @@ def resolve_model_class(auto_model, config):
|
|||
try:
|
||||
result = mapping[config.__class__]
|
||||
except Exception:
|
||||
for config_class, model_class in mapping.items():
|
||||
result = None
|
||||
for key in list(getattr(mapping, "_model_mapping", {})):
|
||||
try:
|
||||
config_class = mapping._load_attr_from_module(
|
||||
key, mapping._config_mapping[key]
|
||||
)
|
||||
if isinstance(config, config_class):
|
||||
result = model_class
|
||||
result = mapping._load_attr_from_module(
|
||||
key, mapping._model_mapping[key]
|
||||
)
|
||||
break
|
||||
else:
|
||||
except Exception:
|
||||
continue
|
||||
if result is None:
|
||||
for extra_cls, extra_model in getattr(
|
||||
mapping, "_extra_content", {}
|
||||
).items():
|
||||
try:
|
||||
if isinstance(config, extra_cls):
|
||||
result = extra_model
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result[0] if isinstance(result, (list, tuple)) else result
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue