tts can now set a length limit

This commit is contained in:
Concedo 2025-01-28 22:06:59 +08:00
commit 558bc5c901
8 changed files with 235 additions and 119 deletions

View file

@ -214,6 +214,7 @@ struct tts_load_model_inputs
const char * vulkan_info = nullptr; const char * vulkan_info = nullptr;
const int gpulayers = 0; const int gpulayers = 0;
const bool flash_attention = false; const bool flash_attention = false;
const int ttsmaxlen = 4096;
const bool quiet = false; const bool quiet = false;
const int debugmode = 0; const int debugmode = 0;
}; };

View file

@ -46,20 +46,20 @@
#define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 1000000 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64 // GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220 #define GGML_CUDA_CC_QY2 220

View file

@ -121,6 +121,55 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
#endif #endif
} }
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static int ggml_cuda_parse_id(char devName[]) {
// A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
// these values are not stable so this is susceptible to breakage
// https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
int archMajor = 0x0;
int archMinor = 0x0;
int archNum = GGML_CUDA_CC_OFFSET_AMD;
int archLen = strlen(devName);
char archName[archLen + 1];
// strip leading 'gfx' while copying into our buffer
if (archLen > 3) {
strcpy(archName, &devName[3]);
archLen -= 3;
}
// trim trailing :xnack- or :sramecc- statuses
archLen = strcspn(archName, ":");
archName[archLen] = '\0';
// tease out the version information
if (archLen > 8) {
// versions labeled generic use '-' as delimiter
// strip the trailing "-generic" then iterate through what remains
if ((strstr(archName, "-generic"))) {
archName[archLen - 8] = '\0';
char * pch;
if ((pch = strtok(archName, "-"))) {
archMajor = (int)strtoul(pch, 0, 16);
if ((pch = strtok(NULL, "-"))) {
archMinor = 0x10 * (int)strtoul(pch, 0, 16);
}
}
}
} else if (archLen >= 3) {
// last two digits should be the minor * 0x10 + stepping
archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
archName[archLen - 2] = '\0';
// only the major version remains
archMajor = (int)strtoul(archName, 0, 16);
}
archNum += archMajor * 0x100;
archNum += archMinor;
return archNum;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static ggml_cuda_device_info ggml_cuda_init() { static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards: // Workaround for a rocBLAS bug when using multiple graphics cards:
@ -172,7 +221,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
info.default_tensor_split[id] = total_vram; info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem; total_vram += prop.totalGlobalMem;
@ -181,10 +229,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
if ((info.devices[id].cc & 0xff00) == 0x0) {
GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
id, prop.name, prop.gcnArchName, prop.major, prop.minor);
// Fallback to prop.major and prop.minor
if (prop.major > 0) {
info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
}
}
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no");
#else #else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor; info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
} }

View file

@ -291,6 +291,7 @@ class tts_load_model_inputs(ctypes.Structure):
("vulkan_info", ctypes.c_char_p), ("vulkan_info", ctypes.c_char_p),
("gpulayers", ctypes.c_int), ("gpulayers", ctypes.c_int),
("flash_attention", ctypes.c_bool), ("flash_attention", ctypes.c_bool),
("ttsmaxlen", ctypes.c_int),
("quiet", ctypes.c_bool), ("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)] ("debugmode", ctypes.c_int)]
@ -1451,6 +1452,7 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
if ttst > 0: if ttst > 0:
thds = ttst thds = ttst
inputs.threads = thds inputs.threads = thds
inputs.ttsmaxlen = args.ttsmaxlen if args.ttsmaxlen < 4096 else 4096
inputs = set_backend_props(inputs) inputs = set_backend_props(inputs)
ret = handle.tts_load_model(inputs) ret = handle.tts_load_model(inputs)
return ret return ret
@ -3279,6 +3281,7 @@ def show_gui():
wavtokenizer_var = ctk.StringVar() wavtokenizer_var = ctk.StringVar()
ttsgpu_var = ctk.IntVar(value=0) ttsgpu_var = ctk.IntVar(value=0)
tts_threads_var = ctk.StringVar(value=str(default_threads)) tts_threads_var = ctk.StringVar(value=str(default_threads))
ttsmaxlen_var = ctk.StringVar(value=str(4096))
def tabbuttonaction(name): def tabbuttonaction(name):
for t in tabcontent: for t in tabcontent:
@ -3855,6 +3858,7 @@ def show_gui():
makefileentry(audio_tab, "WavTokenizer Model (Text-To-Speech):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 7, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.") makefileentry(audio_tab, "WavTokenizer Model (Text-To-Speech):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 7, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.")
wavtokenizer_var.trace("w", gui_changed_modelfile) wavtokenizer_var.trace("w", gui_changed_modelfile)
makecheckbox(audio_tab, "TTS Use GPU", ttsgpu_var, 9, 0,tooltiptxt="Uses the GPU for TTS.") makecheckbox(audio_tab, "TTS Use GPU", ttsgpu_var, 9, 0,tooltiptxt="Uses the GPU for TTS.")
makelabelentry(audio_tab, "OuteTTS Max Tokens:" , ttsmaxlen_var, 11, 50,padx=290,singleline=True,tooltip="Max allowed audiotokens to generate per TTS request.")
ttsgpu_var.trace("w", gui_changed_modelfile) ttsgpu_var.trace("w", gui_changed_modelfile)
def kcpp_export_template(): def kcpp_export_template():
@ -4077,6 +4081,7 @@ def show_gui():
args.ttsmodel = tts_model_var.get() args.ttsmodel = tts_model_var.get()
args.ttswavtokenizer = wavtokenizer_var.get() args.ttswavtokenizer = wavtokenizer_var.get()
args.ttsgpu = (ttsgpu_var.get()==1) args.ttsgpu = (ttsgpu_var.get()==1)
args.ttsmaxlen = int(ttsmaxlen_var.get())
def import_vars(dict): def import_vars(dict):
global importvars_in_progress global importvars_in_progress
@ -4242,6 +4247,7 @@ def show_gui():
tts_model_var.set(dict["ttsmodel"] if ("ttsmodel" in dict and dict["ttsmodel"]) else "") tts_model_var.set(dict["ttsmodel"] if ("ttsmodel" in dict and dict["ttsmodel"]) else "")
wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "") wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "")
ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0) ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0)
ttsmaxlen_var.set(str(dict["ttsmaxlen"]) if ("ttsmaxlen" in dict and dict["ttsmaxlen"]) else str(4096))
importvars_in_progress = False importvars_in_progress = False
gui_changed_modelfile() gui_changed_modelfile()
@ -5646,6 +5652,7 @@ if __name__ == '__main__':
ttsparsergroup.add_argument("--ttsmodel", metavar=('[filename]'), help="Specify the OuteTTS Text-To-Speech GGUF model.", default="") ttsparsergroup.add_argument("--ttsmodel", metavar=('[filename]'), help="Specify the OuteTTS Text-To-Speech GGUF model.", default="")
ttsparsergroup.add_argument("--ttswavtokenizer", metavar=('[filename]'), help="Specify the WavTokenizer GGUF model.", default="") ttsparsergroup.add_argument("--ttswavtokenizer", metavar=('[filename]'), help="Specify the WavTokenizer GGUF model.", default="")
ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true') ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true')
ttsparsergroup.add_argument("--ttsmaxlen", help="Limit number of audio tokens generated with TTS.", type=int, default=4096)
ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0) ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0)
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')

View file

@ -478,6 +478,7 @@ static int cts_offset = 151672;
static int space_id = 151670; static int space_id = 151670;
static int code_terminate_id = 151670; static int code_terminate_id = 151670;
static int nthreads = 4; static int nthreads = 4;
static int tts_max_len = 4096;
bool ttstype_load_model(const tts_load_model_inputs inputs) bool ttstype_load_model(const tts_load_model_inputs inputs)
{ {
@ -522,6 +523,8 @@ bool ttstype_load_model(const tts_load_model_inputs inputs)
nthreads = inputs.threads; nthreads = inputs.threads;
tts_max_len = inputs.ttsmaxlen;
tts_model_params.use_mmap = false; tts_model_params.use_mmap = false;
tts_model_params.use_mlock = false; tts_model_params.use_mlock = false;
tts_model_params.n_gpu_layers = inputs.gpulayers; //offload if possible tts_model_params.n_gpu_layers = inputs.gpulayers; //offload if possible
@ -871,7 +874,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
// main loop // main loop
n_decode = 0; n_decode = 0;
n_predict = 4096; //max 4096 tokens n_predict = tts_max_len; //max 4096 tokens
while (n_decode <= n_predict) while (n_decode <= n_predict)
{ {

View file

@ -824,7 +824,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
for (const auto & file : files) { for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
mmaps_used.emplace_back(mapping->size(), 0); mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) { if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock()); std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());

View file

@ -1470,8 +1470,13 @@ struct llama_vocab::impl {
std::vector<llama_token> cache_special_tokens; std::vector<llama_token> cache_special_tokens;
std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true); std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true);
struct pair_hash {
std::map<std::pair<std::string, std::string>, int> bpe_ranks; size_t operator()(const std::pair<std::string, std::string> & p) const {
return std::hash<std::string>{}(p.first) ^ //create some hash for pair
(std::hash<std::string>{}(p.second) << 1);
}
};
std::unordered_map<std::pair<std::string, std::string>, int, pair_hash> bpe_ranks;
// set of all tokens that cause "end of generation" // set of all tokens that cause "end of generation"
std::set<llama_token> special_eog_ids; std::set<llama_token> special_eog_ids;

View file

@ -8468,13 +8468,141 @@ static enum ggml_status llama_graph_compute(
return status; return status;
} }
static int llama_prepare_sbatch(
llama_context & lctx,
const llama_batch & batch,
uint32_t & n_outputs) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const uint32_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch);
//GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
lctx.n_queued_tokens += n_tokens_all;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
} else {
// keep last output only
n_outputs = 1;
}
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !lctx.kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
return 0;
}
static int llama_prepare_ubatch(
llama_context & lctx,
llama_kv_slot_restorer & kv_slot_restorer,
llama_ubatch & ubatch,
const uint32_t n_outputs,
const uint32_t n_tokens_all) {
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
auto & kv_self = lctx.kv_self;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (lctx.kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
}
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += int32_t(ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
return 0;
}
// decode a batch of tokens by evaluating the transformer // decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning), // in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state // the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models) // (for non-recurrent models) or cleaned (for recurrent models)
// //
// - lctx: llama context // - lctx: llama context
// - batch: batch to evaluate // - inp_batch: batch to evaluate
// //
// return 0 on success // return 0 on success
// return positive int on warning // return positive int on warning
@ -8491,37 +8619,18 @@ static int llama_decode_impl(
return -1; return -1;
} }
// temporary allocate memory for the input batch if needed // temporarily allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams; const auto & cparams = lctx.cparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch);
//GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) { if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us(); lctx.t_compute_start_us = ggml_time_us();
} }
lctx.n_queued_tokens += n_tokens_all;
auto & kv_self = lctx.kv_self; auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self); llama_kv_slot_restorer kv_slot_restorer(kv_self);
@ -8531,99 +8640,27 @@ static int llama_decode_impl(
uint32_t n_outputs = 0; uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0; uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch; {
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens if (ret != 0) {
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; return ret;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
} }
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
} else {
// keep last output only
n_outputs = 1;
} }
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
while (lctx.sbatch.n_tokens > 0) { while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch; llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch
{ {
int32_t n_outputs_new = 0; const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
if (ret != 0) {
if (n_outputs == n_tokens_all) { return ret;
n_outputs_new = n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
} }
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
} }
int n_threads = (n_tokens < 32) ? cparams.n_threads : cparams.n_threads_batch; const int n_threads = ubatch.n_tokens < 32 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; ggml_threadpool_t threadpool = ubatch.n_tokens < 32 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0); GGML_ASSERT(n_threads > 0);
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get()); ggml_backend_sched_reset(lctx.sched.get());
@ -8676,7 +8713,7 @@ static int llama_decode_impl(
// update the kv ring buffer // update the kv ring buffer
{ {
kv_self.head += n_tokens; kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index. // Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) { if (kv_self.head >= kv_self.size) {