mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
* fix: avoid PerceptionEncoder ImportError blocking trust_remote_code model loads * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update config class retrieval in _utils.py Refactor config class retrieval logic to use model mapping. * resolve_model_class: restore _extra_content fallback The previous fallback iterated mapping.items(), which transformers' _LazyAutoMapping defines as _model_mapping entries + _extra_content entries. The PR's per-key loop covers only _model_mapping, so subclasses of configs registered via AutoModel.register(cfg, model) silently resolve to None. Add a safe isinstance pass over _extra_content (no lazy loads, no crash risk) before giving up. * Add tests for resolve_model_class fallback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
137 lines
3.3 KiB
Python
137 lines
3.3 KiB
Python
from unsloth.models._utils import resolve_model_class
|
|
|
|
|
|
class _AutoModelLike:
|
|
def __init__(self, mapping):
|
|
self._model_mapping = mapping
|
|
|
|
|
|
class _FakeLazyMapping:
|
|
def __init__(self, entries, extra_content = None, broken_keys = ()):
|
|
self._entries = dict(entries)
|
|
self._config_mapping = {k: f"Cfg_{k}" for k in self._entries}
|
|
self._model_mapping = {k: f"Mdl_{k}" for k in self._entries}
|
|
self._extra_content = dict(extra_content or {})
|
|
self._broken_keys = set(broken_keys)
|
|
|
|
def __getitem__(self, key):
|
|
if key in self._extra_content:
|
|
return self._extra_content[key]
|
|
for k, (cfg_cls, mdl_cls) in self._entries.items():
|
|
if k in self._broken_keys:
|
|
raise ValueError(f"broken entry {k}")
|
|
if cfg_cls is key:
|
|
return mdl_cls
|
|
raise KeyError(key)
|
|
|
|
def _load_attr_from_module(self, key, attr):
|
|
if key in self._broken_keys:
|
|
raise ValueError(f"broken entry {key}")
|
|
cfg_cls, mdl_cls = self._entries[key]
|
|
if attr == self._config_mapping[key]:
|
|
return cfg_cls
|
|
if attr == self._model_mapping[key]:
|
|
return mdl_cls
|
|
raise KeyError(attr)
|
|
|
|
|
|
class CfgA:
|
|
pass
|
|
|
|
|
|
class CfgB:
|
|
pass
|
|
|
|
|
|
class CfgBChild(CfgB):
|
|
pass
|
|
|
|
|
|
class ModelA:
|
|
pass
|
|
|
|
|
|
class ModelB:
|
|
pass
|
|
|
|
|
|
class RegBase:
|
|
pass
|
|
|
|
|
|
class RegChild(RegBase):
|
|
pass
|
|
|
|
|
|
class RegModel:
|
|
pass
|
|
|
|
|
|
class UnknownCfg:
|
|
pass
|
|
|
|
|
|
def test_fast_path_exact_match():
|
|
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, CfgA()) is ModelA
|
|
|
|
|
|
def test_fallback_subclass_match_via_lazy_mapping():
|
|
m = _FakeLazyMapping({"a": (CfgA, ModelA), "b": (CfgB, ModelB)})
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, CfgBChild()) is ModelB
|
|
|
|
|
|
def test_broken_lazy_entry_does_not_crash():
|
|
m = _FakeLazyMapping(
|
|
{"broken": (CfgA, ModelA), "b": (CfgB, ModelB)},
|
|
broken_keys = ("broken",),
|
|
)
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, CfgBChild()) is ModelB
|
|
|
|
|
|
def test_unknown_config_returns_none():
|
|
m = _FakeLazyMapping({"a": (CfgA, ModelA)})
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, UnknownCfg()) is None
|
|
|
|
|
|
def test_extra_content_subclass_fallback():
|
|
m = _FakeLazyMapping(
|
|
{"a": (CfgA, ModelA)},
|
|
extra_content = {RegBase: RegModel},
|
|
)
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, RegChild()) is RegModel
|
|
|
|
|
|
def test_extra_content_exact_match_fast_path():
|
|
m = _FakeLazyMapping(
|
|
{"a": (CfgA, ModelA)},
|
|
extra_content = {RegBase: RegModel},
|
|
)
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, RegBase()) is RegModel
|
|
|
|
|
|
def test_broken_entry_with_extra_content_subclass():
|
|
m = _FakeLazyMapping(
|
|
{"broken": (CfgA, ModelA)},
|
|
extra_content = {RegBase: RegModel},
|
|
broken_keys = ("broken",),
|
|
)
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, RegChild()) is RegModel
|
|
|
|
|
|
def test_plain_dict_mapping_is_not_required():
|
|
am = _AutoModelLike({CfgA: ModelA})
|
|
assert resolve_model_class(am, CfgA()) is ModelA
|
|
|
|
|
|
def test_tuple_result_unwrapped():
|
|
m = _FakeLazyMapping({"a": (CfgA, (ModelA, "extra"))})
|
|
am = _AutoModelLike(m)
|
|
assert resolve_model_class(am, CfgA()) is ModelA
|