fix local chat on npu

This commit is contained in:
danglinfei 2025-09-23 19:35:44 +08:00
parent 63ec4d4b4f
commit 361cbf6329
3 changed files with 117 additions and 3 deletions

View file

@ -91,7 +91,7 @@ class StaticCache(transformers.StaticCache):
self.page_table_list = []
for idx in range(config.num_hidden_layers):
if isinstance(device, dict):
target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"]
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
@ -121,7 +121,7 @@ class StaticCache(transformers.StaticCache):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
if isinstance(device, dict):
target_device = device[f"model.layers.{idx}.self_attn"]["generate_device"]
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device