mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge commit '456af35eb7
' into concedo_experimental
# Conflicts: # ggml/src/ggml-sycl/getrows.cpp # src/CMakeLists.txt # tools/llama-bench/llama-bench.cpp
This commit is contained in:
commit
b59b5dbbd1
28 changed files with 1403 additions and 496 deletions
2
Makefile
2
Makefile
|
@ -675,7 +675,7 @@ embeddings_default.o: otherarch/embeddings_adapter.cpp
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
# idiotic "for easier compilation"
|
# idiotic "for easier compilation"
|
||||||
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache-unified.cpp src/llama-kv-cache-unified-iswa.cpp src/llama-kv-cache-recurrent.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
|
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache-unified.cpp src/llama-kv-cache-unified-iswa.cpp src/llama-memory-hybrid.cpp src/llama-memory-recurrent.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
|
||||||
gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
|
gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
|
||||||
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
|
||||||
gpttype_adapter.o: $(GPTTYPE_ADAPTER)
|
gpttype_adapter.o: $(GPTTYPE_ADAPTER)
|
||||||
|
|
|
@ -714,11 +714,17 @@ bool fs_validate_filename(const std::string & filename) {
|
||||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic push
|
||||||
|
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||||
|
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic pop
|
# pragma clang diagnostic pop
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
filename_utf32 = converter.from_bytes(filename);
|
filename_utf32 = converter.from_bytes(filename);
|
||||||
|
|
|
@ -6389,8 +6389,8 @@ def parse_args() -> argparse.Namespace:
|
||||||
help="model is executed on big endian machine",
|
help="model is executed on big endian machine",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model", type=Path,
|
"model", type=str,
|
||||||
help="directory containing model file",
|
help="directory containing model file or huggingface repository ID (if --remote)",
|
||||||
nargs="?",
|
nargs="?",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -6493,18 +6493,20 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
dir_model = args.model
|
|
||||||
|
|
||||||
if args.remote:
|
if args.remote:
|
||||||
|
hf_repo_id = args.model
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
local_dir = snapshot_download(
|
local_dir = snapshot_download(
|
||||||
repo_id=str(dir_model),
|
repo_id=hf_repo_id,
|
||||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||||
dir_model = Path(local_dir)
|
dir_model = Path(local_dir)
|
||||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||||
|
else:
|
||||||
|
hf_repo_id = None
|
||||||
|
dir_model = Path(args.model)
|
||||||
|
|
||||||
if not dir_model.is_dir():
|
if not dir_model.is_dir():
|
||||||
logger.error(f'Error: {args.model} is not a directory')
|
logger.error(f'Error: {dir_model} is not a directory')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
ftype_map: dict[str, gguf.LlamaFileType] = {
|
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||||
|
@ -6524,9 +6526,9 @@ def main() -> None:
|
||||||
|
|
||||||
if args.outfile is not None:
|
if args.outfile is not None:
|
||||||
fname_out = args.outfile
|
fname_out = args.outfile
|
||||||
elif args.remote:
|
elif hf_repo_id:
|
||||||
# if remote, use the model ID as the output file name
|
# if remote, use the model ID as the output file name
|
||||||
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
|
fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
|
||||||
else:
|
else:
|
||||||
fname_out = dir_model
|
fname_out = dir_model
|
||||||
|
|
||||||
|
@ -6555,7 +6557,7 @@ def main() -> None:
|
||||||
split_max_tensors=args.split_max_tensors,
|
split_max_tensors=args.split_max_tensors,
|
||||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||||
small_first_shard=args.no_tensor_first_split,
|
small_first_shard=args.no_tensor_first_split,
|
||||||
remote_hf_model_id=str(args.model) if args.remote else None)
|
remote_hf_model_id=hf_repo_id)
|
||||||
|
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
logger.info("Exporting model vocab...")
|
logger.info("Exporting model vocab...")
|
||||||
|
|
157
docs/build-s390x.md
Normal file
157
docs/build-s390x.md
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> This build documentation is specific only to IBM Z & LinuxONE mainframes (s390x). You can find the build documentation for other architectures: [build.md](build.md).
|
||||||
|
|
||||||
|
# Build llama.cpp locally (for s390x)
|
||||||
|
|
||||||
|
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
|
||||||
|
|
||||||
|
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
|
||||||
|
|
||||||
|
**To get the code:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ggml-org/llama.cpp
|
||||||
|
cd llama.cpp
|
||||||
|
```
|
||||||
|
|
||||||
|
## CPU Build with BLAS
|
||||||
|
|
||||||
|
Building llama.cpp with BLAS support is highly recommended as it has shown to provide performance improvements.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_BLAS=ON \
|
||||||
|
-DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
|
||||||
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes**:
|
||||||
|
- For faster repeated compilation, install [ccache](https://ccache.dev/)
|
||||||
|
- By default, VXE/VXE2 is enabled. To disable it (not recommended):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_BLAS=ON \
|
||||||
|
-DGGML_BLAS_VENDOR=OpenBLAS \
|
||||||
|
-DGGML_VXE=OFF
|
||||||
|
|
||||||
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
```
|
||||||
|
|
||||||
|
- For debug builds:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=Debug \
|
||||||
|
-DGGML_BLAS=ON \
|
||||||
|
-DGGML_BLAS_VENDOR=OpenBLAS
|
||||||
|
|
||||||
|
cmake --build build --config Debug -j $(nproc)
|
||||||
|
```
|
||||||
|
|
||||||
|
- For static builds, add `-DBUILD_SHARED_LIBS=OFF`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_BLAS=ON \
|
||||||
|
-DGGML_BLAS_VENDOR=OpenBLAS \
|
||||||
|
-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|
||||||
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting GGUF Models
|
||||||
|
|
||||||
|
All models need to be converted to Big-Endian. You can achieve this in three cases:
|
||||||
|
|
||||||
|
1. **Use pre-converted models verified for use on IBM Z & LinuxONE (easiest)**
|
||||||
|
|
||||||
|
You can find popular models pre-converted and verified at [s390x Ready Models](hf.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08).
|
||||||
|
|
||||||
|
These models and their respective tokenizers are verified to run correctly on IBM Z & LinuxONE.
|
||||||
|
|
||||||
|
2. **Convert safetensors model to GGUF Big-Endian directly (recommended)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 convert_hf_to_gguf.py \
|
||||||
|
--outfile model-name-be.f16.gguf \
|
||||||
|
--outtype f16 \
|
||||||
|
--bigendian \
|
||||||
|
model-directory/
|
||||||
|
```
|
||||||
|
|
||||||
|
For example,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 convert_hf_to_gguf.py \
|
||||||
|
--outfile granite-3.3-2b-instruct-be.f16.gguf \
|
||||||
|
--outtype f16 \
|
||||||
|
--bigendian \
|
||||||
|
granite-3.3-2b-instruct/
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Convert existing GGUF Little-Endian model to Big-Endian**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG
|
||||||
|
```
|
||||||
|
|
||||||
|
For example,
|
||||||
|
```bash
|
||||||
|
python3 gguf-py/gguf/scripts/gguf_convert_endian.py granite-3.3-2b-instruct-le.f16.gguf BIG
|
||||||
|
mv granite-3.3-2b-instruct-le.f16.gguf granite-3.3-2b-instruct-be.f16.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes:**
|
||||||
|
- The GGUF endian conversion script may not support all data types at the moment and may fail for some models/quantizations. When that happens, please try manually converting the safetensors model to GGUF Big-Endian via Step 2.
|
||||||
|
|
||||||
|
## IBM Accelerators
|
||||||
|
|
||||||
|
### 1. SIMD Acceleration
|
||||||
|
|
||||||
|
Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14 or EC13. In such systems, the APIs can still run but will use a scalar implementation.
|
||||||
|
|
||||||
|
### 2. zDNN Accelerator
|
||||||
|
|
||||||
|
*Only available in IBM z16 or later system. No direction at the moment.*
|
||||||
|
|
||||||
|
### 3. Spyre Accelerator
|
||||||
|
|
||||||
|
*No direction at the moment.*
|
||||||
|
|
||||||
|
## Performance Tuning
|
||||||
|
|
||||||
|
### 1. Virtualization Setup
|
||||||
|
|
||||||
|
It is strongly recommended to use only LPAR (Type-1) virtualization to get the most performance.
|
||||||
|
|
||||||
|
Note: Type-2 virtualization is not supported at the moment, while you can get it running, the performance will not be the best.
|
||||||
|
|
||||||
|
### 2. IFL (Core) Count
|
||||||
|
|
||||||
|
It is recommended to allocate a minimum of 8 shared IFLs assigned to the LPAR. Increasing the IFL count past 8 shared IFLs will only improve Prompt Processing performance but not Token Generation.
|
||||||
|
|
||||||
|
Note: IFL count does not equate to vCPU count.
|
||||||
|
|
||||||
|
### 3. SMT vs NOSMT (Simultaneous Multithreading)
|
||||||
|
|
||||||
|
It is strongly recommended to disable SMT via the kernel boot parameters as it negatively affects performance. Please refer to your Linux distribution's guide on disabling SMT via kernel boot parameters.
|
||||||
|
|
||||||
|
### 4. BLAS vs NOBLAS
|
||||||
|
|
||||||
|
IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongly recommended to use BLAS.
|
||||||
|
|
||||||
|
## Getting Help on IBM Z & LinuxONE
|
||||||
|
|
||||||
|
1. **Bugs, Feature Requests**
|
||||||
|
|
||||||
|
Please file an issue in llama.cpp and ensure that the title contains "s390x".
|
||||||
|
|
||||||
|
2. **Other Questions**
|
||||||
|
|
||||||
|
Please reach out directly to [aionz@us.ibm.com](mailto:aionz@us.ibm.com).
|
||||||
|
|
|
@ -69,6 +69,9 @@
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic push
|
||||||
|
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
|
@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
|
||||||
|
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic pop
|
# pragma clang diagnostic pop
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
|
|
@ -383,9 +383,9 @@ typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
|
||||||
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
||||||
|
|
||||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||||
typedef double double64x2_t __attribute((vector_size(16)));
|
typedef double double64x2_t __attribute__((vector_size(16)));
|
||||||
|
|
||||||
typedef signed long long long64x2_t __attribute((vector_size(16)));
|
typedef signed long long long64x2_t __attribute__((vector_size(16)));
|
||||||
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
||||||
|
|
||||||
typedef struct ggml_uint8x16x2_t {
|
typedef struct ggml_uint8x16x2_t {
|
||||||
|
|
|
@ -62,7 +62,7 @@
|
||||||
#define NOINLINE __attribute__((__noinline__))
|
#define NOINLINE __attribute__((__noinline__))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
|
||||||
#define VECTOR_REGISTERS 32
|
#define VECTOR_REGISTERS 32
|
||||||
#else
|
#else
|
||||||
#define VECTOR_REGISTERS 16
|
#define VECTOR_REGISTERS 16
|
||||||
|
@ -109,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
||||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
#if defined(__VXE__) || defined(__VXE2__)
|
||||||
|
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
|
||||||
|
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
|
||||||
|
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__MMA__)
|
#if defined(__MMA__)
|
||||||
typedef vector unsigned char vec_t;
|
typedef vector unsigned char vec_t;
|
||||||
typedef __vector_quad acc_t;
|
typedef __vector_quad acc_t;
|
||||||
|
@ -162,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__VXE__) || defined(__VXE2__)
|
||||||
|
template <>
|
||||||
|
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
||||||
|
return vec_madd(a, b, c);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED HORIZONTAL SUM
|
// VECTORIZED HORIZONTAL SUM
|
||||||
|
|
||||||
|
@ -178,6 +191,13 @@ inline float hsum(float16x8_t x) {
|
||||||
}
|
}
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
#if defined(__VXE__) || defined(__VXE2__)
|
||||||
|
inline float hsum(float32x4_t x) {
|
||||||
|
float32x4_t tmp = x + vec_reve(x);
|
||||||
|
return tmp[0] + tmp[1];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||||
inline float hsum(__m128 x) {
|
inline float hsum(__m128 x) {
|
||||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||||
|
@ -227,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
|
||||||
#endif // _MSC_VER
|
#endif // _MSC_VER
|
||||||
#endif // __ARM_NEON
|
#endif // __ARM_NEON
|
||||||
|
|
||||||
|
#if defined(__VXE__) || defined(__VXE2__)
|
||||||
|
template <> inline float32x4_t load(const ggml_fp16_t * p) {
|
||||||
|
float tmp[4];
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
tmp[i] = GGML_FP16_TO_FP32(p[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec_xl(0, (const float *)(tmp));
|
||||||
|
}
|
||||||
|
template <> inline float32x4_t load(const float * p) {
|
||||||
|
return vec_xl(0, p);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||||
template <> inline __m128 load(const float *p) {
|
template <> inline __m128 load(const float *p) {
|
||||||
return _mm_loadu_ps(p);
|
return _mm_loadu_ps(p);
|
||||||
|
@ -3319,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||||
(const float *)B, ldb,
|
(const float *)B, ldb,
|
||||||
(float *)C, ldc};
|
(float *)C, ldc};
|
||||||
return tb.matmul(m, n);
|
return tb.matmul(m, n);
|
||||||
|
#elif defined(__VXE__) || defined(__VXE2__)
|
||||||
|
if (n < 4)
|
||||||
|
return false;
|
||||||
|
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
||||||
|
k, (const float *)A, lda,
|
||||||
|
(const float *)B, ldb,
|
||||||
|
(float *)C, ldc};
|
||||||
|
return tb.matmul(m, n);
|
||||||
#elif defined(__MMA__)
|
#elif defined(__MMA__)
|
||||||
if (k % 8)
|
if (k % 8)
|
||||||
return false;
|
return false;
|
||||||
|
@ -3410,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||||
(float *)C, ldc};
|
(float *)C, ldc};
|
||||||
return tb.matmul(m, n);
|
return tb.matmul(m, n);
|
||||||
}
|
}
|
||||||
|
#elif defined(__VXE__) || defined(__VXE2__)
|
||||||
|
if (n < 4)
|
||||||
|
return false;
|
||||||
|
if (Btype == GGML_TYPE_F16) {
|
||||||
|
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
||||||
|
k, (const ggml_fp16_t *)A, lda,
|
||||||
|
(const ggml_fp16_t *)B, ldb,
|
||||||
|
(float *)C, ldc};
|
||||||
|
return tb.matmul(m, n);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#if defined(__VXE__) || defined(__VXE2__)
|
||||||
|
#include <vecintrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||||
for (int i = 0; i < offset; ++i) { \
|
for (int i = 0; i < offset; ++i) { \
|
||||||
x[i] = vec_add(x[i], x[offset + i]); \
|
x[i] = vec_add(x[i], x[offset + i]); \
|
||||||
} \
|
} \
|
||||||
res = vec_extract(x[0], 0) + \
|
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||||
vec_extract(x[0], 1) + \
|
res = tmp[0] + tmp[1]; \
|
||||||
vec_extract(x[0], 2) + \
|
|
||||||
vec_extract(x[0], 3); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GGML_F32_VEC GGML_F32x4
|
#define GGML_F32_VEC GGML_F32x4
|
||||||
|
|
|
@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
GGML_METAL_KERNEL_TYPE_NEG,
|
GGML_METAL_KERNEL_TYPE_NEG,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||||
|
@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||||
|
@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return false; // TODO: implement
|
return false; // TODO: implement
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
|
@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (dst->op) {
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_OP_MEAN:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = MIN(nth, ne00);
|
||||||
|
|
||||||
ggml_metal_kargs_sum_rows args = {
|
ggml_metal_kargs_sum_rows args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
|
@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
|
|
|
@ -993,31 +993,61 @@ kernel void kernel_neg(
|
||||||
dst[tpig] = -src0[tpig];
|
dst[tpig] = -src0[tpig];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool norm>
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
|
constant ggml_metal_kargs_sum_rows & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_sum_rows & args,
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
int64_t i3 = tpig.z;
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
int64_t i2 = tpig.y;
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
int64_t i1 = tpig.x;
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
int64_t i3 = tgpig.z;
|
||||||
|
int64_t i2 = tgpig.y;
|
||||||
|
int64_t i1 = tgpig.x;
|
||||||
|
|
||||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
|
shmem_f32[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||||
|
|
||||||
float row_sum = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||||
row_sum += src_row[i0];
|
sumf += src_row[i0];
|
||||||
}
|
}
|
||||||
|
|
||||||
dst_row[0] = row_sum;
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sumf = shmem_f32[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||||
|
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
|
@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
||||||
|
|
||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||||
|
@ -1816,3 +1817,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||||
return LLM_TENSOR_INFOS.at(tensor);
|
return LLM_TENSOR_INFOS.at(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||||
|
switch (arch) {
|
||||||
|
case LLM_ARCH_MAMBA:
|
||||||
|
case LLM_ARCH_RWKV6:
|
||||||
|
case LLM_ARCH_RWKV6QWEN2:
|
||||||
|
case LLM_ARCH_RWKV7:
|
||||||
|
case LLM_ARCH_ARWKV7:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||||
|
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||||
|
// the place to identify them
|
||||||
|
switch (arch) {
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -151,6 +151,7 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
LLM_KV_ATTENTION_LAYER_INDICES,
|
||||||
|
|
||||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||||
|
@ -439,3 +440,6 @@ const char * llm_arch_name(llm_arch arch);
|
||||||
llm_arch llm_arch_from_string(const std::string & name);
|
llm_arch llm_arch_from_string(const std::string & name);
|
||||||
|
|
||||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||||
|
|
||||||
|
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||||
|
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
#include "llama-kv-cache-unified-iswa.h"
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-memory-hybrid.h"
|
||||||
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -238,18 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
const int64_t n_kv = kv_state->get_n_kv();
|
const int64_t n_rs = mem_state->get_n_rs();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
int32_t * data = (int32_t *) s_copy->data;
|
int32_t * data = (int32_t *) s_copy->data;
|
||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
data[i] = kv_state->s_copy(i);
|
data[i] = mem_state->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -403,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
|
if (self_kq_mask) {
|
||||||
|
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
||||||
|
|
||||||
|
if (s_copy) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
|
int32_t * data = (int32_t *) s_copy->data;
|
||||||
|
|
||||||
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
|
data[i] = mem_state->get_state_recr()->s_copy(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_graph_context
|
// llm_graph_context
|
||||||
//
|
//
|
||||||
|
@ -961,23 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
|
||||||
|
|
||||||
auto & cur = inp->s_copy;
|
|
||||||
|
|
||||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
||||||
ggml_set_input(cur);
|
|
||||||
|
|
||||||
res->add_input(std::move(inp));
|
|
||||||
|
|
||||||
return cur;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||||
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
||||||
|
|
||||||
|
@ -1047,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
||||||
return pos_bias;
|
return pos_bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
|
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
|
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
||||||
|
|
||||||
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
|
ggml_set_input(inp->s_copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_attn_mha(
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * q,
|
ggml_tensor * q,
|
||||||
|
@ -1291,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_attn(
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
|
@ -1430,20 +1429,99 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla,
|
||||||
|
float kq_scale,
|
||||||
|
int il) const {
|
||||||
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
|
// by doing so, the number of splits in the graph is reduced
|
||||||
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
|
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
||||||
|
|
||||||
|
// store to KV cache
|
||||||
|
{
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
|
ggml_tensor * q = q_cur;
|
||||||
|
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||||
|
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
|
if (wo) {
|
||||||
|
cur = build_lora_mm(wo, cur);
|
||||||
|
if (arch == LLM_ARCH_GLM4) {
|
||||||
|
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||||
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
cur = ggml_add(ctx0, cur, wo_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||||
|
|
||||||
|
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
ggml_tensor * state_copy,
|
ggml_tensor * state_copy,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
|
uint32_t n_kv,
|
||||||
|
uint32_t kv_head,
|
||||||
|
uint32_t kv_size,
|
||||||
|
int32_t rs_zero,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
||||||
const auto kv_head = kv_state->get_head();
|
|
||||||
const auto rs_zero = kv_state->get_rs_z();
|
|
||||||
|
|
||||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
|
||||||
|
|
||||||
// Clear a single state which will then be copied to the other cleared states.
|
// Clear a single state which will then be copied to the other cleared states.
|
||||||
// Note that this is a no-op when the view is zero-sized.
|
// Note that this is a no-op when the view is zero-sized.
|
||||||
|
@ -1474,22 +1552,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||||
return output_states;
|
return output_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
||||||
|
|
||||||
|
const auto n_rs = kv_state->get_n_rs();
|
||||||
|
|
||||||
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
|
ggml_set_input(inp->s_copy);
|
||||||
|
|
||||||
|
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * s,
|
||||||
|
int32_t state_size,
|
||||||
|
int32_t n_seqs,
|
||||||
|
bool avoid_copies) const {
|
||||||
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * s,
|
||||||
|
int32_t state_size,
|
||||||
|
int32_t n_seqs,
|
||||||
|
bool avoid_copies) const {
|
||||||
|
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
||||||
|
|
||||||
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * state_copy,
|
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_recurrent_state(
|
ggml_tensor * token_shift = build_rs(
|
||||||
gf, token_shift_all, state_copy,
|
inp, gf, token_shift_all,
|
||||||
hparams.n_embd_k_s(), n_seqs);
|
hparams.n_embd_r(), n_seqs);
|
||||||
|
|
||||||
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
||||||
|
|
||||||
|
@ -1500,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||||
ggml_tensor * token_shift,
|
ggml_tensor * token_shift,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
@ -1512,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||||
return ggml_cpy(
|
return ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||||
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,8 @@ struct llama_memory_state_i;
|
||||||
|
|
||||||
class llama_kv_cache_unified_state;
|
class llama_kv_cache_unified_state;
|
||||||
class llama_kv_cache_unified_iswa_state;
|
class llama_kv_cache_unified_iswa_state;
|
||||||
class llama_kv_cache_recurrent_state;
|
class llama_memory_recurrent_state;
|
||||||
|
class llama_memory_hybrid_state;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
|
@ -188,16 +189,16 @@ public:
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
class llm_graph_input_rs : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
||||||
virtual ~llm_graph_input_s_copy() = default;
|
virtual ~llm_graph_input_rs() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [kv_size]
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
const llama_kv_cache_recurrent_state * kv_state;
|
const llama_memory_recurrent_state * mem_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
|
@ -300,6 +301,33 @@ public:
|
||||||
const llama_cross * cross = nullptr;
|
const llama_cross * cross = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_mem_hybrid(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_memory_hybrid_state * mem_state) :
|
||||||
|
hparams(hparams),
|
||||||
|
cparams(cparams),
|
||||||
|
mem_state(mem_state) {
|
||||||
|
}
|
||||||
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
|
const llama_memory_hybrid_state * mem_state;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_graph_result
|
// llm_graph_result
|
||||||
//
|
//
|
||||||
|
@ -508,13 +536,14 @@ struct llm_graph_context {
|
||||||
ggml_tensor * build_inp_out_ids() const;
|
ggml_tensor * build_inp_out_ids() const;
|
||||||
ggml_tensor * build_inp_mean() const;
|
ggml_tensor * build_inp_mean() const;
|
||||||
ggml_tensor * build_inp_cls() const;
|
ggml_tensor * build_inp_cls() const;
|
||||||
ggml_tensor * build_inp_s_copy() const;
|
|
||||||
|
|
||||||
ggml_tensor * build_inp_cross_embd() const;
|
ggml_tensor * build_inp_cross_embd() const;
|
||||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||||
ggml_tensor * build_inp_pos_bucket_dec() const;
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||||
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// attention
|
// attention
|
||||||
//
|
//
|
||||||
|
@ -589,21 +618,61 @@ struct llm_graph_context {
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
//
|
//
|
||||||
// recurrent
|
// recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
ggml_tensor * build_recurrent_state(
|
// TODO: avoid notion of "kv"
|
||||||
|
// TODO: move this implementation to llama_memory_recurrent.
|
||||||
|
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||||
|
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||||
|
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||||
|
// `llama_memory_recurrent`
|
||||||
|
ggml_tensor * build_rs(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
ggml_tensor * state_copy,
|
ggml_tensor * state_copy,
|
||||||
|
int32_t state_size,
|
||||||
|
int32_t n_seqs,
|
||||||
|
uint32_t n_kv,
|
||||||
|
uint32_t kv_head,
|
||||||
|
uint32_t kv_size,
|
||||||
|
int32_t rs_zero,
|
||||||
|
bool avoid_copies = false) const;
|
||||||
|
|
||||||
|
llm_graph_input_rs * build_rs_inp() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_rs(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * s,
|
||||||
|
int32_t state_size,
|
||||||
|
int32_t n_seqs,
|
||||||
|
bool avoid_copies = false) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_rs(
|
||||||
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies = false) const;
|
bool avoid_copies = false) const;
|
||||||
|
|
||||||
ggml_tensor * build_rwkv_token_shift_load(
|
ggml_tensor * build_rwkv_token_shift_load(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * state_copy,
|
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||||
return n_embd_head_v * n_head_kv;
|
return n_embd_head_v * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_embd_k_s() const {
|
uint32_t llama_hparams::n_embd_r() const {
|
||||||
if (wkv_head_size != 0) {
|
if (wkv_head_size != 0) {
|
||||||
// for RWKV models
|
// for RWKV models
|
||||||
return token_shift_count * n_embd;
|
return token_shift_count * n_embd;
|
||||||
|
@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_embd_v_s() const {
|
uint32_t llama_hparams::n_embd_s() const {
|
||||||
if (wkv_head_size != 0) {
|
if (wkv_head_size != 0) {
|
||||||
// corresponds to RWKV's wkv_states size
|
// corresponds to RWKV's wkv_states size
|
||||||
return n_embd * wkv_head_size;
|
return n_embd * wkv_head_size;
|
||||||
|
@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||||
|
return recurrent_layer_arr[il];
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return swa_layers[il];
|
return swa_layers[il];
|
||||||
|
|
|
@ -115,6 +115,9 @@ struct llama_hparams {
|
||||||
uint32_t ssm_d_state = 0;
|
uint32_t ssm_d_state = 0;
|
||||||
uint32_t ssm_dt_rank = 0;
|
uint32_t ssm_dt_rank = 0;
|
||||||
|
|
||||||
|
// for hybrid state space models
|
||||||
|
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||||
|
|
||||||
bool ssm_dt_b_c_rms = false;
|
bool ssm_dt_b_c_rms = false;
|
||||||
|
|
||||||
float f_clamp_kqv = 0.0f;
|
float f_clamp_kqv = 0.0f;
|
||||||
|
@ -181,10 +184,13 @@ struct llama_hparams {
|
||||||
|
|
||||||
// dimension of the rolling state embeddings
|
// dimension of the rolling state embeddings
|
||||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||||
uint32_t n_embd_k_s() const;
|
uint32_t n_embd_r() const;
|
||||||
|
|
||||||
// dimension of the recurrent state embeddings
|
// dimension of the recurrent state embeddings
|
||||||
uint32_t n_embd_v_s() const;
|
uint32_t n_embd_s() const;
|
||||||
|
|
||||||
|
// whether or not the given layer is recurrent (for hybrid models)
|
||||||
|
bool is_recurrent(uint32_t il) const;
|
||||||
|
|
||||||
bool is_swa(uint32_t il) const;
|
bool is_swa(uint32_t il) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -197,21 +197,19 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
llama_kv_cache_unified_iswa * kv) :
|
||||||
state_base = kv->get_base()->init_full();
|
state_base(kv->get_base()->init_full()),
|
||||||
state_swa = kv->get_swa ()->init_full();
|
state_swa (kv->get_swa ()->init_full()),
|
||||||
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
bool optimize) :
|
||||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
state_base(kv->get_base()->init_update(lctx, optimize)),
|
||||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||||
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
|
@ -219,15 +217,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches)
|
std::vector<llama_ubatch> ubatches) :
|
||||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
|
||||||
sbatch(std::move(sbatch)),
|
sbatch(std::move(sbatch)),
|
||||||
ubatches(std::move(ubatches)) {
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
||||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
||||||
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||||
|
|
|
@ -117,8 +117,6 @@ public:
|
||||||
const llama_kv_cache_unified_state * get_swa() const;
|
const llama_kv_cache_unified_state * get_swa() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llama_memory_status status;
|
|
||||||
|
|
||||||
//llama_kv_cache_unified_iswa * kv;
|
//llama_kv_cache_unified_iswa * kv;
|
||||||
|
|
||||||
llama_sbatch sbatch;
|
llama_sbatch sbatch;
|
||||||
|
@ -128,6 +126,8 @@ private:
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
llama_memory_state_ptr state_base;
|
const llama_memory_state_ptr state_base;
|
||||||
llama_memory_state_ptr state_swa;
|
const llama_memory_state_ptr state_swa;
|
||||||
|
|
||||||
|
const llama_memory_status status;
|
||||||
};
|
};
|
||||||
|
|
|
@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
const char * dev_name = "CPU";
|
const char * dev_name = "CPU";
|
||||||
|
|
||||||
|
@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
|
||||||
// Write key type
|
// Write key type
|
||||||
const int32_t k_type_i = (int32_t)layer.k->type;
|
const int32_t k_type_i = (int32_t)layer.k->type;
|
||||||
|
@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||||
|
@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||||
|
@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
|
||||||
// Read type of key
|
// Read type of key
|
||||||
int32_t k_type_i_ref;
|
int32_t k_type_i_ref;
|
||||||
|
@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
|
@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
|
|
247
src/llama-memory-hybrid.cpp
Normal file
247
src/llama-memory-hybrid.cpp
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
#include "llama-memory-hybrid.h"
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
#include "llama-model.h"
|
||||||
|
#include "llama-context.h"
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_hybrid::llama_memory_hybrid(
|
||||||
|
const llama_model & model,
|
||||||
|
/* attn */
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_pad,
|
||||||
|
uint32_t n_swa,
|
||||||
|
llama_swa_type swa_type,
|
||||||
|
/* recurrent */
|
||||||
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
|
uint32_t rs_size,
|
||||||
|
/* common */
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
bool offload,
|
||||||
|
/* layer filters */
|
||||||
|
layer_filter_cb && filter_attn,
|
||||||
|
layer_filter_cb && filter_recr) :
|
||||||
|
hparams(model.hparams),
|
||||||
|
mem_attn(new llama_kv_cache_unified(
|
||||||
|
model,
|
||||||
|
filter_attn == nullptr ?
|
||||||
|
[&](int32_t il) { return !model.hparams.is_recurrent(il); }
|
||||||
|
: filter_attn,
|
||||||
|
type_k,
|
||||||
|
type_v,
|
||||||
|
v_trans,
|
||||||
|
offload,
|
||||||
|
kv_size,
|
||||||
|
n_seq_max,
|
||||||
|
n_pad,
|
||||||
|
n_swa,
|
||||||
|
swa_type
|
||||||
|
)),
|
||||||
|
mem_recr(new llama_memory_recurrent(
|
||||||
|
model,
|
||||||
|
filter_recr == nullptr ?
|
||||||
|
[&](int32_t il) { return model.hparams.is_recurrent(il); }
|
||||||
|
: filter_recr,
|
||||||
|
type_r,
|
||||||
|
type_s,
|
||||||
|
offload,
|
||||||
|
rs_size,
|
||||||
|
n_seq_max
|
||||||
|
)) {}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||||
|
|
||||||
|
// since this includes a recurrent cache, we cannot use split_simple
|
||||||
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||||
|
|
||||||
|
// follow the recurrent pattern for creating the ubatch splits
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
while (sbatch.n_tokens > 0) {
|
||||||
|
llama_ubatch ubatch;
|
||||||
|
|
||||||
|
if (embd_pooled) {
|
||||||
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
|
ubatch = sbatch.split_seq(n_ubatch);
|
||||||
|
} else {
|
||||||
|
ubatch = sbatch.split_equal(n_ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
ubatches.push_back(ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the recurrent batches first
|
||||||
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
|
// TODO: will the recurrent cache be in an undefined state at this point?
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||||
|
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the attention cache
|
||||||
|
auto heads_attn = mem_attn->prepare(ubatches);
|
||||||
|
if (heads_attn.empty()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||||
|
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_unique<llama_memory_hybrid_state>(
|
||||||
|
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
||||||
|
return std::make_unique<llama_memory_hybrid_state>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||||
|
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid::get_can_shift() const {
|
||||||
|
// Shifting is trivially supported for recurrent
|
||||||
|
return mem_attn->get_can_shift();
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::clear(bool data) {
|
||||||
|
mem_attn->clear(data);
|
||||||
|
mem_recr->clear(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
|
// Try removing from the recurrent cache first since it may fail. If it does
|
||||||
|
// fail, the cache will not have been mutated.
|
||||||
|
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
|
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
||||||
|
mem_attn->seq_keep(seq_id);
|
||||||
|
mem_recr->seq_keep(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
|
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||||
|
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
|
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||||
|
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
|
// the min of the total cache is the max of the two caches' min values
|
||||||
|
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
|
// the max of the total cache is the min of the two caches' max values
|
||||||
|
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||||
|
mem_attn->state_write(io, seq_id);
|
||||||
|
mem_recr->state_write(io, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
|
mem_attn->state_read(io, seq_id);
|
||||||
|
mem_recr->state_read(io, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||||
|
return mem_attn.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
||||||
|
return mem_recr.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
|
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
||||||
|
state_attn(mem->get_mem_attn()->init_full()),
|
||||||
|
state_recr(mem->get_mem_recr()->init_full()),
|
||||||
|
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
||||||
|
llama_memory_hybrid * mem,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize) :
|
||||||
|
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||||
|
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||||
|
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
||||||
|
llama_memory_hybrid * mem,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
std::vector<uint32_t> heads_attn,
|
||||||
|
std::vector<llama_ubatch> ubatches) :
|
||||||
|
sbatch(std::move(sbatch)),
|
||||||
|
ubatches(std::move(ubatches)),
|
||||||
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
|
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
|
||||||
|
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
|
||||||
|
status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_state::next() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
state_attn->next();
|
||||||
|
state_recr->next();
|
||||||
|
|
||||||
|
if (++i_next >= ubatches.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_state::apply() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res = res & state_attn->apply();
|
||||||
|
res = res & state_recr->apply();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
return sbatch.out_ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
return ubatches[i_next];
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
||||||
|
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
||||||
|
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
||||||
|
}
|
143
src/llama-memory-hybrid.h
Normal file
143
src/llama-memory-hybrid.h
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-batch.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
|
#include "llama-kv-cache-unified.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid
|
||||||
|
//
|
||||||
|
|
||||||
|
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||||
|
// support models where each layer may be either attention-based or recurrent
|
||||||
|
|
||||||
|
class llama_memory_hybrid : public llama_memory_i {
|
||||||
|
public:
|
||||||
|
|
||||||
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
|
llama_memory_hybrid(
|
||||||
|
const llama_model & model,
|
||||||
|
/* attn */
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_pad,
|
||||||
|
uint32_t n_swa,
|
||||||
|
llama_swa_type swa_type,
|
||||||
|
/* recurrent */
|
||||||
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
|
uint32_t rs_size,
|
||||||
|
/* common */
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
bool offload,
|
||||||
|
/* layer filters */
|
||||||
|
layer_filter_cb && filter_attn = nullptr,
|
||||||
|
layer_filter_cb && filter_recr = nullptr);
|
||||||
|
|
||||||
|
~llama_memory_hybrid() = default;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_i
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled) override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_kv_cache_unified * get_mem_attn() const;
|
||||||
|
llama_memory_recurrent * get_mem_recr() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||||
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_memory_hybrid_state : public llama_memory_state_i {
|
||||||
|
public:
|
||||||
|
// init failure
|
||||||
|
explicit llama_memory_hybrid_state(llama_memory_status status);
|
||||||
|
|
||||||
|
// init full
|
||||||
|
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
|
||||||
|
|
||||||
|
// init update
|
||||||
|
explicit llama_memory_hybrid_state(
|
||||||
|
llama_memory_hybrid * mem,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize);
|
||||||
|
|
||||||
|
// init success
|
||||||
|
llama_memory_hybrid_state(
|
||||||
|
llama_memory_hybrid * mem,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
std::vector<uint32_t> heads_attn,
|
||||||
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
|
~llama_memory_hybrid_state() = default;
|
||||||
|
|
||||||
|
bool next() override;
|
||||||
|
bool apply() override;
|
||||||
|
|
||||||
|
std::vector<int64_t> & out_ids() override;
|
||||||
|
|
||||||
|
llama_memory_status get_status() const override;
|
||||||
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_state
|
||||||
|
//
|
||||||
|
|
||||||
|
const llama_kv_cache_unified_state * get_state_attn() const;
|
||||||
|
const llama_memory_recurrent_state * get_state_recr() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
|
// the index of the next ubatch to process
|
||||||
|
size_t i_next = 0;
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
const llama_memory_state_ptr state_attn;
|
||||||
|
const llama_memory_state_ptr state_recr;
|
||||||
|
|
||||||
|
const llama_memory_status status;
|
||||||
|
};
|
|
@ -1,4 +1,4 @@
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
|
@ -12,27 +12,28 @@
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_recurrent
|
// llama_memory_recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
llama_memory_recurrent::llama_memory_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
ggml_type type_k,
|
layer_filter_cb && filter,
|
||||||
ggml_type type_v,
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
bool offload,
|
bool offload,
|
||||||
uint32_t kv_size,
|
uint32_t mem_size,
|
||||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||||
const int32_t n_layer = hparams.n_layer;
|
const int32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
||||||
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
||||||
|
|
||||||
head = 0;
|
head = 0;
|
||||||
size = kv_size;
|
size = mem_size;
|
||||||
used = 0;
|
used = 0;
|
||||||
|
|
||||||
cells.clear();
|
cells.clear();
|
||||||
cells.resize(kv_size);
|
cells.resize(mem_size);
|
||||||
|
|
||||||
// create a context for each buffer type
|
// create a context for each buffer type
|
||||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||||
|
@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||||
return it->second;
|
return it->second;
|
||||||
};
|
};
|
||||||
|
|
||||||
k_l.reserve(n_layer);
|
r_l.resize(n_layer);
|
||||||
v_l.reserve(n_layer);
|
s_l.resize(n_layer);
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
if (filter && !filter(i)) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const char * dev_name = "CPU";
|
const char * dev_name = "CPU";
|
||||||
|
|
||||||
|
@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||||
throw std::runtime_error("failed to create ggml context for kv cache");
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
||||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
||||||
ggml_format_name(k, "cache_k_l%d", i);
|
ggml_format_name(r, "cache_r_l%d", i);
|
||||||
ggml_format_name(v, "cache_v_l%d", i);
|
ggml_format_name(s, "cache_s_l%d", i);
|
||||||
k_l.push_back(k);
|
r_l[i] = r;
|
||||||
v_l.push_back(v);
|
s_l[i] = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||||
|
@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const size_t memory_size_k = size_k_bytes();
|
const size_t memory_size_r = size_r_bytes();
|
||||||
const size_t memory_size_v = size_v_bytes();
|
const size_t memory_size_s = size_s_bytes();
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
||||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
||||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::clear(bool data) {
|
void llama_memory_recurrent::clear(bool data) {
|
||||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||||
cells[i].pos = -1;
|
cells[i].pos = -1;
|
||||||
cells[i].seq_id.clear();
|
cells[i].seq_id.clear();
|
||||||
|
@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = size;
|
||||||
|
|
||||||
if (p0 < 0) {
|
if (p0 < 0) {
|
||||||
|
@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
||||||
if (0 <= seq_id) {
|
if (0 <= seq_id) {
|
||||||
int32_t & tail_id = cells[seq_id].tail;
|
int32_t & tail_id = cells[seq_id].tail;
|
||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
const kv_cell & cell = cells[tail_id];
|
const auto & cell = cells[tail_id];
|
||||||
// partial intersection is invalid
|
// partial intersection is invalid
|
||||||
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
if (seq_id_src == seq_id_dst) {
|
if (seq_id_src == seq_id_dst) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||||
kv_cell & tail_src = cells[seq_id_src];
|
auto & tail_src = cells[seq_id_src];
|
||||||
kv_cell & tail_dst = cells[seq_id_dst];
|
auto & tail_dst = cells[seq_id_dst];
|
||||||
if (tail_dst.tail >= 0) {
|
if (tail_dst.tail >= 0) {
|
||||||
// clear destination seq_id if it wasn't empty
|
// clear destination seq_id if it wasn't empty
|
||||||
kv_cell & cell_dst = cells[tail_dst.tail];
|
auto & cell_dst = cells[tail_dst.tail];
|
||||||
|
|
||||||
cell_dst.seq_id.erase(seq_id_dst);
|
cell_dst.seq_id.erase(seq_id_dst);
|
||||||
tail_dst.tail = -1;
|
tail_dst.tail = -1;
|
||||||
|
@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tail_src.tail >= 0) {
|
if (tail_src.tail >= 0) {
|
||||||
kv_cell & cell_src = cells[tail_src.tail];
|
auto & cell_src = cells[tail_src.tail];
|
||||||
|
|
||||||
cell_src.seq_id.insert(seq_id_dst);
|
cell_src.seq_id.insert(seq_id_dst);
|
||||||
tail_dst.tail = tail_src.tail;
|
tail_dst.tail = tail_src.tail;
|
||||||
|
@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = size;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
if (shift == 0) {
|
if (shift == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
||||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||||
const int32_t tail_id = cells[seq_id].tail;
|
const int32_t tail_id = cells[seq_id].tail;
|
||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
kv_cell & cell = cells[tail_id];
|
auto & cell = cells[tail_id];
|
||||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
cell.pos += shift;
|
cell.pos += shift;
|
||||||
}
|
}
|
||||||
|
@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
if (d == 1) {
|
if (d == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
||||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||||
const int32_t tail_id = cells[seq_id].tail;
|
const int32_t tail_id = cells[seq_id].tail;
|
||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
kv_cell & cell = cells[tail_id];
|
auto & cell = cells[tail_id];
|
||||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
cell.pos /= d;
|
cell.pos /= d;
|
||||||
}
|
}
|
||||||
|
@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
llama_pos result = -1;
|
llama_pos result = -1;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
@ -359,7 +362,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
@ -378,24 +381,24 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!prepare(ubatches)) {
|
if (!prepare(ubatches)) {
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
|
return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
return std::make_unique<llama_memory_recurrent_state>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
GGML_UNUSED(lctx);
|
GGML_UNUSED(lctx);
|
||||||
GGML_UNUSED(optimize);
|
GGML_UNUSED(optimize);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
// simply remember the full state because it is very small for this type of cache
|
// simply remember the full state because it is very small for this type of cache
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
auto org_cells = cells;
|
auto org_cells = cells;
|
||||||
|
@ -419,7 +422,7 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
const uint32_t n_seqs = ubatch.n_seqs;
|
const uint32_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
@ -453,9 +456,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
kv_cell & seq = cells[seq_id];
|
auto & seq = cells[seq_id];
|
||||||
if (seq.tail >= 0) {
|
if (seq.tail >= 0) {
|
||||||
kv_cell & cell = cells[seq.tail];
|
auto & cell = cells[seq.tail];
|
||||||
// clear cells from seq_ids that become shared
|
// clear cells from seq_ids that become shared
|
||||||
// (should not normally happen, but let's handle it anyway)
|
// (should not normally happen, but let's handle it anyway)
|
||||||
cell.seq_id.erase(seq_id);
|
cell.seq_id.erase(seq_id);
|
||||||
|
@ -475,7 +478,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
std::vector<int32_t> tails_verif;
|
std::vector<int32_t> tails_verif;
|
||||||
tails_verif.assign(size, -1);
|
tails_verif.assign(size, -1);
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
kv_cell & cell = cells[i];
|
auto & cell = cells[i];
|
||||||
for (llama_seq_id seq_id : cell.seq_id) {
|
for (llama_seq_id seq_id : cell.seq_id) {
|
||||||
if (tails_verif[seq_id] != -1) {
|
if (tails_verif[seq_id] != -1) {
|
||||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
||||||
|
@ -496,7 +499,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||||
kv_cell & cell = cells[next_empty_cell];
|
auto & cell = cells[next_empty_cell];
|
||||||
if (cell.is_empty()) { break; }
|
if (cell.is_empty()) { break; }
|
||||||
next_empty_cell += 1;
|
next_empty_cell += 1;
|
||||||
}
|
}
|
||||||
|
@ -504,20 +507,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
// find usable cell range
|
// find usable cell range
|
||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
kv_cell & seq_meta = cells[seq_id];
|
auto & seq_meta = cells[seq_id];
|
||||||
bool has_cell = false;
|
bool has_cell = false;
|
||||||
if (seq_meta.tail >= 0) {
|
if (seq_meta.tail >= 0) {
|
||||||
kv_cell & cell = cells[seq_meta.tail];
|
auto & cell = cells[seq_meta.tail];
|
||||||
GGML_ASSERT(cell.has_seq_id(seq_id));
|
GGML_ASSERT(cell.has_seq_id(seq_id));
|
||||||
// does this seq_id "own" the cell?
|
// does this seq_id "own" the cell?
|
||||||
if (cell.seq_id.size() == 1) { has_cell = true; }
|
if (cell.seq_id.size() == 1) { has_cell = true; }
|
||||||
}
|
}
|
||||||
if (!has_cell) {
|
if (!has_cell) {
|
||||||
kv_cell & empty_cell = cells[next_empty_cell];
|
auto & empty_cell = cells[next_empty_cell];
|
||||||
GGML_ASSERT(empty_cell.is_empty());
|
GGML_ASSERT(empty_cell.is_empty());
|
||||||
// copy old tail into the empty cell
|
// copy old tail into the empty cell
|
||||||
if (seq_meta.tail >= 0) {
|
if (seq_meta.tail >= 0) {
|
||||||
kv_cell & orig_cell = cells[seq_meta.tail];
|
auto & orig_cell = cells[seq_meta.tail];
|
||||||
empty_cell.pos = orig_cell.pos;
|
empty_cell.pos = orig_cell.pos;
|
||||||
empty_cell.src = orig_cell.src;
|
empty_cell.src = orig_cell.src;
|
||||||
orig_cell.seq_id.erase(seq_id);
|
orig_cell.seq_id.erase(seq_id);
|
||||||
|
@ -530,7 +533,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
next_empty_cell += 1;
|
next_empty_cell += 1;
|
||||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||||
kv_cell & cell = cells[next_empty_cell];
|
auto & cell = cells[next_empty_cell];
|
||||||
if (cell.is_empty()) { break; }
|
if (cell.is_empty()) { break; }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -544,8 +547,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
const int32_t dst_id = s + min;
|
const int32_t dst_id = s + min;
|
||||||
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||||
if (dst_id != src_id) {
|
if (dst_id != src_id) {
|
||||||
kv_cell & dst_cell = cells[dst_id];
|
auto & dst_cell = cells[dst_id];
|
||||||
kv_cell & src_cell = cells[src_id];
|
auto & src_cell = cells[src_id];
|
||||||
|
|
||||||
std::swap(dst_cell.pos, src_cell.pos);
|
std::swap(dst_cell.pos, src_cell.pos);
|
||||||
std::swap(dst_cell.src, src_cell.src);
|
std::swap(dst_cell.src, src_cell.src);
|
||||||
|
@ -567,7 +570,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||||
const int32_t cell_id = s + min;
|
const int32_t cell_id = s + min;
|
||||||
kv_cell & cell = cells[cell_id];
|
auto & cell = cells[cell_id];
|
||||||
|
|
||||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||||
// What should happen when the pos backtracks or skips a value?
|
// What should happen when the pos backtracks or skips a value?
|
||||||
|
@ -620,18 +623,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
head = min;
|
head = min;
|
||||||
n = max - min + 1;
|
n = max - min + 1;
|
||||||
used = std::count_if(cells.begin(), cells.end(),
|
used = std::count_if(cells.begin(), cells.end(),
|
||||||
[](const kv_cell & cell){ return !cell.is_empty(); });
|
[](const mem_cell & cell){ return !cell.is_empty(); });
|
||||||
|
|
||||||
// sanity check
|
// sanity check
|
||||||
return n >= n_seqs;
|
return n >= n_seqs;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::get_can_shift() const {
|
bool llama_memory_recurrent::get_can_shift() const {
|
||||||
// shifting the pos is trivial for recurrent models
|
// shifting the pos is trivial for recurrent models
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache_recurrent::total_size() const {
|
size_t llama_memory_recurrent::total_size() const {
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
for (const auto & buf : bufs) {
|
for (const auto & buf : bufs) {
|
||||||
size += ggml_backend_buffer_get_size(buf.get());
|
size += ggml_backend_buffer_get_size(buf.get());
|
||||||
|
@ -640,27 +643,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache_recurrent::size_k_bytes() const {
|
size_t llama_memory_recurrent::size_r_bytes() const {
|
||||||
size_t size_k_bytes = 0;
|
size_t size_r_bytes = 0;
|
||||||
|
|
||||||
for (const auto & k : k_l) {
|
for (const auto & r : r_l) {
|
||||||
size_k_bytes += ggml_nbytes(k);
|
if (r != nullptr) {
|
||||||
|
size_r_bytes += ggml_nbytes(r);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return size_k_bytes;
|
return size_r_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache_recurrent::size_v_bytes() const {
|
size_t llama_memory_recurrent::size_s_bytes() const {
|
||||||
size_t size_v_bytes = 0;
|
size_t size_s_bytes = 0;
|
||||||
|
|
||||||
for (const auto & v : v_l) {
|
for (const auto & s : s_l) {
|
||||||
size_v_bytes += ggml_nbytes(v);
|
if (s != nullptr) {
|
||||||
|
size_s_bytes += ggml_nbytes(s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return size_v_bytes;
|
return size_s_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
|
@ -698,7 +705,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
||||||
state_write_data(io, cell_ranges);
|
state_write_data(io, cell_ranges);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
uint32_t cell_count;
|
uint32_t cell_count;
|
||||||
io.read_to(&cell_count, sizeof(cell_count));
|
io.read_to(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
|
@ -717,7 +724,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||||
const auto & cell = cells[i];
|
const auto & cell = cells[i];
|
||||||
|
@ -736,11 +743,11 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||||
const uint32_t v_trans = 0;
|
const uint32_t s_trans = 0;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
io.write(&v_trans, sizeof(v_trans));
|
io.write(&s_trans, sizeof(s_trans));
|
||||||
io.write(&n_layer, sizeof(n_layer));
|
io.write(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
std::vector<uint8_t> tmp_buf;
|
std::vector<uint8_t> tmp_buf;
|
||||||
|
@ -748,75 +755,73 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
|
||||||
// Iterate and write all the keys first, each row is a cell
|
// Iterate and write all the keys first, each row is a cell
|
||||||
// Get whole range at a time
|
// Get whole range at a time
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
||||||
|
|
||||||
// Write key type
|
// Write key type
|
||||||
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
||||||
io.write(&k_type_i, sizeof(k_type_i));
|
io.write(&r_type_i, sizeof(r_type_i));
|
||||||
|
|
||||||
// Write row size of key
|
// Write row size of key
|
||||||
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||||
io.write(&k_size_row, sizeof(k_size_row));
|
io.write(&r_size_row, sizeof(r_size_row));
|
||||||
|
|
||||||
// Read each range of cells of k_size length each into tmp_buf and write out
|
// Read each range of cells of k_size length each into tmp_buf and write out
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
const size_t range_size = range.second - range.first;
|
const size_t range_size = range.second - range.first;
|
||||||
const size_t buf_size = range_size * k_size_row;
|
const size_t buf_size = range_size * r_size_row;
|
||||||
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!s_trans) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||||
io.write(&v_type_i, sizeof(v_type_i));
|
io.write(&s_type_i, sizeof(s_type_i));
|
||||||
|
|
||||||
// Write row size of value
|
// Write row size of value
|
||||||
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||||
io.write(&v_size_row, sizeof(v_size_row));
|
io.write(&s_size_row, sizeof(s_size_row));
|
||||||
|
|
||||||
// Read each range of cells of v_size length each into tmp_buf and write out
|
// Read each range of cells of s_size length each into tmp_buf and write out
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
const size_t range_size = range.second - range.first;
|
const size_t range_size = range.second - range.first;
|
||||||
const size_t buf_size = range_size * v_size_row;
|
const size_t buf_size = range_size * s_size_row;
|
||||||
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||||
const uint32_t kv_size = size;
|
const uint32_t mem_size = size;
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||||
io.write(&v_type_i, sizeof(v_type_i));
|
io.write(&s_type_i, sizeof(s_type_i));
|
||||||
|
|
||||||
// Write element size
|
// Write element size
|
||||||
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||||
io.write(&v_size_el, sizeof(v_size_el));
|
io.write(&s_size_el, sizeof(s_size_el));
|
||||||
|
|
||||||
// Write GQA embedding size
|
// Write GQA embedding size
|
||||||
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
io.write(&n_embd_s, sizeof(n_embd_s));
|
||||||
|
|
||||||
// For each row, we get the element values of each cell
|
// For each row, we get the element values of each cell
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||||
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
const size_t range_size = range.second - range.first;
|
const size_t range_size = range.second - range.first;
|
||||||
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
||||||
const size_t buf_size = range_size * v_size_el;
|
const size_t buf_size = range_size * s_size_el;
|
||||||
io.write_tensor(v_l[il], src_offset, buf_size);
|
io.write_tensor(s_l[il], src_offset, buf_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||||
if (dest_seq_id != -1) {
|
if (dest_seq_id != -1) {
|
||||||
// single sequence
|
// single sequence
|
||||||
|
|
||||||
|
@ -869,7 +874,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||||
clear(true);
|
clear(true);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
kv_cell & cell = cells[i];
|
auto & cell = cells[i];
|
||||||
|
|
||||||
llama_pos pos;
|
llama_pos pos;
|
||||||
uint32_t n_seq_id;
|
uint32_t n_seq_id;
|
||||||
|
@ -883,7 +888,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||||
llama_seq_id seq_id;
|
llama_seq_id seq_id;
|
||||||
io.read_to(&seq_id, sizeof(seq_id));
|
io.read_to(&seq_id, sizeof(seq_id));
|
||||||
|
|
||||||
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
|
// TODO: llama_memory_recurrent should have a notion of max sequences
|
||||||
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||||
if (seq_id < 0) {
|
if (seq_id < 0) {
|
||||||
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||||
|
@ -915,10 +920,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||||
uint32_t v_trans;
|
uint32_t s_trans;
|
||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
io.read_to(&v_trans, sizeof(v_trans));
|
io.read_to(&s_trans, sizeof(s_trans));
|
||||||
io.read_to(&n_layer, sizeof(n_layer));
|
io.read_to(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
if (n_layer != hparams.n_layer) {
|
if (n_layer != hparams.n_layer) {
|
||||||
|
@ -929,102 +934,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
||||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (false != (bool) v_trans) {
|
if (false != (bool) s_trans) {
|
||||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
||||||
|
|
||||||
// Read type of key
|
// Read type of key
|
||||||
int32_t k_type_i_ref;
|
int32_t r_type_i_ref;
|
||||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||||
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||||
if (k_type_i != k_type_i_ref) {
|
if (r_type_i != r_type_i_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read row size of key
|
// Read row size of key
|
||||||
uint64_t k_size_row_ref;
|
uint64_t r_size_row_ref;
|
||||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||||
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||||
if (k_size_row != k_size_row_ref) {
|
if (r_size_row != r_size_row_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// Read and set the keys for the whole cell range
|
// Read and set the keys for the whole cell range
|
||||||
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!s_trans) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t s_type_i_ref;
|
||||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||||
if (v_type_i != v_type_i_ref) {
|
if (s_type_i != s_type_i_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read row size of value
|
// Read row size of value
|
||||||
uint64_t v_size_row_ref;
|
uint64_t s_size_row_ref;
|
||||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||||
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||||
if (v_size_row != v_size_row_ref) {
|
if (s_size_row != s_size_row_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// Read and set the values for the whole cell range
|
// Read and set the values for the whole cell range
|
||||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For each layer, read the values for each cell (transposed)
|
// For each layer, read the values for each cell (transposed)
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t s_type_i_ref;
|
||||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||||
if (v_type_i != v_type_i_ref) {
|
if (s_type_i != s_type_i_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read element size of value
|
// Read element size of value
|
||||||
uint32_t v_size_el_ref;
|
uint32_t s_size_el_ref;
|
||||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||||
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||||
if (v_size_el != v_size_el_ref) {
|
if (s_size_el != s_size_el_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read GQA embedding size
|
// Read state embedding size
|
||||||
uint32_t n_embd_v_gqa_ref;
|
uint32_t n_embd_s_ref;
|
||||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
if (n_embd_s != n_embd_s_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// For each row in the transposed matrix, read the values for the whole cell range
|
// For each row in the transposed matrix, read the values for the whole cell range
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1034,25 +1037,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_recurrent_state
|
// llama_memory_recurrent_state
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||||
llama_memory_status status,
|
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||||
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||||
llama_memory_status status,
|
llama_memory_recurrent * mem,
|
||||||
llama_kv_cache_recurrent * kv,
|
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent_state::next() {
|
bool llama_memory_recurrent_state::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
|
@ -1062,54 +1063,54 @@ bool llama_kv_cache_recurrent_state::next() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent_state::apply() {
|
bool llama_memory_recurrent_state::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
kv->find_slot(ubatches[i_next]);
|
mem->find_slot(ubatches[i_next]);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
|
std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return sbatch.out_ids;
|
return sbatch.out_ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
|
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
|
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
||||||
return is_full ? kv->size : kv->n;
|
return is_full ? mem->size : mem->n;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
uint32_t llama_memory_recurrent_state::get_head() const {
|
||||||
return is_full ? 0 : kv->head;
|
return is_full ? 0 : mem->head;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
|
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
||||||
return is_full ? 0 : kv->rs_z;
|
return is_full ? 0 : mem->rs_z;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
uint32_t llama_memory_recurrent_state::get_size() const {
|
||||||
return kv->size;
|
return mem->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
||||||
return kv->k_l[il];
|
return mem->r_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
||||||
return kv->v_l[il];
|
return mem->s_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
||||||
return kv->cells[i + kv->head].src0;
|
return mem->cells[i + mem->head].src0;
|
||||||
}
|
}
|
|
@ -8,22 +8,27 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_recurrent
|
// llama_memory_recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
|
||||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
class llama_memory_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
llama_kv_cache_recurrent(
|
|
||||||
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
|
llama_memory_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
ggml_type type_k,
|
layer_filter_cb && filter,
|
||||||
ggml_type type_v,
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
bool offload,
|
bool offload,
|
||||||
uint32_t kv_size,
|
uint32_t mem_size,
|
||||||
uint32_t n_seq_max);
|
uint32_t n_seq_max);
|
||||||
|
|
||||||
~llama_kv_cache_recurrent() = default;
|
~llama_memory_recurrent() = default;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
|
@ -51,7 +56,7 @@ public:
|
||||||
|
|
||||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
// find a contiguous slot of memory cells and emplace the ubatch there
|
||||||
bool find_slot(const llama_ubatch & ubatch);
|
bool find_slot(const llama_ubatch & ubatch);
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
@ -72,7 +77,7 @@ public:
|
||||||
int32_t rs_z = -1;
|
int32_t rs_z = -1;
|
||||||
|
|
||||||
// TODO: optimize for recurrent state needs
|
// TODO: optimize for recurrent state needs
|
||||||
struct kv_cell {
|
struct mem_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
int32_t src = -1; // used to know where states should be copied from
|
int32_t src = -1; // used to know where states should be copied from
|
||||||
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
||||||
|
@ -88,15 +93,16 @@ public:
|
||||||
return seq_id.empty();
|
return seq_id.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_same_seq(const kv_cell & other) const {
|
bool is_same_seq(const mem_cell & other) const {
|
||||||
return seq_id == other.seq_id;
|
return seq_id == other.seq_id;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<kv_cell> cells;
|
std::vector<mem_cell> cells;
|
||||||
|
|
||||||
std::vector<ggml_tensor *> k_l; // per layer
|
// per layer
|
||||||
std::vector<ggml_tensor *> v_l;
|
std::vector<ggml_tensor *> r_l;
|
||||||
|
std::vector<ggml_tensor *> s_l;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//const llama_model & model;
|
//const llama_model & model;
|
||||||
|
@ -109,8 +115,8 @@ private:
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
|
|
||||||
size_t size_k_bytes() const;
|
size_t size_r_bytes() const;
|
||||||
size_t size_v_bytes() const;
|
size_t size_s_bytes() const;
|
||||||
|
|
||||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||||
|
@ -119,24 +125,22 @@ private:
|
||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
class llama_memory_recurrent_state : public llama_memory_state_i {
|
||||||
public:
|
public:
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
llama_memory_recurrent_state(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache state
|
||||||
llama_kv_cache_recurrent_state(
|
llama_memory_recurrent_state(
|
||||||
llama_memory_status status,
|
llama_memory_recurrent * mem);
|
||||||
llama_kv_cache_recurrent * kv);
|
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a state from a batch
|
||||||
llama_kv_cache_recurrent_state(
|
llama_memory_recurrent_state(
|
||||||
llama_memory_status status,
|
llama_memory_recurrent * mem,
|
||||||
llama_kv_cache_recurrent * kv,
|
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_recurrent_state();
|
virtual ~llama_memory_recurrent_state();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_state_i
|
||||||
|
@ -151,23 +155,23 @@ public:
|
||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_recurrent_state specific API
|
// llama_memory_recurrent_state specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_kv() const;
|
uint32_t get_n_rs() const;
|
||||||
uint32_t get_head() const;
|
uint32_t get_head() const;
|
||||||
int32_t get_rs_z() const;
|
int32_t get_rs_z() const;
|
||||||
uint32_t get_size() const;
|
uint32_t get_size() const;
|
||||||
|
|
||||||
ggml_tensor * get_k_l(int32_t il) const;
|
ggml_tensor * get_r_l(int32_t il) const;
|
||||||
ggml_tensor * get_v_l(int32_t il) const;
|
ggml_tensor * get_s_l(int32_t il) const;
|
||||||
|
|
||||||
int32_t s_copy(int i) const;
|
int32_t s_copy(int i) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llama_memory_status status;
|
const llama_memory_status status;
|
||||||
|
|
||||||
llama_kv_cache_recurrent * kv;
|
llama_memory_recurrent * mem;
|
||||||
|
|
||||||
llama_sbatch sbatch;
|
llama_sbatch sbatch;
|
||||||
|
|
|
@ -8,7 +8,8 @@
|
||||||
|
|
||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
#include "llama-kv-cache-unified-iswa.h"
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-memory-hybrid.h"
|
||||||
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
|
@ -475,6 +476,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||||
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
||||||
|
std::fill(
|
||||||
|
hparams.recurrent_layer_arr.begin(),
|
||||||
|
hparams.recurrent_layer_arr.end(),
|
||||||
|
llm_arch_is_recurrent(ml.get_arch()));
|
||||||
|
|
||||||
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
||||||
|
|
||||||
|
@ -9211,7 +9216,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
ggml_tensor * state_copy = build_inp_s_copy();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
// norm
|
// norm
|
||||||
|
@ -9220,7 +9225,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
|
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
// skip computing output for unused tokens
|
// skip computing output for unused tokens
|
||||||
|
@ -9258,12 +9263,12 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
|
|
||||||
// TODO: split
|
// TODO: split
|
||||||
ggml_tensor * build_mamba_layer(
|
ggml_tensor * build_mamba_layer(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
ggml_tensor * state_copy,
|
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = kv_state->get_head();
|
||||||
|
|
||||||
|
@ -9283,17 +9288,17 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
GGML_ASSERT(ubatch.equal_seqs);
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||||
|
|
||||||
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
|
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||||
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||||
|
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
ggml_tensor * conv = build_recurrent_state(
|
ggml_tensor * conv = build_rs(
|
||||||
gf, conv_states_all, state_copy,
|
inp, gf, conv_states_all,
|
||||||
hparams.n_embd_k_s(), n_seqs);
|
hparams.n_embd_r(), n_seqs);
|
||||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
||||||
ggml_tensor * ssm = build_recurrent_state(
|
ggml_tensor * ssm = build_rs(
|
||||||
gf, ssm_states_all, state_copy,
|
inp, gf, ssm_states_all,
|
||||||
hparams.n_embd_v_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
||||||
|
|
||||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
|
@ -12004,13 +12009,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_rwkv6_time_mix(
|
ggml_tensor * build_rwkv6_time_mix(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
ggml_tensor * x_prev,
|
ggml_tensor * x_prev,
|
||||||
ggml_tensor * state_copy,
|
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
|
@ -12131,9 +12136,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_recurrent_state(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
gf, kv_state->get_v_l(il), state_copy,
|
inp, gf, kv_state->get_s_l(il),
|
||||||
hparams.n_embd_v_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output;
|
ggml_tensor * wkv_output;
|
||||||
if (is_qrwkv) {
|
if (is_qrwkv) {
|
||||||
|
@ -12151,9 +12156,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_v_l(il),
|
kv_state->get_s_l(il),
|
||||||
hparams.n_embd_v_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
@ -12187,7 +12192,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||||
|
|
||||||
ggml_tensor * state_copy = build_inp_s_copy();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
@ -12197,9 +12202,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||||
const llama_layer * layer = &model.layers[il];
|
const llama_layer * layer = &model.layers[il];
|
||||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||||
gf, state_copy, ubatch, il
|
|
||||||
);
|
|
||||||
|
|
||||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||||
|
@ -12214,7 +12217,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||||
|
|
||||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
@ -12277,14 +12280,14 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||||
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
||||||
struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||||
llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
|
llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
|
||||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
ggml_tensor * state_copy = build_inp_s_copy();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
@ -12294,9 +12297,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||||
const llama_layer * layer = &model.layers[il];
|
const llama_layer * layer = &model.layers[il];
|
||||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||||
gf, state_copy, ubatch, il
|
|
||||||
);
|
|
||||||
|
|
||||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||||
cb(att_norm, "attn_norm", il);
|
cb(att_norm, "attn_norm", il);
|
||||||
|
@ -12308,7 +12309,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||||
|
|
||||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||||
|
@ -12396,14 +12397,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_rwkv7_time_mix(
|
ggml_tensor * build_rwkv7_time_mix(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
ggml_tensor * x_prev,
|
ggml_tensor * x_prev,
|
||||||
ggml_tensor * state_copy,
|
|
||||||
ggml_tensor *& first_layer_value,
|
ggml_tensor *& first_layer_value,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
|
@ -12482,9 +12483,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_recurrent_state(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
gf, kv_state->get_v_l(il), state_copy,
|
inp, gf, kv_state->get_s_l(il),
|
||||||
hparams.n_embd_v_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||||
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
||||||
|
@ -12497,9 +12498,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_v_l(il),
|
kv_state->get_s_l(il),
|
||||||
hparams.n_embd_v_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
@ -12540,7 +12541,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||||
|
|
||||||
ggml_tensor * state_copy = build_inp_s_copy();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
@ -12550,9 +12551,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||||
const llama_layer * layer = &model.layers[il];
|
const llama_layer * layer = &model.layers[il];
|
||||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||||
gf, state_copy, ubatch, il
|
|
||||||
);
|
|
||||||
|
|
||||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||||
|
@ -12567,7 +12566,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||||
|
|
||||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
@ -12625,7 +12624,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||||
|
|
||||||
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||||
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
@ -12633,7 +12632,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
ggml_tensor * state_copy = build_inp_s_copy();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
@ -12643,9 +12642,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||||
const llama_layer * layer = &model.layers[il];
|
const llama_layer * layer = &model.layers[il];
|
||||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||||
gf, state_copy, ubatch, il
|
|
||||||
);
|
|
||||||
|
|
||||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||||
cb(att_norm, "attn_norm", il);
|
cb(att_norm, "attn_norm", il);
|
||||||
|
@ -12657,7 +12654,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
|
|
||||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||||
|
|
||||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||||
|
@ -13838,6 +13835,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
llama_memory_i * res;
|
llama_memory_i * res;
|
||||||
|
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
|
// Models that need specific instantiation should be handled in the
|
||||||
|
// switch statement
|
||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
case LLM_ARCH_JINA_BERT_V2:
|
case LLM_ARCH_JINA_BERT_V2:
|
||||||
case LLM_ARCH_NOMIC_BERT:
|
case LLM_ARCH_NOMIC_BERT:
|
||||||
|
@ -13847,22 +13846,39 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
{
|
{
|
||||||
res = nullptr;
|
res = nullptr;
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
// Models that need standard caching should rely on recurrent/hybrid
|
||||||
case LLM_ARCH_RWKV6:
|
// checks
|
||||||
case LLM_ARCH_RWKV6QWEN2:
|
default:
|
||||||
case LLM_ARCH_RWKV7:
|
|
||||||
case LLM_ARCH_ARWKV7:
|
|
||||||
{
|
{
|
||||||
res = new llama_kv_cache_recurrent(
|
if (llm_arch_is_recurrent(arch)) {
|
||||||
|
res = new llama_memory_recurrent(
|
||||||
*this,
|
*this,
|
||||||
|
nullptr,
|
||||||
GGML_TYPE_F32,
|
GGML_TYPE_F32,
|
||||||
GGML_TYPE_F32,
|
GGML_TYPE_F32,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
cparams.n_seq_max);
|
cparams.n_seq_max);
|
||||||
} break;
|
} else if (llm_arch_is_hybrid(arch)) {
|
||||||
default:
|
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||||
{
|
|
||||||
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||||
|
|
||||||
|
res = new llama_memory_hybrid(
|
||||||
|
/* model */ *this,
|
||||||
|
/* attn_type_k */ params.type_k,
|
||||||
|
/* attn_type_v */ params.type_v,
|
||||||
|
/* attn_v_trans */ !cparams.flash_attn,
|
||||||
|
/* attn_kv_size */ cparams.n_ctx,
|
||||||
|
/* attn_n_pad */ padding,
|
||||||
|
/* attn_n_swa */ hparams.n_swa,
|
||||||
|
/* attn_swa_type */ hparams.swa_type,
|
||||||
|
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
|
/* n_seq_max */ cparams.n_seq_max,
|
||||||
|
/* offload */ cparams.offload_kqv);
|
||||||
|
} else {
|
||||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||||
|
|
||||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||||
|
@ -13901,6 +13917,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -14477,14 +14494,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_model_is_recurrent(const llama_model * model) {
|
bool llama_model_is_recurrent(const llama_model * model) {
|
||||||
switch (model->arch) {
|
return llm_arch_is_recurrent(model->arch);
|
||||||
case LLM_ARCH_MAMBA: return true;
|
|
||||||
case LLM_ARCH_RWKV6: return true;
|
|
||||||
case LLM_ARCH_RWKV6QWEN2: return true;
|
|
||||||
case LLM_ARCH_RWKV7: return true;
|
|
||||||
case LLM_ARCH_ARWKV7: return true;
|
|
||||||
default: return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
||||||
|
|
|
@ -2299,9 +2299,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
//NOTE: Per token attributes are missing from the GGUF file.
|
//NOTE: Per token attributes are missing from the GGUF file.
|
||||||
//TODO: Extract attributes from GGUF file.
|
//TODO: Extract attributes from GGUF file.
|
||||||
{
|
{
|
||||||
auto _contains_any = [] (const std::string & str, const std::vector<std::string> & substrs) -> bool {
|
auto _contains_any = [] (const std::string & str, const std::vector<std::string_view> & substrs) -> bool {
|
||||||
for (const auto & substr : substrs) {
|
for (const auto & substr : substrs) {
|
||||||
if (str.find(substr) < std::string::npos) {
|
if (str.find(substr) != std::string::npos) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,8 @@ static bool old_mixtral_warning_showed = false;
|
||||||
#include "llama-sampling.cpp"
|
#include "llama-sampling.cpp"
|
||||||
#include "llama-kv-cache-unified.cpp"
|
#include "llama-kv-cache-unified.cpp"
|
||||||
#include "llama-kv-cache-unified-iswa.cpp"
|
#include "llama-kv-cache-unified-iswa.cpp"
|
||||||
#include "llama-kv-cache-recurrent.cpp"
|
#include "llama-memory-hybrid.cpp"
|
||||||
|
#include "llama-memory-recurrent.cpp"
|
||||||
#include "llama-model-loader.cpp"
|
#include "llama-model-loader.cpp"
|
||||||
#include "llama-model-saver.cpp"
|
#include "llama-model-saver.cpp"
|
||||||
#include "llama-model.cpp"
|
#include "llama-model.cpp"
|
||||||
|
|
|
@ -205,11 +205,16 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
||||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic push
|
||||||
|
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic pop
|
# pragma clang diagnostic pop
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
try {
|
try {
|
||||||
return conv.from_bytes(s);
|
return conv.from_bytes(s);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue