mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
smallthinker right
This commit is contained in:
parent
f8719ee7b9
commit
712ad1fa3c
7 changed files with 48 additions and 108 deletions
|
@ -97,6 +97,7 @@ class SmallthinkerConfig(PretrainedConfig):
|
|||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
moe_layer_layout = [1]*num_hidden_layers
|
||||
# Configuration sanitizers
|
||||
assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads"
|
||||
assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers"
|
||||
|
|
|
@ -83,7 +83,7 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
|||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
router_input = hidden_states.clone()
|
||||
router_input = hidden_states
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
freqs_cis if self.model.rope_layout[i] else None,
|
||||
|
|
|
@ -839,7 +839,7 @@ def load_balancing_loss_func(
|
|||
|
||||
|
||||
# @auto_docstring
|
||||
class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
||||
class SmallThinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -897,9 +897,9 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
|||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, SmallthinkerForCausalLM
|
||||
>>> from transformers import AutoTokenizer, SmallThinkerForCausalLM
|
||||
|
||||
>>> model = SmallthinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
||||
>>> model = SmallThinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
|
@ -1212,7 +1212,7 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
|||
|
||||
|
||||
__all__ = [
|
||||
"SmallthinkerForCausalLM",
|
||||
"SmallThinkerForCausalLM",
|
||||
"SmallthinkerForQuestionAnswering",
|
||||
"SmallthinkerModel",
|
||||
"SmallthinkerPreTrainedModel",
|
||||
|
|
|
@ -471,20 +471,17 @@ class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbeddi
|
|||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||
def __init__(
|
||||
|
|
|
@ -473,17 +473,31 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
|||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -514,7 +528,8 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
|||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
cos, sin = freqs_cis
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
|
||||
|
||||
|
@ -533,80 +548,7 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
|||
|
||||
return attn_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis is not None:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||
|
|
|
@ -64,7 +64,7 @@ default_optimize_rules = {
|
|||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||||
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||||
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
@ -135,7 +135,7 @@ class Engine:
|
|||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "Glm4MoeForCausalLM":
|
||||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "SmallthinkerForCausalLM":
|
||||
elif args.model_name == "SmallThinkerForCausalLM":
|
||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
config._attn_implementation = "eager"
|
||||
else:
|
||||
|
@ -162,7 +162,7 @@ class Engine:
|
|||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
|
@ -219,7 +219,7 @@ class Engine:
|
|||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
|
|
@ -215,7 +215,7 @@ if __name__ == '__main__':
|
|||
settings = create_sched_settings_qwen3moe(main_args)
|
||||
elif main_args.architectures == "Glm4MoeForCausalLM":
|
||||
settings = create_sched_settings_glm4moe(main_args)
|
||||
elif main_args.architectures == "SmallthinkerForCausalLM":
|
||||
elif main_args.architectures == "SmallThinkerForCausalLM":
|
||||
settings = create_sched_settings_smallthinker(main_args)
|
||||
else:
|
||||
settings = create_sched_settings(main_args)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue