tts can now set a length limit

This commit is contained in:
Concedo 2025-01-28 22:06:59 +08:00
commit 558bc5c901
8 changed files with 235 additions and 119 deletions

View file

@ -214,6 +214,7 @@ struct tts_load_model_inputs
const char * vulkan_info = nullptr; const char * vulkan_info = nullptr;
const int gpulayers = 0; const int gpulayers = 0;
const bool flash_attention = false; const bool flash_attention = false;
const int ttsmaxlen = 4096;
const bool quiet = false; const bool quiet = false;
const int debugmode = 0; const int debugmode = 0;
}; };

View file

@ -46,20 +46,20 @@
#define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 1000000 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64 // GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220 #define GGML_CUDA_CC_QY2 220

View file

@ -121,6 +121,55 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
#endif #endif
} }
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static int ggml_cuda_parse_id(char devName[]) {
// A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
// these values are not stable so this is susceptible to breakage
// https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
int archMajor = 0x0;
int archMinor = 0x0;
int archNum = GGML_CUDA_CC_OFFSET_AMD;
int archLen = strlen(devName);
char archName[archLen + 1];
// strip leading 'gfx' while copying into our buffer
if (archLen > 3) {
strcpy(archName, &devName[3]);
archLen -= 3;
}
// trim trailing :xnack- or :sramecc- statuses
archLen = strcspn(archName, ":");
archName[archLen] = '\0';
// tease out the version information
if (archLen > 8) {
// versions labeled generic use '-' as delimiter
// strip the trailing "-generic" then iterate through what remains
if ((strstr(archName, "-generic"))) {
archName[archLen - 8] = '\0';
char * pch;
if ((pch = strtok(archName, "-"))) {
archMajor = (int)strtoul(pch, 0, 16);
if ((pch = strtok(NULL, "-"))) {
archMinor = 0x10 * (int)strtoul(pch, 0, 16);
}
}
}
} else if (archLen >= 3) {
// last two digits should be the minor * 0x10 + stepping
archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
archName[archLen - 2] = '\0';
// only the major version remains
archMajor = (int)strtoul(archName, 0, 16);
}
archNum += archMajor * 0x100;
archNum += archMinor;
return archNum;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static ggml_cuda_device_info ggml_cuda_init() { static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards: // Workaround for a rocBLAS bug when using multiple graphics cards:
@ -172,7 +221,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
info.default_tensor_split[id] = total_vram; info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem; total_vram += prop.totalGlobalMem;
@ -181,10 +229,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
if ((info.devices[id].cc & 0xff00) == 0x0) {
GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
id, prop.name, prop.gcnArchName, prop.major, prop.minor);
// Fallback to prop.major and prop.minor
if (prop.major > 0) {
info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
}
}
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no");
#else #else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor; info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
} }

View file

