Fixed some GGUFv1 loading bugs, long overdue cleanup for compiling, integrated TTS

tts is functional (+6 squashed commit)

Squashed commit:

[22396311] wip tts

[3a883027] tts not yet working

[0dcfab0e] fix silly bug

[a378d9ef] some long overdue cleanup

[fc5a6fb5] Wip tts

[39f50497] wip TTS integration
This commit is contained in:
Concedo 2025-01-12 16:33:02 +08:00
parent 12cdcf0abe
commit b3de1598e7
17 changed files with 1175 additions and 271 deletions

View file

@ -1,5 +1,6 @@
#include "utils.h"
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstring>
@ -303,6 +304,47 @@ std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string)
return ret;
}
std::string kcpp_base64_encode(const unsigned char* data, unsigned int data_length) {
const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string encoded;
encoded.reserve(((data_length + 2) / 3) * 4);
for (unsigned int i = 0; i < data_length; i += 3) {
unsigned int triple = (data[i] << 16) + (i + 1 < data_length ? data[i + 1] << 8 : 0) + (i + 2 < data_length ? data[i + 2] : 0);
encoded.push_back(base64_chars[(triple >> 18) & 0x3F]);
encoded.push_back(base64_chars[(triple >> 12) & 0x3F]);
if (i + 1 < data_length) {
encoded.push_back(base64_chars[(triple >> 6) & 0x3F]);
} else {
encoded.push_back('=');
}
if (i + 2 < data_length) {
encoded.push_back(base64_chars[triple & 0x3F]);
} else {
encoded.push_back('=');
}
}
return encoded;
}
std::string kcpp_base64_encode(const std::string &data) {
static const char lookup[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string encoded;
int val = 0, valb = -6;
for (unsigned char c : data) {
val = (val << 8) + c;
valb += 8;
while (valb >= 0) {
encoded.push_back(lookup[(val >> valb) & 0x3F]);
valb -= 6;
}
}
if (valb > -6) {
encoded.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]);
}
while (encoded.size() % 4) {
encoded.push_back('=');
}
return encoded;
}
std::string get_timestamp_str()
{
@ -314,3 +356,150 @@ std::string get_timestamp_str()
std::string timestamp(buffer);
return timestamp;
}
//a very rudimentary all in one sampling function which has no dependencies
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng)
{
if (temp <= 0 || top_k==1) {
// select the token with the highest logit directly
float max_logit = logits[0];
int32_t max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (logits[i] > max_logit) {
max_logit = logits[i];
max_id = i;
}
}
return max_id;
}
top_k = (top_k<=0 || top_k>300)?300:top_k;
top_k = std::min(top_k, n_logits);
std::vector<std::pair<float, int32_t>> logits_id;
logits_id.reserve(n_logits);
//temperature sample
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
//sample top_k
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
// compute probs for the top k tokens
std::vector<float> probs;
probs.reserve(logits_id.size());
float maxl = logits_id[0].first;
double sum = 0.0;
for (const auto & kv : logits_id) {
const float p = expf(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
kcpp_embd_batch::kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast, bool use_mrope)
{
int32_t seq_id = 0;
pos.resize(n_tokens * (use_mrope?4:1));
std::fill(pos.begin(), pos.end(), 0);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
if(!use_mrope)
{
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
else
{
for (int i = 0; i < n_tokens; i++) {
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
for (int j = 0; j < batch.n_tokens * 3; j++) {
batch.pos[j] = npast + (j % batch.n_tokens);
}
}
}
kcpp_embd_batch::kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast, bool use_mrope, bool return_all_logits)
{
int32_t seq_id = 0;
int32_t n_tokens = tokens.size();
pos.resize(n_tokens * (use_mrope?4:1));
std::fill(pos.begin(), pos.end(), 0);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids[n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
if(!use_mrope)
{
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = (return_all_logits?true:false);
}
}
else
{
for (int i = 0; i < n_tokens; i++) {
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = (return_all_logits?true:false);
}
for (int j = 0; j < batch.n_tokens * 3; j++) {
batch.pos[j] = npast + (j % batch.n_tokens);
}
}
batch.logits[n_tokens - 1] = true;
}