#!/usr/bin/env python # coding=utf-8 """ Description : Author : Azure-Tang Date : 2024-07-25 11:25:24 Version : 1.0.0 LastEditors : Azure LastEditTime : 2024-08-27 07:29:04 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import inspect import math from typing import List, Optional, Tuple, Union import time import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention from ktransformers.server.config.config import Config import os import yaml from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, ) from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from ktransformers.models.modeling_qwen2_moe import ( Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer, ) from ktransformers.models.modeling_deepseek import ( BaseModelOutputWithPast, DeepseekV2DecoderLayer, DeepseekV2MoE, ) from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.utils import InferenceState, get_compute_capability from ktransformers.util.custom_gguf import GGUFLoader from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_llama import ( LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, ) if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list( inspect.signature(flash_attn_func).parameters ) logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" _CONFIG_FOR_DOC = "Qwen2MoeConfig" QWEN2MOE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen2MoeConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ QWEN2MOE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", QWEN2MOE_START_DOCSTRING, ) class KQwen2MoeModel(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Args: config: Qwen2MoeConfig """ def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, per_layer_prefill_intput_threshold: ( int | None ) = None, # if None or 0, close per-layer prefill ) -> Union[Tuple, MoeModelOutputWithPast]: # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}') if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold per_layer_prefill_flag = False seq_lenth = ( inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) ) if ( per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth ): per_layer_prefill_flag = True for layer in self.layers: self.load_layer_to(layer, InferenceState.UNLOAD) else: pass output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if inputs_embeds is None: input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds.to("cuda") if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None for i, decoder_layer in enumerate(self.layers): if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] if cur_device not in self.stream_device_map: self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.transfer_map[i], non_blocking=True ) causal_mask = ( causal_mask.to(self.transfer_map[i], non_blocking=True) if causal_mask is not None else None ) position_ids = ( position_ids.to(self.transfer_map[i], non_blocking=True) if position_ids is not None else None ) cache_position = ( cache_position.to(self.transfer_map[i], non_blocking=True) if cache_position is not None else None ) if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, ) else: if per_layer_prefill_flag: # print(f"to gpu") self.load_layer_to(decoder_layer, InferenceState.PREFILL) torch.cuda.empty_cache() layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, ) if per_layer_prefill_flag: # print(f"to cpu") self.load_layer_to(decoder_layer, InferenceState.UNLOAD) torch.cuda.empty_cache() hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits and layer_outputs[-1] is not None: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if per_layer_prefill_flag: per_layer_prefill_flag = False for layer in self.layers: self.load_layer_to(layer, InferenceState.GENERATE) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [ hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits, ] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState): assert isinstance( layer, Qwen2MoeDecoderLayer ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda device = "cpu" if target == InferenceState.UNLOAD else "cuda" # attn layer.self_attn.q_proj.set_inference_mode(target) layer.self_attn.k_proj.set_inference_mode(target) layer.self_attn.v_proj.set_inference_mode(target) layer.self_attn.o_proj.set_inference_mode(target) layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) # mlp if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): layer.mlp.gate.set_inference_mode(target) layer.mlp.experts.set_inference_mode(target) layer.mlp.shared_expert.gate_proj.set_inference_mode(target) layer.mlp.shared_expert.up_proj.set_inference_mode(target) layer.mlp.shared_expert.down_proj.set_inference_mode(target) layer.mlp.shared_expert.act_fn.to(device) layer.mlp.shared_expert_gate.to(device) else: layer.mlp.gate_proj.set_inference_mode(target) layer.mlp.up_proj.set_inference_mode(target) layer.mlp.down_proj.set_inference_mode(target) layer.mlp.act_fn.to(device) # layer norm layer.input_layernorm.to(device) layer.post_attention_layernorm.to(device) DeepseekV2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class KDeepseekV2Model(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Args: config: DeepseekV2Config """ def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, per_layer_prefill_intput_threshold: ( int | None ) = None, # if None, no per-layer prefill ) -> Union[Tuple, BaseModelOutputWithPast]: if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold per_layer_prefill_flag = False seq_lenth = ( inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) ) if ( per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth ): per_layer_prefill_flag = True for layer in self.layers: self.load_layer_to(layer, InferenceState.UNLOAD) torch.cuda.empty_cache() else: pass output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." ) use_cache = False past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if inputs_embeds is None: org_device = input_ids.device # TODO move to embed_tokens's device, not hard code to cpu input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids).to(org_device) input_ids = input_ids.to(org_device) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if per_layer_prefill_flag: causal_mask = None else: if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: # print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) else: causal_mask = None # embed positions hidden_states = inputs_embeds if per_layer_prefill_flag: print(f"Total length of input_ids: {hidden_states.size(1)}") # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None t_gpu = 0 t_cpu = 0 t_f = 0 for i, decoder_layer in enumerate(self.layers): # print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] if cur_device not in self.stream_device_map and cur_device.lower() != "cpu": self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) if cur_device.lower() != "cpu": torch.cuda.set_device(cur_device) self.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.transfer_map[i], non_blocking=True ) causal_mask = ( causal_mask.to(self.transfer_map[i], non_blocking=True) if causal_mask is not None else None ) position_ids = ( position_ids.to(self.transfer_map[i], non_blocking=True) if position_ids is not None else None ) cache_position = ( cache_position.to(self.transfer_map[i], non_blocking=True) if cache_position is not None else None ) if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: t3 = time.time() if per_layer_prefill_flag: # print(f"to gpu") self.load_layer_to(decoder_layer, InferenceState.PREFILL) torch.cuda.empty_cache() t4 = time.time() # with open("log.txt", "a") as f: # f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) t5 = time.time() if per_layer_prefill_flag: # print(f"to cpu") self.load_layer_to(decoder_layer, InferenceState.UNLOAD) torch.cuda.empty_cache() t6 = time.time() t_gpu += t4 - t3 t_cpu += t6 - t5 t_f += t5 - t4 hidden_states = layer_outputs[0] # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # with open("log.txt", "a") as f: # f.write(f"@@@After layers\n") # f.write(f"hidden_states={hidden_states}\n") # f.write(f"hidden_states.shape={hidden_states.shape}\n") if per_layer_prefill_flag: t6 = time.time() # print(f"restore") per_layer_prefill_flag = False for layer in self.layers: self.load_layer_to(layer, InferenceState.GENERATE) torch.cuda.empty_cache() t7 = time.time() print( f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}" ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState): assert isinstance( layer, DeepseekV2DecoderLayer ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda device = "cpu" if target == InferenceState.UNLOAD else "cuda" # TODO Support DFS to auto use {to, set_inference_mode} according to the module type # attn layer.self_attn.to(device) # # mlp if isinstance(layer.mlp, DeepseekV2MoE): layer.mlp.gate.to(device) layer.mlp.experts.set_inference_mode(target) layer.mlp.shared_experts.gate_proj.set_inference_mode(target) layer.mlp.shared_experts.up_proj.set_inference_mode(target) layer.mlp.shared_experts.down_proj.set_inference_mode(target) layer.mlp.shared_experts.act_fn.to(device) # layer.mlp.shared_expert_gate.to(device) else: layer.mlp.gate_proj.set_inference_mode(target) layer.mlp.up_proj.set_inference_mode(target) layer.mlp.down_proj.set_inference_mode(target) layer.mlp.act_fn.to(device) # layer norm layer.input_layernorm.to(device) layer.post_attention_layernorm.to(device) LLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ LLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class KLlamaModel(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ dynamic_sdpa = None def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() user_path: str = os.path.expanduser('~') localstore_path: str = os.path.join(user_path,'.ktransformers') config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME) with open(config_path,"r") as file: config_yaml = yaml.safe_load(file.read()) self.long_context_config = config_yaml.get("long_context") self.ext_config = config_yaml.get("ext") KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention( max_seq_len=self.long_context_config["max_seq_len"], block_size=self.long_context_config["block_size"], config=config, device=torch.device("cuda"), local_windows_len=self.long_context_config["local_windows_len"], topk=self.long_context_config["second_select_num"], threads_num=self.ext_config["cpu_infer"], anchor_type=self.long_context_config["anchor_type"], kv_type=self.long_context_config["kv_type"], dense_layer_num=self.long_context_config["dense_layer_num"], anchor_num=self.long_context_config["anchor_num"], preselect_block=self.long_context_config["preselect_block"], block_selection_mode=self.long_context_config["head_select_mode"], preselect_block_count=self.long_context_config["preselect_block_count"], layer_step=self.long_context_config["layer_step"], token_step=self.long_context_config["token_step"], prefill_chunk_size=self.long_context_config["chunk_size"], use_attn_sparsity=False, ) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False return_legacy_cache = False if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device="cuda", ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = None chunck_size = self.long_context_config["chunk_size"] cur_idx = 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.to("cpu")) q_len = cache_position.size(0) # generate if q_len == 1: x = inputs_embeds[:, -1:, :] position_ids = position_ids[:, -1:] return self.forward_chunk( x, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states, return_dict, ) elif q_len <= chunck_size: inputs_embeds = inputs_embeds.to('cuda') output = self.forward_chunk( inputs_embeds, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states, return_dict, ) KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1) KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1) return output cur_idx = 0 assert ( output_attentions == False ), "output_attentions is not supported when using chunked attention" attn_output = None # prefill KLlamaModel.dynamic_sdpa.remaining_length = q_len while cur_idx < q_len: print(f'current prefill length: {cur_idx}') chunk_mask = None if inputs_embeds.device.type == 'cpu': tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda") else: tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)] output_with_past = self.forward_chunk( tmp_inputs_embeds, chunk_mask, position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)], past_key_values, output_attentions, use_cache, cache_position[cur_idx : min(cur_idx + chunck_size, q_len)], ) cur_output = output_with_past.last_hidden_state KLlamaModel.dynamic_sdpa.remaining_length -= ( min(cur_idx + chunck_size, q_len) - cur_idx ) cur_idx += chunck_size # if attn_output is None: attn_output = cur_output # else: # attn_output = torch.cat((attn_output, cur_output), dim=-2) KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1) KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1) return BaseModelOutputWithPast(last_hidden_state=attn_output) def forward_chunk( self, inputs_embeds, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_legacy_cache = False if use_cache and not isinstance( past_key_values, Cache ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError( "Custom 4D attention mask should be passed in inverted form with max==0`" ) causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand( input_tensor.shape[0], 1, -1, -1 ) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = ( causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask