add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -306,6 +306,12 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
if torch.xpu.is_available() and inputs_embeds.device.type == "xpu":
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -369,6 +375,7 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if per_layer_prefill_flag:
# print(f"to cpu")
@ -376,8 +383,10 @@ class KQwen2MoeModel(BaseInjectedModule):
torch.cuda.empty_cache()
hidden_states = layer_outputs[0]
if use_cache:
if use_cache and len(layer_outputs) > 1:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
else:
next_decoder_cache = None
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -396,11 +405,14 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if next_decoder_cache is not None:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
else:
next_cache = past_key_values
if not return_dict:
return tuple(