fix: guard resolve_model_class fallback against unresolvable transformers AutoModel entries (#5155)

* fix: avoid PerceptionEncoder ImportError blocking trust_remote_code model loads

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update config class retrieval in _utils.py

Refactor config class retrieval logic to use model mapping.

* resolve_model_class: restore _extra_content fallback

The previous fallback iterated mapping.items(), which transformers'
_LazyAutoMapping defines as _model_mapping entries + _extra_content
entries. The PR's per-key loop covers only _model_mapping, so
subclasses of configs registered via AutoModel.register(cfg, model)
silently resolve to None. Add a safe isinstance pass over
_extra_content (no lazy loads, no crash risk) before giving up.

* Add tests for resolve_model_class fallback

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
Etherll 2026-04-24 15:59:17 +03:00 committed by GitHub
parent 06ed94da0d
commit 0326577b82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 161 additions and 6 deletions

View file

@ -0,0 +1,137 @@
from unsloth.models._utils import resolve_model_class
class _AutoModelLike:
def __init__(self, mapping):
self._model_mapping = mapping
class _FakeLazyMapping:
def __init__(self, entries, extra_content = None, broken_keys = ()):
self._entries = dict(entries)
self._config_mapping = {k: f"Cfg_{k}" for k in self._entries}
self._model_mapping = {k: f"Mdl_{k}" for k in self._entries}
self._extra_content = dict(extra_content or {})
self._broken_keys = set(broken_keys)
def __getitem__(self, key):
if key in self._extra_content:
return self._extra_content[key]
for k, (cfg_cls, mdl_cls) in self._entries.items():
if k in self._broken_keys:
raise ValueError(f"broken entry {k}")
if cfg_cls is key:
return mdl_cls
raise KeyError(key)
def _load_attr_from_module(self, key, attr):
if key in self._broken_keys:
raise ValueError(f"broken entry {key}")
cfg_cls, mdl_cls = self._entries[key]
if attr == self._config_mapping[key]:
return cfg_cls
if attr == self._model_mapping[key]:
return mdl_cls
raise KeyError(attr)
class CfgA:
pass
class CfgB:
pass
class CfgBChild(CfgB):
pass
class ModelA:
pass
class ModelB:
pass
class RegBase:
pass
class RegChild(RegBase):
pass
class RegModel:
pass
class UnknownCfg:
pass
def test_fast_path_exact_match():
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
am = _AutoModelLike(m)
assert resolve_model_class(am, CfgA()) is ModelA
def test_fallback_subclass_match_via_lazy_mapping():
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
am = _AutoModelLike(m)
assert resolve_model_class(am, CfgBChild()) is ModelB
def test_broken_lazy_entry_does_not_crash():
m = _FakeLazyMapping(
{"broken": (CfgA, ModelA), "b": (CfgB, ModelB)},
broken_keys = ("broken",),
)
am = _AutoModelLike(m)
assert resolve_model_class(am, CfgBChild()) is ModelB
def test_unknown_config_returns_none():
m = _FakeLazyMapping({"a": (CfgA, ModelA)})
am = _AutoModelLike(m)
assert resolve_model_class(am, UnknownCfg()) is None
def test_extra_content_subclass_fallback():
m = _FakeLazyMapping(
{"a": (CfgA, ModelA)},
extra_content = {RegBase: RegModel},
)
am = _AutoModelLike(m)
assert resolve_model_class(am, RegChild()) is RegModel
def test_extra_content_exact_match_fast_path():
m = _FakeLazyMapping(
{"a": (CfgA, ModelA)},
extra_content = {RegBase: RegModel},
)
am = _AutoModelLike(m)
assert resolve_model_class(am, RegBase()) is RegModel
def test_broken_entry_with_extra_content_subclass():
m = _FakeLazyMapping(
{"broken": (CfgA, ModelA)},
extra_content = {RegBase: RegModel},
broken_keys = ("broken",),
)
am = _AutoModelLike(m)
assert resolve_model_class(am, RegChild()) is RegModel
def test_plain_dict_mapping_is_not_required():
am = _AutoModelLike({CfgA: ModelA})
assert resolve_model_class(am, CfgA()) is ModelA
def test_tuple_result_unwrapped():
m = _FakeLazyMapping({"a": (CfgA, (ModelA, "extra"))})
am = _AutoModelLike(m)
assert resolve_model_class(am, CfgA()) is ModelA

View file

@ -415,13 +415,31 @@ def resolve_model_class(auto_model, config):
try:
result = mapping[config.__class__]
except Exception:
for config_class, model_class in mapping.items():
if isinstance(config, config_class):
result = model_class
break
else:
result = None
for key in list(getattr(mapping, "_model_mapping", {})):
try:
config_class = mapping._load_attr_from_module(
key, mapping._config_mapping[key]
)
if isinstance(config, config_class):
result = mapping._load_attr_from_module(
key, mapping._model_mapping[key]
)
break
except Exception:
continue
if result is None:
for extra_cls, extra_model in getattr(
mapping, "_extra_content", {}
).items():
try:
if isinstance(config, extra_cls):
result = extra_model
break
except Exception:
continue
if result is None:
return None
return result[0] if isinstance(result, (list, tuple)) else result