@ -291,6 +291,7 @@ class tts_load_model_inputs(ctypes.Structure):
("vulkan_info", ctypes.c_char_p), ("vulkan_info", ctypes.c_char_p),
("gpulayers", ctypes.c_int), ("gpulayers", ctypes.c_int),
("flash_attention", ctypes.c_bool), ("flash_attention", ctypes.c_bool),
("ttsmaxlen", ctypes.c_int),
("quiet", ctypes.c_bool), ("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)] ("debugmode", ctypes.c_int)]
@ -1451,6 +1452,7 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
if ttst > 0: if ttst > 0:
thds = ttst thds = ttst
inputs.threads = thds inputs.threads = thds
inputs.ttsmaxlen = args.ttsmaxlen if args.ttsmaxlen < 4096 else 4096
inputs = set_backend_props(inputs) inputs = set_backend_props(inputs)
ret = handle.tts_load_model(inputs) ret = handle.tts_load_model(inputs)
return ret return ret
@ -3279,6 +3281,7 @@ def show_gui():
wavtokenizer_var = ctk.StringVar() wavtokenizer_var = ctk.StringVar()
ttsgpu_var = ctk.IntVar(value=0) ttsgpu_var = ctk.IntVar(value=0)
tts_threads_var = ctk.StringVar(value=str(default_threads)) tts_threads_var = ctk.StringVar(value=str(default_threads))
ttsmaxlen_var = ctk.StringVar(value=str(4096))
def tabbuttonaction(name): def tabbuttonaction(name):
for t in tabcontent: for t in tabcontent:
@ -3855,6 +3858,7 @@ def show_gui():
makefileentry(audio_tab, "WavTokenizer Model (Text-To-Speech):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 7, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.") makefileentry(audio_tab, "WavTokenizer Model (Text-To-Speech):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 7, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.")
wavtokenizer_var.trace("w", gui_changed_modelfile) wavtokenizer_var.trace("w", gui_changed_modelfile)
makecheckbox(audio_tab, "TTS Use GPU", ttsgpu_var, 9, 0,tooltiptxt="Uses the GPU for TTS.") makecheckbox(audio_tab, "TTS Use GPU", ttsgpu_var, 9, 0,tooltiptxt="Uses the GPU for TTS.")
makelabelentry(audio_tab, "OuteTTS Max Tokens:" , ttsmaxlen_var, 11, 50,padx=290,singleline=True,tooltip="Max allowed audiotokens to generate per TTS request.")
ttsgpu_var.trace("w", gui_changed_modelfile) ttsgpu_var.trace("w", gui_changed_modelfile)
def kcpp_export_template(): def kcpp_export_template():
@ -4077,6 +4081,7 @@ def show_gui():
args.ttsmodel = tts_model_var.get() args.ttsmodel = tts_model_var.get()
args.ttswavtokenizer = wavtokenizer_var.get() args.ttswavtokenizer = wavtokenizer_var.get()
args.ttsgpu = (ttsgpu_var.get()==1) args.ttsgpu = (ttsgpu_var.get()==1)
args.ttsmaxlen = int(ttsmaxlen_var.get())
def import_vars(dict): def import_vars(dict):
global importvars_in_progress global importvars_in_progress
@ -4242,6 +4247,7 @@ def show_gui():
tts_model_var.set(dict["ttsmodel"] if ("ttsmodel" in dict and dict["ttsmodel"]) else "") tts_model_var.set(dict["ttsmodel"] if ("ttsmodel" in dict and dict["ttsmodel"]) else "")
wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "") wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "")
ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0) ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0)
ttsmaxlen_var.set(str(dict["ttsmaxlen"]) if ("ttsmaxlen" in dict and dict["ttsmaxlen"]) else str(4096))
importvars_in_progress = False importvars_in_progress = False
gui_changed_modelfile() gui_changed_modelfile()
@ -5646,6 +5652,7 @@ if __name__ == '__main__':
ttsparsergroup.add_argument("--ttsmodel", metavar=('[filename]'), help="Specify the OuteTTS Text-To-Speech GGUF model.", default="") ttsparsergroup.add_argument("--ttsmodel", metavar=('[filename]'), help="Specify the OuteTTS Text-To-Speech GGUF model.", default="")
ttsparsergroup.add_argument("--ttswavtokenizer", metavar=('[filename]'), help="Specify the WavTokenizer GGUF model.", default="") ttsparsergroup.add_argument("--ttswavtokenizer", metavar=('[filename]'), help="Specify the WavTokenizer GGUF model.", default="")
ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true') ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true')
ttsparsergroup.add_argument("--ttsmaxlen", help="Limit number of audio tokens generated with TTS.", type=int, default=4096)
ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0) ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0)
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')

View file

