mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 15:54:37 +00:00
smallthinker right
This commit is contained in:
parent
f8719ee7b9
commit
712ad1fa3c
7 changed files with 48 additions and 108 deletions
|
@ -97,6 +97,7 @@ class SmallthinkerConfig(PretrainedConfig):
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
moe_layer_layout = [1]*num_hidden_layers
|
||||||
# Configuration sanitizers
|
# Configuration sanitizers
|
||||||
assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads"
|
assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads"
|
||||||
assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers"
|
assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers"
|
||||||
|
|
|
@ -83,7 +83,7 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
||||||
with torch.cuda.stream(current_stream):
|
with torch.cuda.stream(current_stream):
|
||||||
residual = torch.zeros_like(hidden_states)
|
residual = torch.zeros_like(hidden_states)
|
||||||
for i, decode_layer in enumerate(self.model.layers):
|
for i, decode_layer in enumerate(self.model.layers):
|
||||||
router_input = hidden_states.clone()
|
router_input = hidden_states
|
||||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||||
freqs_cis if self.model.rope_layout[i] else None,
|
freqs_cis if self.model.rope_layout[i] else None,
|
||||||
|
|
|
@ -839,7 +839,7 @@ def load_balancing_loss_func(
|
||||||
|
|
||||||
|
|
||||||
# @auto_docstring
|
# @auto_docstring
|
||||||
class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
class SmallThinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
@ -897,9 +897,9 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, SmallthinkerForCausalLM
|
>>> from transformers import AutoTokenizer, SmallThinkerForCausalLM
|
||||||
|
|
||||||
>>> model = SmallthinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
>>> model = SmallThinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Smallthinker-8x7B-v0.1")
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
@ -1212,7 +1212,7 @@ class SmallthinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SmallthinkerForCausalLM",
|
"SmallThinkerForCausalLM",
|
||||||
"SmallthinkerForQuestionAnswering",
|
"SmallthinkerForQuestionAnswering",
|
||||||
"SmallthinkerModel",
|
"SmallthinkerModel",
|
||||||
"SmallthinkerPreTrainedModel",
|
"SmallthinkerPreTrainedModel",
|
||||||
|
|
|
@ -471,20 +471,17 @@ class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbeddi
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids):
|
def forward(self, x, position_ids):
|
||||||
if "dynamic" in self.rope_type:
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
|
||||||
# Core RoPE block
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
||||||
# print(inv_freq_expanded.device)
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
|
||||||
device_type = x.device.type
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
with torch.autocast(device_type=device_type, enabled=False):
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
cos = emb.cos() * self.attention_scaling
|
||||||
freqs_cis = freqs_cis * self.attention_scaling
|
sin = emb.sin() * self.attention_scaling
|
||||||
return freqs_cis
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -473,17 +473,31 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||||
orig_module.layer_idx)
|
orig_module.layer_idx)
|
||||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
self,
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
xq: torch.Tensor,
|
|
||||||
xk: torch.Tensor,
|
Args:
|
||||||
freqs_cis: torch.Tensor,
|
q (`torch.Tensor`): The query tensor.
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
k (`torch.Tensor`): The key tensor.
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
Deprecated and unused.
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -514,7 +528,8 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||||
print(sin.shape)
|
print(sin.shape)
|
||||||
"""
|
"""
|
||||||
if freqs_cis:
|
if freqs_cis:
|
||||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
cos, sin = freqs_cis
|
||||||
|
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -533,80 +548,7 @@ class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
|
||||||
def __init__(self,
|
|
||||||
key: str,
|
|
||||||
gguf_loader : GGUFLoader,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
orig_module: nn.Module,
|
|
||||||
prefill_device: str = "cuda",
|
|
||||||
generate_device: str = "cuda",
|
|
||||||
chunck_size: int = 1000,
|
|
||||||
**kwargs):
|
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
||||||
self.orig_module.__init__(orig_module.config,
|
|
||||||
orig_module.layer_idx)
|
|
||||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
self,
|
|
||||||
xq: torch.Tensor,
|
|
||||||
xk: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: KGQACache,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
wrapper: flashInferAttn,
|
|
||||||
bsz_tensors: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
|
||||||
|
|
||||||
q_len, _ = hidden_states.size()
|
|
||||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
|
||||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
|
||||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
|
||||||
|
|
||||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
|
||||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
|
|
||||||
# cos, sin = freqs_cis
|
|
||||||
"""
|
|
||||||
print(query_states.shape)
|
|
||||||
print(key_states.shape)
|
|
||||||
print(cos.shape)
|
|
||||||
print(sin.shape)
|
|
||||||
"""
|
|
||||||
if freqs_cis is not None:
|
|
||||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
|
||||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
|
|
||||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
|
||||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
|
||||||
|
|
||||||
|
|
||||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
|
||||||
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||||
|
|
|
@ -64,7 +64,7 @@ default_optimize_rules = {
|
||||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||||
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||||||
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ class Engine:
|
||||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
elif args.model_name == "Glm4MoeForCausalLM":
|
elif args.model_name == "Glm4MoeForCausalLM":
|
||||||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
elif args.model_name == "SmallthinkerForCausalLM":
|
elif args.model_name == "SmallThinkerForCausalLM":
|
||||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
config._attn_implementation = "eager"
|
config._attn_implementation = "eager"
|
||||||
else:
|
else:
|
||||||
|
@ -162,7 +162,7 @@ class Engine:
|
||||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||||
else:
|
else:
|
||||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||||
elif config.architectures[0] == "SmallthinkerForCausalLM":
|
elif config.architectures[0] == "SmallThinkerForCausalLM":
|
||||||
self.cache = KGQACache(config, self.args.page_size)
|
self.cache = KGQACache(config, self.args.page_size)
|
||||||
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
||||||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
|
@ -219,7 +219,7 @@ class Engine:
|
||||||
self.block_num = inference_context.k_cache[0].size(1)
|
self.block_num = inference_context.k_cache[0].size(1)
|
||||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||||
#@TODO add config
|
#@TODO add config
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM":
|
||||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||||
else:
|
else:
|
||||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||||
|
|
|
@ -215,7 +215,7 @@ if __name__ == '__main__':
|
||||||
settings = create_sched_settings_qwen3moe(main_args)
|
settings = create_sched_settings_qwen3moe(main_args)
|
||||||
elif main_args.architectures == "Glm4MoeForCausalLM":
|
elif main_args.architectures == "Glm4MoeForCausalLM":
|
||||||
settings = create_sched_settings_glm4moe(main_args)
|
settings = create_sched_settings_glm4moe(main_args)
|
||||||
elif main_args.architectures == "SmallthinkerForCausalLM":
|
elif main_args.architectures == "SmallThinkerForCausalLM":
|
||||||
settings = create_sched_settings_smallthinker(main_args)
|
settings = create_sched_settings_smallthinker(main_args)
|
||||||
else:
|
else:
|
||||||
settings = create_sched_settings(main_args)
|
settings = create_sched_settings(main_args)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue