Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build-linux-cross.yml
This commit is contained in:
Concedo 2025-07-17 18:23:26 +08:00
commit f57018f722
6 changed files with 69 additions and 60 deletions

View file

@ -7,7 +7,6 @@ import pathlib
import re import re
import requests import requests
import sys
import json import json
import shutil import shutil
import argparse import argparse
@ -69,8 +68,7 @@ args = parser.parse_args()
hf_token = args.hf_token if args.hf_token is not None else hf_token hf_token = args.hf_token if args.hf_token is not None else hf_token
if hf_token is None: if hf_token is None:
logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
sys.exit(1)
# TODO: this string has to exercise as much pre-tokenizer functionality as possible # TODO: this string has to exercise as much pre-tokenizer functionality as possible
# will be updated with time - contributions welcome # will be updated with time - contributions welcome
@ -151,7 +149,7 @@ pre_computed_hashes = [
def download_file_with_auth(url, token, save_path): def download_file_with_auth(url, token, save_path):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"} if token else None
response = sess.get(url, headers=headers) response = sess.get(url, headers=headers)
response.raise_for_status() response.raise_for_status()
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
@ -250,10 +248,9 @@ for model in [*pre_computed_hashes, *all_models]:
else: else:
# otherwise, compute the hash of the tokenizer # otherwise, compute the hash of the tokenizer
# Skip if the tokenizer folder does not exist or there are other download issues previously # Fail if the tokenizer folder with config does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"): if not os.path.isfile(f"models/tokenizers/{name}/tokenizer_config.json"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...") raise OSError(f"Config for tokenizer {name} not found. The model may not exist or is not accessible with the provided token.")
continue
try: try:
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
@ -261,9 +258,8 @@ for model in [*pre_computed_hashes, *all_models]:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else: else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e: except Exception as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") raise OSError(f"Error loading tokenizer for model {name}.") from e
continue # Skip to the next model if the tokenizer can't be loaded
chktok = tokenizer.encode(CHK_TXT) chktok = tokenizer.encode(CHK_TXT)
chkhsh = sha256(str(chktok).encode()).hexdigest() chkhsh = sha256(str(chktok).encode()).hexdigest()

View file

@ -157,6 +157,8 @@ bool llama_batch_allocr::init(
n_outputs += batch.logits[i] != 0; n_outputs += batch.logits[i] != 0;
} }
has_cpl = false;
// determine coupled sequences // determine coupled sequences
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {

View file

@ -117,7 +117,7 @@ private:
using seq_cpl_t = std::vector<bool>; using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch // helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl; bool has_cpl = false;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

View file

@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
const int64_t n_tps = n_tokens/n_stream; const int64_t n_tps = n_tokens/n_stream;
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
std::fill(data, data + ggml_nelements(dst), -INFINITY);
// Use only the previous KV cells of the correct sequence for each token of the ubatch. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
const llama_pos p1 = ubatch->pos[i]; const llama_pos p1 = ubatch->pos[i];
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
for (uint32_t j = 0; j < n_kv; ++j) { for (uint32_t j = 0; j < n_kv; ++j) {
float f = 0.0f;
bool masked = false;
if (cells.is_empty(j)) { if (cells.is_empty(j)) {
masked = true; continue;
} else {
const llama_pos p0 = cells.pos_get(j);
// mask the token if not the same sequence
masked = masked || (!cells.seq_has(j, seq_id));
// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
} }
if (masked) { // mask the token if not the same sequence
f = -INFINITY; if (!cells.seq_has(j, seq_id)) {
continue;
} }
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f; const llama_pos p0 = cells.pos_get(j);
}
// mask padded tokens // mask future tokens
if (data) { if (causal_attn && p0 > p1) {
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) { continue;
for (uint32_t j = 0; j < n_kv; ++j) {
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
}
} }
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
}
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
} }
} }
} }

View file

@ -38,9 +38,9 @@ llama_memory_hybrid::llama_memory_hybrid(
type_v, type_v,
v_trans, v_trans,
offload, offload,
1,
kv_size, kv_size,
n_seq_max, n_seq_max,
1,
n_pad, n_pad,
n_swa, n_swa,
swa_type swa_type

View file

@ -16654,7 +16654,19 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_tensor * cur, ggml_tensor * cur,
llm_graph_input_rs * inp_recr, llm_graph_input_rs * inp_recr,
int il) { int il) {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr(); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
const uint32_t kv_head = mctx_cur->get_head();
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
cb(bcx, "model.layers.{}.conv.in_proj", il); cb(bcx, "model.layers.{}.conv.in_proj", il);
@ -16662,38 +16674,48 @@ struct llm_build_lfm2 : public llm_graph_context {
constexpr auto n_chunks = 3; constexpr auto n_chunks = 3;
GGML_ASSERT(bcx->ne[0] % n_chunks == 0); GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
auto const chunk_size = bcx->ne[0] / n_chunks; auto const chunk_size = bcx->ne[0] / n_chunks;
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx)); auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx)); auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx)); auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
// read conv state directly, with build_rs generation is slower // read conv state
ggml_tensor * conv_state = mctx_cur->get_r_l(il); auto * conv_state = mctx_cur->get_r_l(il);
const int64_t n_seqs = ubatch.n_seqs; auto * conv_rs = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs); auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
bx = ggml_concat(ctx0, conv, bx, 0); bx = ggml_concat(ctx0, conv, bx, 0);
GGML_ASSERT(bx->ne[0] > conv->ne[0]); GGML_ASSERT(bx->ne[0] > conv->ne[0]);
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); // last d_conv columns is a new conv state
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
// write conv state // write new conv conv state
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state)); ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
new_conv,
ggml_view_1d(
ctx0,
conv_state,
ggml_nelements(new_conv),
kv_head*d_conv*n_embd*ggml_element_size(new_conv)
)
)
);
auto * conv_kernel = model.layers[il].shortconv.conv; auto * conv_kernel = model.layers[il].shortconv.conv;
GGML_ASSERT(hparams.n_shortconv_l_cache > 0); auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
// construct ssm_conv op
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
cb(conv_out, "model.layers.{}.conv.conv", il); cb(conv_out, "model.layers.{}.conv.conv", il);
auto * y = ggml_mul(ctx0, c, conv_out); auto * y = ggml_mul(ctx0, c, conv_out);
y = build_lora_mm(model.layers[il].shortconv.out_proj, y); y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
cb(y, "model.layers.{}.conv.out_proj", il); cb(y, "model.layers.{}.conv.out_proj", il);
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
return y; return y;
} }