mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
update rope calculation; update modeling.py; update gate for moe
This commit is contained in:
parent
5a50b34627
commit
f873558a89
11 changed files with 402 additions and 412 deletions
|
@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
|
|||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
|
@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
|
|||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
factor = config.rope_scaling["factor"]
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# seq_len: default to max_position_embeddings, e.g. at init time
|
||||
seq_len = seq_len if seq_len is not None else max_position_embeddings
|
||||
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
||||
|
||||
# Compute the inverse frequencies
|
||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||
|
@ -185,15 +187,33 @@ def _compute_yarn_parameters(
|
|||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = config.qk_rope_head_dim
|
||||
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
factor = config.rope_scaling["factor"]
|
||||
attention_factor = config.rope_scaling.get("attention_factor")
|
||||
mscale = config.rope_scaling.get("mscale")
|
||||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
||||
|
||||
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
|
||||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||
if "original_max_position_embeddings" in config.rope_scaling:
|
||||
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
|
||||
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||
else:
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
def get_mscale(scale, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
attention_factor = config.rope_scaling.get("attention_factor")
|
||||
if attention_factor is None:
|
||||
attention_factor = 0.1 * math.log(factor) + 1.0
|
||||
if mscale and mscale_all_dim:
|
||||
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||
else:
|
||||
attention_factor = get_mscale(factor)
|
||||
|
||||
# Optional config options
|
||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||
|
@ -211,7 +231,7 @@ def _compute_yarn_parameters(
|
|||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
|
@ -219,16 +239,20 @@ def _compute_yarn_parameters(
|
|||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||
|
||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
|
@ -244,7 +268,7 @@ def _compute_longrope_parameters(
|
|||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
seq_len (`int`, *optional*):
|
||||
The current sequence length. Unused for this type of RoPE.
|
||||
The current sequence length.
|
||||
rope_kwargs (`Dict`, *optional*):
|
||||
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
||||
Returns:
|
||||
|
@ -261,7 +285,8 @@ def _compute_longrope_parameters(
|
|||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
long_factor = config.rope_scaling["long_factor"]
|
||||
short_factor = config.rope_scaling["short_factor"]
|
||||
factor = config.rope_scaling.get("factor")
|
||||
|
@ -271,22 +296,20 @@ def _compute_longrope_parameters(
|
|||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||
if hasattr(config, "original_max_position_embeddings"):
|
||||
max_position_embeddings = config.original_max_position_embeddings
|
||||
expanded_max_position_embeddings = config.max_position_embeddings
|
||||
factor = expanded_max_position_embeddings / max_position_embeddings
|
||||
original_max_position_embeddings = config.original_max_position_embeddings
|
||||
factor = config.max_position_embeddings / config.original_max_position_embeddings
|
||||
else:
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
expanded_max_position_embeddings = max_position_embeddings * factor
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
if attention_factor is None:
|
||||
if factor <= 1.0:
|
||||
attention_factor = 1.0
|
||||
else:
|
||||
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
||||
attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
|
||||
|
||||
# Compute the inverse frequencies -- scaled based on the target sequence length
|
||||
if expanded_max_position_embeddings > max_position_embeddings:
|
||||
if seq_len and seq_len > original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||
else:
|
||||
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
||||
|
@ -325,19 +348,18 @@ def _compute_llama3_parameters(
|
|||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in inv_freq:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
|
||||
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
wavelen = 2 * math.pi / inv_freq
|
||||
# wavelen < high_freq_wavelen: do nothing
|
||||
# wavelen > low_freq_wavelen: divide by factor
|
||||
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
||||
# otherwise: interpolate between the two, using a smooth factor
|
||||
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
||||
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
||||
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||
|
||||
return inv_freq_llama, attention_factor
|
||||
|
||||
|
||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||
|
@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
|
|||
}
|
||||
|
||||
|
||||
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
||||
def _check_received_keys(
|
||||
rope_type: str,
|
||||
received_keys: set,
|
||||
required_keys: set,
|
||||
optional_keys: Optional[set] = None,
|
||||
ignore_keys: Optional[set] = None,
|
||||
):
|
||||
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
||||
# BC: "rope_type" was originally "type" -- let's gracefully handle it
|
||||
if "rope_type" not in received_keys and "type" in received_keys:
|
||||
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
||||
if "type" in received_keys:
|
||||
received_keys -= {"type"}
|
||||
received_keys.add("rope_type")
|
||||
required_keys.add("rope_type")
|
||||
|
||||
# Some models need to store model-specific keys, and we don't want to throw warning at them
|
||||
if ignore_keys is not None:
|
||||
received_keys -= ignore_keys
|
||||
|
||||
missing_keys = required_keys - received_keys
|
||||
if missing_keys:
|
||||
|
@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
|
|||
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||
|
||||
|
||||
def _validate_default_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
|
||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
||||
optional_keys = {
|
||||
"attention_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"original_max_position_embeddings",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
|
@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
|||
)
|
||||
|
||||
|
||||
def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
short_factor = rope_scaling.get("short_factor")
|
||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||
|
@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||
)
|
||||
if attention_factor is not None:
|
||||
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
|
@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
|
|||
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||
if high_freq_factor < low_freq_factor:
|
||||
if high_freq_factor <= low_freq_factor:
|
||||
logger.warning(
|
||||
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||
|
@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
|
|||
}
|
||||
|
||||
|
||||
def rope_config_validation(config: PretrainedConfig):
|
||||
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
"""
|
||||
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
||||
"""
|
||||
|
@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
|
|||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
||||
if validation_fn is not None:
|
||||
validation_fn(config)
|
||||
validation_fn(config, ignore_keys=ignore_keys)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
||||
)
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue