[feature] release 0.1.3

This commit is contained in:
chenxl 2024-08-28 16:11:43 +00:00
parent 67f8b370c3
commit 4d1d561d28
58 changed files with 11709 additions and 374 deletions

View file

@ -1,14 +1,14 @@
#!/usr/bin/env python
# coding=utf-8
'''
"""
Description :
Author : Azure-Tang
Date : 2024-07-25 11:25:24
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-08-14 14:53:05
LastEditTime : 2024-08-27 07:29:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
"""
import inspect
import math
@ -19,7 +19,10 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
from ktransformers.server.config.config import Config
import os
import yaml
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
@ -40,19 +43,35 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer
from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, DeepseekV2DecoderLayer, DeepseekV2MoE
from ktransformers.models.modeling_qwen2_moe import (
Qwen2MoeSparseMoeBlock,
Qwen2MoeMLP,
Qwen2MoeDecoderLayer,
)
from ktransformers.models.modeling_deepseek import (
BaseModelOutputWithPast,
DeepseekV2DecoderLayer,
DeepseekV2MoE,
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
_flash_supports_window_size = "window_size" in list(
inspect.signature(flash_attn_func).parameters
)
logger = logging.get_logger(__name__)
@ -151,6 +170,7 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
the complete sequence length.
"""
@add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING,
@ -162,18 +182,21 @@ class KQwen2MoeModel(BaseInjectedModule):
Args:
config: Qwen2MoeConfig
"""
def __init__(
self,
key: str,
gguf_loader : GGUFLoader,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@ -192,29 +215,47 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
per_layer_prefill_intput_threshold: int | None = None, # if None or 0, close per-layer prefill
per_layer_prefill_intput_threshold: (
int | None
) = None, # if None or 0, close per-layer prefill
) -> Union[Tuple, MoeModelOutputWithPast]:
# print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')
if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
if per_layer_prefill_intput_threshold is None:
per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
seq_lenth = (
inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
)
if (
per_layer_prefill_intput_threshold
and per_layer_prefill_intput_threshold < seq_lenth
):
per_layer_prefill_flag = True
for layer in self.layers:
self.load_layer_to(layer, InferenceState.UNLOAD)
else:
pass
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
@ -243,15 +284,23 @@ class KQwen2MoeModel(BaseInjectedModule):
inputs_embeds = inputs_embeds.to("cuda")
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
@ -263,7 +312,7 @@ class KQwen2MoeModel(BaseInjectedModule):
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
@ -271,11 +320,25 @@ class KQwen2MoeModel(BaseInjectedModule):
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
hidden_states = hidden_states.to(
self.transfer_map[i], non_blocking=True
)
causal_mask = (
causal_mask.to(self.transfer_map[i], non_blocking=True)
if causal_mask is not None
else None
)
position_ids = (
position_ids.to(self.transfer_map[i], non_blocking=True)
if position_ids is not None
else None
)
cache_position = (
cache_position.to(self.transfer_map[i], non_blocking=True)
if cache_position is not None
else None
)
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -323,7 +386,6 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states = self.norm(hidden_states)
if per_layer_prefill_flag:
per_layer_prefill_flag = False
for layer in self.layers:
@ -333,12 +395,22 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
@ -349,11 +421,13 @@ class KQwen2MoeModel(BaseInjectedModule):
router_logits=all_router_logits,
)
def load_layer_to(self, layer:Qwen2MoeDecoderLayer, target: InferenceState):
assert isinstance(layer, Qwen2MoeDecoderLayer), "module should be nn.ModuleList of decoder layers"
def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):
assert isinstance(
layer, Qwen2MoeDecoderLayer
), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# attn
layer.self_attn.q_proj.set_inference_mode(target)
@ -458,18 +532,21 @@ class KDeepseekV2Model(BaseInjectedModule):
Args:
config: DeepseekV2Config
"""
def __init__(
self,
key: str,
gguf_loader : GGUFLoader,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@ -487,15 +564,23 @@ class KDeepseekV2Model(BaseInjectedModule):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
per_layer_prefill_intput_threshold: int | None = None, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: (
int | None
) = None, # if None, no per-layer prefill
) -> Union[Tuple, BaseModelOutputWithPast]:
if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
if per_layer_prefill_intput_threshold is None:
per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
seq_lenth = (
inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
)
if (
per_layer_prefill_intput_threshold
and per_layer_prefill_intput_threshold < seq_lenth
):
per_layer_prefill_flag = True
for layer in self.layers:
self.load_layer_to(layer, InferenceState.UNLOAD)
self.load_layer_to(layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
else:
pass
@ -542,9 +627,13 @@ class KDeepseekV2Model(BaseInjectedModule):
past_key_values_length = past_key_values.get_usable_length(seq_length)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
@ -556,15 +645,17 @@ class KDeepseekV2Model(BaseInjectedModule):
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.to(org_device)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
if per_layer_prefill_flag:
causal_mask = None
else:
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
if per_layer_prefill_flag:
print(f'Total length of input_ids: {hidden_states.size(1)}')
print(f"Total length of input_ids: {hidden_states.size(1)}")
# decoder layers
all_hidden_states = () if output_hidden_states else None
@ -576,7 +667,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f = 0
for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
@ -584,10 +675,24 @@ class KDeepseekV2Model(BaseInjectedModule):
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
hidden_states = hidden_states.to(
self.transfer_map[i], non_blocking=True
)
causal_mask = (
causal_mask.to(self.transfer_map[i], non_blocking=True)
if causal_mask is not None
else None
)
position_ids = (
position_ids.to(self.transfer_map[i], non_blocking=True)
if position_ids is not None
else None
)
cache_position = (
cache_position.to(self.transfer_map[i], non_blocking=True)
if cache_position is not None
else None
)
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -622,12 +727,12 @@ class KDeepseekV2Model(BaseInjectedModule):
t5 = time.time()
if per_layer_prefill_flag:
# print(f"to cpu")
self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
t6 = time.time()
t_gpu += t4-t3
t_cpu += t6-t5
t_f += t5-t4
t_gpu += t4 - t3
t_cpu += t6 - t5
t_f += t5 - t4
hidden_states = layer_outputs[0]
@ -648,7 +753,9 @@ class KDeepseekV2Model(BaseInjectedModule):
torch.cuda.empty_cache()
t7 = time.time()
print(f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}")
print(
f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}"
)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -674,16 +781,18 @@ class KDeepseekV2Model(BaseInjectedModule):
attentions=all_self_attns,
)
def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
assert isinstance(layer, DeepseekV2DecoderLayer), "module should be nn.ModuleList of decoder layers"
def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
assert isinstance(
layer, DeepseekV2DecoderLayer
), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# TODO Support DFS to auto use {to, set_inference_mode} according to the module type
# attn
layer.self_attn.to(device) #
layer.self_attn.to(device) #
# mlp
if isinstance(layer.mlp, DeepseekV2MoE):
@ -702,3 +811,526 @@ class KDeepseekV2Model(BaseInjectedModule):
# layer norm
layer.input_layernorm.to(device)
layer.post_attention_layernorm.to(device)
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class KLlamaModel(BaseInjectedModule):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
dynamic_sdpa = None
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
user_path: str = os.path.expanduser('~')
localstore_path: str = os.path.join(user_path,'.ktransformers')
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
with open(config_path,"r") as file:
config_yaml = yaml.safe_load(file.read())
self.long_context_config = config_yaml.get("long_context")
self.ext_config = config_yaml.get("ext")
KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(
max_seq_len=self.long_context_config["max_seq_len"],
block_size=self.long_context_config["block_size"],
config=config,
device=torch.device("cuda"),
local_windows_len=self.long_context_config["local_windows_len"],
topk=self.long_context_config["second_select_num"],
threads_num=self.ext_config["cpu_infer"],
anchor_type=self.long_context_config["anchor_type"],
kv_type=self.long_context_config["kv_type"],
dense_layer_num=self.long_context_config["dense_layer_num"],
anchor_num=self.long_context_config["anchor_num"],
preselect_block=self.long_context_config["preselect_block"],
block_selection_mode=self.long_context_config["head_select_mode"],
preselect_block_count=self.long_context_config["preselect_block_count"],
layer_step=self.long_context_config["layer_step"],
token_step=self.long_context_config["token_step"],
prefill_chunk_size=self.long_context_config["chunk_size"],
use_attn_sparsity=False,
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device="cuda",
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = None
chunck_size = self.long_context_config["chunk_size"]
cur_idx = 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.to("cpu"))
q_len = cache_position.size(0)
# generate
if q_len == 1:
x = inputs_embeds[:, -1:, :]
position_ids = position_ids[:, -1:]
return self.forward_chunk(
x,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states,
return_dict,
)
elif q_len <= chunck_size:
inputs_embeds = inputs_embeds.to('cuda')
output = self.forward_chunk(
inputs_embeds,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states,
return_dict,
)
KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
return output
cur_idx = 0
assert (
output_attentions == False
), "output_attentions is not supported when using chunked attention"
attn_output = None
# prefill
KLlamaModel.dynamic_sdpa.remaining_length = q_len
while cur_idx < q_len:
print(f'current prefill length: {cur_idx}')
chunk_mask = None
if inputs_embeds.device.type == 'cpu':
tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda")
else:
tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]
output_with_past = self.forward_chunk(
tmp_inputs_embeds,
chunk_mask,
position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],
past_key_values,
output_attentions,
use_cache,
cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],
)
cur_output = output_with_past.last_hidden_state
KLlamaModel.dynamic_sdpa.remaining_length -= (
min(cur_idx + chunck_size, q_len) - cur_idx
)
cur_idx += chunck_size
# if attn_output is None:
attn_output = cur_output
# else:
# attn_output = torch.cat((attn_output, cur_output), dim=-2)
KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
return BaseModelOutputWithPast(last_hidden_state=attn_output)
def forward_chunk(
self,
inputs_embeds,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_legacy_cache = False
if use_cache and not isinstance(
past_key_values, Cache
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask