smallthinker right

This commit is contained in:
qiyuxinlin 2025-07-25 12:46:14 +00:00
parent f8719ee7b9
commit 712ad1fa3c
7 changed files with 48 additions and 108 deletions

View file

@ -97,6 +97,7 @@ class SmallthinkerConfig(PretrainedConfig):
initializer_range=0.02,
**kwargs,
):
moe_layer_layout = [1]*num_hidden_layers
# Configuration sanitizers
assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads"
assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers"

View file

@ -83,7 +83,7 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
with torch.cuda.stream(current_stream):
residual = torch.zeros_like(hidden_states)
for i, decode_layer in enumerate(self.model.layers):
router_input = hidden_states.clone()
router_input = hidden_states
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
freqs_cis if self.model.rope_layout[i] else None,

View file

@ -839,7 +839,7 @@ def load_balancing_loss_func(
# @auto_docstring
class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
class SmallThinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@ -897,9 +897,9 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
Example:
```python
>>> from transformers import AutoTokenizer, SmallthinkerForCausalLM
>>> from transformers import AutoTokenizer, SmallThinkerForCausalLM
>>> model = SmallthinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
>>> model = SmallThinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
@ -1212,7 +1212,7 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
__all__ = [
"SmallthinkerForCausalLM",
"SmallThinkerForCausalLM",
"SmallthinkerForQuestionAnswering",
"SmallthinkerModel",
"SmallthinkerPreTrainedModel",

View file

@ -471,20 +471,17 @@ class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbeddi
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# print(inv_freq_expanded.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
freqs_cis = freqs_cis * self.attention_scaling
return freqs_cis
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
def __init__(

View file

@ -473,17 +473,31 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
@ -514,7 +528,8 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
print(sin.shape)
"""
if freqs_cis:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
cos, sin = freqs_cis
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
@ -533,80 +548,7 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
return attn_output
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
raise NotImplementedError("use_qk_norm is not implemented yet")
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis is not None:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):

View file

@ -64,7 +64,7 @@ default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
}
@ -135,7 +135,7 @@ class Engine:
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "Glm4MoeForCausalLM":
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallthinkerForCausalLM":
elif args.model_name == "SmallThinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager"
else:
@ -162,7 +162,7 @@ class Engine:
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallthinkerForCausalLM":
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallthinkerForCausalLM(config, self.cache)
elif config.architectures[0] == "Glm4MoeForCausalLM":
@ -219,7 +219,7 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)

View file

@ -215,7 +215,7 @@ if __name__ == '__main__':
settings = create_sched_settings_qwen3moe(main_args)
elif main_args.architectures == "Glm4MoeForCausalLM":
settings = create_sched_settings_glm4moe(main_args)
elif main_args.architectures == "SmallthinkerForCausalLM":
elif main_args.architectures == "SmallThinkerForCausalLM":
settings = create_sched_settings_smallthinker(main_args)
else:
settings = create_sched_settings(main_args)