from __future__ import annotations from typing import Any, Iterable, TYPE_CHECKING if TYPE_CHECKING: from torch import Tensor from .base import ModelBase, TextModel, gguf @ModelBase.register("OpenELMForCausalLM") class OpenELMModel(TextModel): model_arch = gguf.MODEL_ARCH.OPENELM @staticmethod def _make_divisible(v: float | int, divisor: int) -> int: # ref: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/eb111ff2e6724348e5b905984063d4064d4bc579/configuration_openelm.py#L34-L38 new_v = max(divisor, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) ffn_multipliers: list[float] = self.hparams["ffn_multipliers"] ffn_dim_divisor: int = self.hparams["ffn_dim_divisor"] self._n_embd: int = self.hparams["model_dim"] self._num_kv_heads: list[int] = self.hparams["num_kv_heads"] self._num_query_heads: list[int] = self.hparams["num_query_heads"] self._ffn_dims: list[int] = [ OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor) for multiplier in ffn_multipliers ] assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int) # Uses the tokenizer from meta-llama/Llama-2-7b-hf def set_vocab(self): try: self._set_vocab_sentencepiece() except FileNotFoundError: self._set_vocab_builtin("llama-spm", self.hparams["vocab_size"]) def set_gguf_parameters(self): n_embd = self._n_embd head_dim = self.hparams["head_dim"] rot_pct = 1.0 assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_query_heads) assert self.block_count == len(self._ffn_dims) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["max_context_length"]) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(self._ffn_dims) self.gguf_writer.add_head_count(self._num_query_heads) self.gguf_writer.add_head_count_kv(self._num_kv_heads) self.gguf_writer.add_rope_freq_base(self.hparams["rope_freq_constant"]) # https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30 self.gguf_writer.add_layer_norm_rms_eps(1e-6) self.gguf_writer.add_rope_dimension_count(int(rot_pct * head_dim)) self.gguf_writer.add_key_length(head_dim) self.gguf_writer.add_value_length(head_dim) self.gguf_writer.add_file_type(self.ftype) def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: if "n_layers" in keys: return self.hparams["num_transformer_layers"] return super().find_hparam(keys, optional) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # split ff if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight": ff_dim = self._ffn_dims[bid] yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]) yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]) return yield (self.map_tensor_name(name), data_torch)