@ -478,6 +478,7 @@ static int cts_offset = 151672;
static int space_id = 151670; static int space_id = 151670;
static int code_terminate_id = 151670; static int code_terminate_id = 151670;
static int nthreads = 4; static int nthreads = 4;
static int tts_max_len = 4096;
bool ttstype_load_model(const tts_load_model_inputs inputs) bool ttstype_load_model(const tts_load_model_inputs inputs)
{ {
@ -522,6 +523,8 @@ bool ttstype_load_model(const tts_load_model_inputs inputs)
nthreads = inputs.threads; nthreads = inputs.threads;
tts_max_len = inputs.ttsmaxlen;
tts_model_params.use_mmap = false; tts_model_params.use_mmap = false;
tts_model_params.use_mlock = false; tts_model_params.use_mlock = false;
tts_model_params.n_gpu_layers = inputs.gpulayers; //offload if possible tts_model_params.n_gpu_layers = inputs.gpulayers; //offload if possible
@ -871,7 +874,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
// main loop // main loop
n_decode = 0; n_decode = 0;
n_predict = 4096; //max 4096 tokens n_predict = tts_max_len; //max 4096 tokens
while (n_decode <= n_predict) while (n_decode <= n_predict)
{ {

View file

@ -824,7 +824,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
for (const auto & file : files) { for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
mmaps_used.emplace_back(mapping->size(), 0); mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) { if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock()); std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());

View file

@ -1470,8 +1470,13 @@ struct llama_vocab::impl {
std::vector<llama_token> cache_special_tokens; std::vector<llama_token> cache_special_tokens;
std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true); std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true);
struct pair_hash {
std::map<std::pair<std::string, std::string>, int> bpe_ranks; size_t operator()(const std::pair<std::string, std::string> & p) const {
return std::hash<std::string>{}(p.first) ^ //create some hash for pair
(std::hash<std::string>{}(p.second) << 1);
}
};
std::unordered_map<std::pair<std::string, std::string>, int, pair_hash> bpe_ranks;
// set of all tokens that cause "end of generation" // set of all tokens that cause "end of generation"
std::set<llama_token> special_eog_ids; std::set<llama_token> special_eog_ids;

View file

@ -8468,74 +8468,33 @@ static enum ggml_status llama_graph_compute(
return status; return status;
} }
// decode a batch of tokens by evaluating the transformer static int llama_prepare_sbatch(
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_decode_impl(
llama_context & lctx, llama_context & lctx,
llama_batch inp_batch) { const llama_batch & batch,
uint32_t & n_outputs) {
lctx.is_encoding = false;
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams; const auto & cparams = lctx.cparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const uint32_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) { if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) { for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1; return -1;
} }
} }
} }
GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch); GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch);
//GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); //GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all; lctx.n_queued_tokens += n_tokens_all;
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = vocab.n_tokens();
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear(); lctx.embd_seq.clear();
// count outputs // count outputs
@ -8551,7 +8510,7 @@ static int llama_decode_impl(
} }
lctx.sbatch.from_batch(batch, n_embd, lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent, /* simple_split */ !lctx.kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all); /* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer // reserve output buffer
@ -8560,32 +8519,47 @@ static int llama_decode_impl(
return -2; return -2;
}; };
while (lctx.sbatch.n_tokens > 0) { return 0;
llama_ubatch ubatch; }
if (kv_self.recurrent) {
static int llama_prepare_ubatch(
llama_context & lctx,
llama_kv_slot_restorer & kv_slot_restorer,
llama_ubatch & ubatch,
const uint32_t n_outputs,
const uint32_t n_tokens_all) {
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
auto & kv_self = lctx.kv_self;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (lctx.kv_self.recurrent) {
if (embd_pooled) { if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet) // Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch); ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
} else { } else {
// recurrent model architectures are easier to implement // recurrent model architectures are easier to implement
// with equal-length sequences // with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch); ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
} }
} else { } else {
ubatch = lctx.sbatch.split_simple(n_ubatch); ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
} }
const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch // count the outputs in this u_batch
{ {
int32_t n_outputs_new = 0; int32_t n_outputs_new = 0;
if (n_outputs == n_tokens_all) { if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens; n_outputs_new = ubatch.n_tokens;
} else { } else {
GGML_ASSERT(ubatch.output); GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0); n_outputs_new += int32_t(ubatch.output[i] != 0);
} }
} }
@ -8593,18 +8567,13 @@ static int llama_decode_impl(
lctx.n_outputs = n_outputs_new; lctx.n_outputs = n_outputs_new;
} }
int n_threads = (n_tokens < 32) ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
// non-causal masks do not use the KV cache // non-causal masks do not use the KV cache
if (hparams.causal_attn) { if (hparams.causal_attn) {
llama_kv_cache_update(&lctx); llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) { if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0; kv_self.head = 0;
} }
@ -8624,6 +8593,74 @@ static int llama_decode_impl(
} }
} }
return 0;
}
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_decode_impl(
llama_context & lctx,
llama_batch inp_batch) {
lctx.is_encoding = false;
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporarily allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
const llama_batch & batch = batch_allocr.batch;
const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = vocab.n_tokens();
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
{
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
if (ret != 0) {
return ret;
}
}
while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
{
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
if (ret != 0) {
return ret;
}
}
const int n_threads = ubatch.n_tokens < 32 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = ubatch.n_tokens < 32 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get()); ggml_backend_sched_reset(lctx.sched.get());
@ -8676,7 +8713,7 @@ static int llama_decode_impl(
// update the kv ring buffer // update the kv ring buffer
{ {
kv_self.head += n_tokens; kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index. // Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) { if (kv_self.head >= kv_self.size) {