mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-13 10:29:43 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # .gitignore # flake.lock # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt
This commit is contained in:
commit
bf35652ef7
17 changed files with 1680 additions and 1300 deletions
|
@ -48,7 +48,7 @@ int main(int argc, char ** argv) {
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||||
const size_t written = llama_state_get_data(ctx, state_mem.data());
|
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
|
||||||
|
|
||||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||||
fwrite(state_mem.data(), 1, written, fp_write);
|
fwrite(state_mem.data(), 1, written, fp_write);
|
||||||
|
@ -100,13 +100,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
|
std::vector<uint8_t> state_mem;
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
|
fseek(fp_read, 0, SEEK_END);
|
||||||
|
state_mem.resize(ftell(fp_read));
|
||||||
|
fseek(fp_read, 0, SEEK_SET);
|
||||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
fclose(fp_read);
|
fclose(fp_read);
|
||||||
|
|
||||||
if (read != llama_state_set_data(ctx2, state_mem.data())) {
|
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -160,13 +163,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
|
std::vector<uint8_t> state_mem;
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
|
fseek(fp_read, 0, SEEK_END);
|
||||||
|
state_mem.resize(ftell(fp_read));
|
||||||
|
fseek(fp_read, 0, SEEK_SET);
|
||||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
fclose(fp_read);
|
fclose(fp_read);
|
||||||
|
|
||||||
if (read != llama_state_set_data(ctx3, state_mem.data())) {
|
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -183,7 +189,7 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
// save kv of seq 0
|
// save kv of seq 0
|
||||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
|
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
||||||
if (ncopy != seq_store.size()) {
|
if (ncopy != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
|
@ -197,7 +203,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||||
|
|
||||||
// restore kv into seq 1
|
// restore kv into seq 1
|
||||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
|
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
||||||
if (nset != seq_store.size()) {
|
if (nset != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
|
|
|
@ -27,255 +27,11 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#include <hip/hip_runtime.h>
|
#include "vendors/hip.h"
|
||||||
#include <hipblas/hipblas.h>
|
|
||||||
#include <hip/hip_fp16.h>
|
|
||||||
#ifdef __HIP_PLATFORM_AMD__
|
|
||||||
// for rocblas_initialize()
|
|
||||||
#include "rocblas/rocblas.h"
|
|
||||||
#endif // __HIP_PLATFORM_AMD__
|
|
||||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
|
||||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
||||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
||||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
||||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
|
||||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
|
||||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
||||||
#define CUDA_R_16F HIPBLAS_R_16F
|
|
||||||
#define CUDA_R_32F HIPBLAS_R_32F
|
|
||||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
||||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
|
||||||
#define cublasCreate hipblasCreate
|
|
||||||
#define cublasDestroy hipblasDestroy
|
|
||||||
#define cublasGemmEx hipblasGemmEx
|
|
||||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
|
||||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
|
||||||
#define cublasHandle_t hipblasHandle_t
|
|
||||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
|
||||||
#define cublasSetStream hipblasSetStream
|
|
||||||
#define cublasSgemm hipblasSgemm
|
|
||||||
#define cublasStatus_t hipblasStatus_t
|
|
||||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
|
||||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
|
||||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
|
||||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
|
||||||
#define cudaDeviceProp hipDeviceProp_t
|
|
||||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
|
||||||
#define cudaError_t hipError_t
|
|
||||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
|
||||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
|
||||||
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
|
||||||
#define cudaEventDisableTiming hipEventDisableTiming
|
|
||||||
#define cudaEventRecord hipEventRecord
|
|
||||||
#define cudaEventSynchronize hipEventSynchronize
|
|
||||||
#define cudaEvent_t hipEvent_t
|
|
||||||
#define cudaEventDestroy hipEventDestroy
|
|
||||||
#define cudaFree hipFree
|
|
||||||
#define cudaFreeHost hipHostFree
|
|
||||||
#define cudaGetDevice hipGetDevice
|
|
||||||
#define cudaGetDeviceCount hipGetDeviceCount
|
|
||||||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
|
||||||
#define cudaGetErrorString hipGetErrorString
|
|
||||||
#define cudaGetLastError hipGetLastError
|
|
||||||
#define cudaHostRegister hipHostRegister
|
|
||||||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
|
||||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
|
||||||
#define cudaHostUnregister hipHostUnregister
|
|
||||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
|
||||||
#define cudaMalloc hipMalloc
|
|
||||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
|
||||||
#define cudaMemcpy hipMemcpy
|
|
||||||
#define cudaMemcpyAsync hipMemcpyAsync
|
|
||||||
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
|
||||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
|
||||||
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
|
||||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
|
||||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
|
||||||
#define cudaMemcpyKind hipMemcpyKind
|
|
||||||
#define cudaMemset hipMemset
|
|
||||||
#define cudaMemsetAsync hipMemsetAsync
|
|
||||||
#define cudaMemGetInfo hipMemGetInfo
|
|
||||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
||||||
#define cudaSetDevice hipSetDevice
|
|
||||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
|
||||||
#define cudaStreamDestroy hipStreamDestroy
|
|
||||||
#define cudaStreamFireAndForget hipStreamFireAndForget
|
|
||||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
|
||||||
#define cudaStreamPerThread hipStreamPerThread
|
|
||||||
#define cudaStreamSynchronize hipStreamSynchronize
|
|
||||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
|
||||||
#define cudaStream_t hipStream_t
|
|
||||||
#define cudaSuccess hipSuccess
|
|
||||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
|
||||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
|
||||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
|
||||||
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
|
||||||
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
|
||||||
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
|
||||||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
|
||||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
|
||||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
||||||
#elif defined(GGML_USE_MUSA)
|
#elif defined(GGML_USE_MUSA)
|
||||||
#include <musa_runtime.h>
|
#include "vendors/musa.h"
|
||||||
#include <musa.h>
|
|
||||||
#include <mublas.h>
|
|
||||||
#include <musa_fp16.h>
|
|
||||||
// XXX: Keep the following order the same as hipBLAS
|
|
||||||
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
|
||||||
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
|
||||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
|
||||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
|
||||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
|
||||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
|
||||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
|
||||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
|
||||||
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
||||||
#define CUDA_R_16F MUSA_R_16F
|
|
||||||
#define CUDA_R_32F MUSA_R_32F
|
|
||||||
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
||||||
// #define cublasComputeType_t mublasComputeType_t
|
|
||||||
#define cublasCreate mublasCreate
|
|
||||||
#define cublasDestroy mublasDestroy
|
|
||||||
#define cublasGemmEx mublasGemmEx
|
|
||||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
|
||||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
|
||||||
#define cublasHandle_t mublasHandle_t
|
|
||||||
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
|
||||||
#define cublasSetMathMode mublasSetMathMode
|
|
||||||
#define cublasSetStream mublasSetStream
|
|
||||||
#define cublasSgemm mublasSgemm
|
|
||||||
#define cublasStatus_t mublasStatus_t
|
|
||||||
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
|
||||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
|
||||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
|
||||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
|
||||||
#define cudaDeviceProp musaDeviceProp
|
|
||||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
|
||||||
#define cudaError_t musaError_t
|
|
||||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
|
||||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
|
||||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
|
||||||
#define cudaEventDisableTiming musaEventDisableTiming
|
|
||||||
#define cudaEventRecord musaEventRecord
|
|
||||||
#define cudaEventSynchronize musaEventSynchronize
|
|
||||||
#define cudaEvent_t musaEvent_t
|
|
||||||
#define cudaEventDestroy musaEventDestroy
|
|
||||||
#define cudaFree musaFree
|
|
||||||
#define cudaFreeHost musaFreeHost
|
|
||||||
#define cudaGetDevice musaGetDevice
|
|
||||||
#define cudaGetDeviceCount musaGetDeviceCount
|
|
||||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
|
||||||
#define cudaGetErrorString musaGetErrorString
|
|
||||||
#define cudaGetLastError musaGetLastError
|
|
||||||
#define cudaHostRegister musaHostRegister
|
|
||||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
|
||||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
|
||||||
#define cudaHostUnregister musaHostUnregister
|
|
||||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
|
||||||
#define cudaMalloc musaMalloc
|
|
||||||
#define cudaMallocHost musaMallocHost
|
|
||||||
#define cudaMemcpy musaMemcpy
|
|
||||||
#define cudaMemcpyAsync musaMemcpyAsync
|
|
||||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
|
||||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
|
||||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
|
||||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
|
||||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
|
||||||
#define cudaMemcpyKind musaMemcpyKind
|
|
||||||
#define cudaMemset musaMemset
|
|
||||||
#define cudaMemsetAsync musaMemsetAsync
|
|
||||||
#define cudaMemGetInfo musaMemGetInfo
|
|
||||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
|
||||||
#define cudaSetDevice musaSetDevice
|
|
||||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
|
||||||
#define cudaStreamDestroy musaStreamDestroy
|
|
||||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
|
||||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
|
||||||
#define cudaStreamPerThread musaStreamPerThread
|
|
||||||
#define cudaStreamSynchronize musaStreamSynchronize
|
|
||||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
|
||||||
#define cudaStream_t musaStream_t
|
|
||||||
#define cudaSuccess musaSuccess
|
|
||||||
|
|
||||||
// XXX: Other CUDA => MUSA mapping
|
|
||||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
|
||||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
|
||||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
|
||||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
|
||||||
#define CUdevice MUdevice
|
|
||||||
#define CUdeviceptr MUdeviceptr
|
|
||||||
#define CUmemAccessDesc MUmemAccessDesc
|
|
||||||
#define CUmemAllocationProp MUmemAllocationProp
|
|
||||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
|
||||||
#define cuDeviceGet muDeviceGet
|
|
||||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
|
||||||
#define cuMemAddressFree muMemAddressFree
|
|
||||||
#define cuMemAddressReserve muMemAddressReserve
|
|
||||||
#define cuMemCreate muMemCreate
|
|
||||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
|
||||||
#define cuMemMap muMemMap
|
|
||||||
#define cuMemRelease muMemRelease
|
|
||||||
#define cuMemSetAccess muMemSetAccess
|
|
||||||
#define cuMemUnmap muMemUnmap
|
|
||||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
|
||||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
|
||||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
|
||||||
#define make_cudaExtent make_musaExtent
|
|
||||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
|
||||||
|
|
||||||
// XXX: USE_CUDA_GRAPH
|
|
||||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
|
||||||
#define CUresult MUresult
|
|
||||||
#define cuGetErrorString muGetErrorString
|
|
||||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
|
||||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
|
||||||
#define cudaGraphDestroy musaGraphDestroy
|
|
||||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
|
||||||
#define cudaGraphExec_t musaGraphExec_t
|
|
||||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
|
||||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
|
||||||
#define cudaGraphGetNodes musaGraphGetNodes
|
|
||||||
#define cudaGraphInstantiate musaGraphInstantiate
|
|
||||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
||||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
|
||||||
#define cudaGraphLaunch musaGraphLaunch
|
|
||||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
|
||||||
#define cudaGraphNode_t musaGraphNode_t
|
|
||||||
#define cudaGraphNodeType musaGraphNodeType
|
|
||||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
|
||||||
#define cudaGraph_t musaGraph_t
|
|
||||||
#define cudaKernelNodeParams musaKernelNodeParams
|
|
||||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
||||||
#define cudaStreamEndCapture musaStreamEndCapture
|
|
||||||
|
|
||||||
// XXX: cuBLAS => muBLAS mapping
|
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
|
||||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
|
||||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
||||||
#define cublasComputeType_t cudaDataType_t
|
|
||||||
|
|
||||||
// XXX: Clang builtins mapping
|
|
||||||
#define __vsub4 __vsub4_musa
|
|
||||||
#define __vcmpeq4 __vcmpeq4_musa
|
|
||||||
#define __vcmpne4 __vcmpne4_musa
|
|
||||||
#else
|
#else
|
||||||
#include <cuda_runtime.h>
|
#include "vendors/cuda.h"
|
||||||
#include <cuda.h>
|
|
||||||
#include <cublas_v2.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
#if CUDART_VERSION < 11020
|
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
|
||||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
|
||||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
||||||
#define cublasComputeType_t cudaDataType_t
|
|
||||||
#endif // CUDART_VERSION < 11020
|
|
||||||
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
||||||
|
@ -318,11 +74,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
||||||
|
|
||||||
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
||||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||||
#ifndef GGML_USE_MUSA
|
|
||||||
return cublasGetStatusString(err);
|
return cublasGetStatusString(err);
|
||||||
#else
|
|
||||||
return mublasStatus_to_string(err);
|
|
||||||
#endif // GGML_USE_MUSA
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||||
|
@ -364,129 +116,7 @@ typedef half2 dfloat2;
|
||||||
#else
|
#else
|
||||||
typedef float dfloat; // dequantize float
|
typedef float dfloat; // dequantize float
|
||||||
typedef float2 dfloat2;
|
typedef float2 dfloat2;
|
||||||
#endif //GGML_CUDA_F16
|
#endif // GGML_CUDA_F16
|
||||||
|
|
||||||
#if defined(GGML_USE_MUSA)
|
|
||||||
#ifndef __has_builtin
|
|
||||||
#define __has_builtin(x) 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
|
||||||
return __vsubss4(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
|
||||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
||||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
||||||
unsigned int c;
|
|
||||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
|
||||||
}
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
|
||||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
||||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
||||||
unsigned int c;
|
|
||||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
|
||||||
}
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
#endif // defined(GGML_USE_MUSA)
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
|
||||||
#define __CUDA_ARCH__ 1300
|
|
||||||
|
|
||||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
|
||||||
defined(__gfx1150__) || defined(__gfx1151__)
|
|
||||||
#define RDNA3
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
|
||||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
|
||||||
#define RDNA2
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
|
||||||
#define RDNA1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef __has_builtin
|
|
||||||
#define __has_builtin(x) 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
|
||||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
||||||
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
||||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
|
||||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
|
||||||
#if __has_builtin(__builtin_elementwise_sub_sat)
|
|
||||||
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
|
||||||
return reinterpret_cast<const int &>(c);
|
|
||||||
#else
|
|
||||||
int8x4_t c;
|
|
||||||
int16_t tmp;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
tmp = va[i] - vb[i];
|
|
||||||
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
|
||||||
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
|
||||||
c[i] = tmp;
|
|
||||||
}
|
|
||||||
return reinterpret_cast<int &>(c);
|
|
||||||
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
|
||||||
return __vsubss4(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
|
||||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
||||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
||||||
unsigned int c;
|
|
||||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
|
||||||
}
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
|
||||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
|
||||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
|
||||||
unsigned int c;
|
|
||||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
|
||||||
}
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
|
||||||
// __shfl_xor() for half2 was added in ROCm 5.6
|
|
||||||
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
|
||||||
typedef union half2_b32 {
|
|
||||||
half2 val;
|
|
||||||
int b32;
|
|
||||||
} half2_b32_t;
|
|
||||||
half2_b32_t tmp;
|
|
||||||
tmp.val = var;
|
|
||||||
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
|
||||||
return tmp.val;
|
|
||||||
}
|
|
||||||
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
|
||||||
|
|
||||||
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||||
#define FP16_AVAILABLE
|
#define FP16_AVAILABLE
|
||||||
|
|
14
ggml/src/ggml-cuda/vendors/cuda.h
vendored
Normal file
14
ggml/src/ggml-cuda/vendors/cuda.h
vendored
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#if CUDART_VERSION < 11020
|
||||||
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
|
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||||
|
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||||
|
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||||
|
#define cublasComputeType_t cudaDataType_t
|
||||||
|
#endif // CUDART_VERSION < 11020
|
177
ggml/src/ggml-cuda/vendors/hip.h
vendored
Normal file
177
ggml/src/ggml-cuda/vendors/hip.h
vendored
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hipblas/hipblas.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#ifdef __HIP_PLATFORM_AMD__
|
||||||
|
// for rocblas_initialize()
|
||||||
|
#include "rocblas/rocblas.h"
|
||||||
|
#endif // __HIP_PLATFORM_AMD__
|
||||||
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||||
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||||
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||||
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||||
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||||
|
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||||
|
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||||
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||||
|
#define CUDA_R_16F HIPBLAS_R_16F
|
||||||
|
#define CUDA_R_32F HIPBLAS_R_32F
|
||||||
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||||
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||||
|
#define cublasCreate hipblasCreate
|
||||||
|
#define cublasDestroy hipblasDestroy
|
||||||
|
#define cublasGemmEx hipblasGemmEx
|
||||||
|
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||||
|
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||||
|
#define cublasHandle_t hipblasHandle_t
|
||||||
|
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||||
|
#define cublasSetStream hipblasSetStream
|
||||||
|
#define cublasSgemm hipblasSgemm
|
||||||
|
#define cublasStatus_t hipblasStatus_t
|
||||||
|
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||||
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||||
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
|
#define cudaDeviceProp hipDeviceProp_t
|
||||||
|
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||||
|
#define cudaError_t hipError_t
|
||||||
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||||
|
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||||
|
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
||||||
|
#define cudaEventDisableTiming hipEventDisableTiming
|
||||||
|
#define cudaEventRecord hipEventRecord
|
||||||
|
#define cudaEventSynchronize hipEventSynchronize
|
||||||
|
#define cudaEvent_t hipEvent_t
|
||||||
|
#define cudaEventDestroy hipEventDestroy
|
||||||
|
#define cudaFree hipFree
|
||||||
|
#define cudaFreeHost hipHostFree
|
||||||
|
#define cudaGetDevice hipGetDevice
|
||||||
|
#define cudaGetDeviceCount hipGetDeviceCount
|
||||||
|
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||||
|
#define cudaGetErrorString hipGetErrorString
|
||||||
|
#define cudaGetLastError hipGetLastError
|
||||||
|
#define cudaHostRegister hipHostRegister
|
||||||
|
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||||
|
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||||
|
#define cudaHostUnregister hipHostUnregister
|
||||||
|
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||||
|
#define cudaMalloc hipMalloc
|
||||||
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||||
|
#define cudaMemcpy hipMemcpy
|
||||||
|
#define cudaMemcpyAsync hipMemcpyAsync
|
||||||
|
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
||||||
|
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||||
|
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||||
|
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||||
|
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||||
|
#define cudaMemcpyKind hipMemcpyKind
|
||||||
|
#define cudaMemset hipMemset
|
||||||
|
#define cudaMemsetAsync hipMemsetAsync
|
||||||
|
#define cudaMemGetInfo hipMemGetInfo
|
||||||
|
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||||
|
#define cudaSetDevice hipSetDevice
|
||||||
|
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||||
|
#define cudaStreamDestroy hipStreamDestroy
|
||||||
|
#define cudaStreamFireAndForget hipStreamFireAndForget
|
||||||
|
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||||
|
#define cudaStreamPerThread hipStreamPerThread
|
||||||
|
#define cudaStreamSynchronize hipStreamSynchronize
|
||||||
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||||
|
#define cudaStream_t hipStream_t
|
||||||
|
#define cudaSuccess hipSuccess
|
||||||
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||||
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||||
|
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||||
|
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
||||||
|
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
||||||
|
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
||||||
|
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||||
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||||
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||||
|
|
||||||
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
|
#define RDNA3
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||||
|
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||||
|
#define RDNA2
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||||
|
#define RDNA1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef __has_builtin
|
||||||
|
#define __has_builtin(x) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
||||||
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||||
|
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||||
|
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||||
|
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||||
|
#if __has_builtin(__builtin_elementwise_sub_sat)
|
||||||
|
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
||||||
|
return reinterpret_cast<const int &>(c);
|
||||||
|
#else
|
||||||
|
int8x4_t c;
|
||||||
|
int16_t tmp;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
tmp = va[i] - vb[i];
|
||||||
|
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
||||||
|
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
||||||
|
c[i] = tmp;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<int &>(c);
|
||||||
|
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
||||||
|
return __vsubss4(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
||||||
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||||
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||||
|
unsigned int c;
|
||||||
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
||||||
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||||
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||||
|
unsigned int c;
|
||||||
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
||||||
|
// __shfl_xor() for half2 was added in ROCm 5.6
|
||||||
|
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
||||||
|
typedef union half2_b32 {
|
||||||
|
half2 val;
|
||||||
|
int b32;
|
||||||
|
} half2_b32_t;
|
||||||
|
half2_b32_t tmp;
|
||||||
|
tmp.val = var;
|
||||||
|
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
||||||
|
return tmp.val;
|
||||||
|
}
|
||||||
|
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
171
ggml/src/ggml-cuda/vendors/musa.h
vendored
Normal file
171
ggml/src/ggml-cuda/vendors/musa.h
vendored
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <musa_runtime.h>
|
||||||
|
#include <musa.h>
|
||||||
|
#include <mublas.h>
|
||||||
|
#include <musa_fp16.h>
|
||||||
|
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||||
|
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||||
|
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||||
|
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||||
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||||
|
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||||
|
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||||
|
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||||
|
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||||
|
#define CUDA_R_16F MUSA_R_16F
|
||||||
|
#define CUDA_R_32F MUSA_R_32F
|
||||||
|
#define cublasComputeType_t cudaDataType_t
|
||||||
|
#define cublasCreate mublasCreate
|
||||||
|
#define cublasDestroy mublasDestroy
|
||||||
|
#define cublasGemmEx mublasGemmEx
|
||||||
|
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||||
|
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||||
|
#define cublasHandle_t mublasHandle_t
|
||||||
|
#define cublasSetMathMode mublasSetMathMode
|
||||||
|
#define cublasSetStream mublasSetStream
|
||||||
|
#define cublasSgemm mublasSgemm
|
||||||
|
#define cublasStatus_t mublasStatus_t
|
||||||
|
#define cublasGetStatusString mublasStatus_to_string
|
||||||
|
#define cudaDataType_t musaDataType_t
|
||||||
|
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||||
|
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||||
|
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||||
|
#define cudaDeviceProp musaDeviceProp
|
||||||
|
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||||
|
#define cudaError_t musaError_t
|
||||||
|
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||||
|
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||||
|
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||||
|
#define cudaEventDisableTiming musaEventDisableTiming
|
||||||
|
#define cudaEventRecord musaEventRecord
|
||||||
|
#define cudaEventSynchronize musaEventSynchronize
|
||||||
|
#define cudaEvent_t musaEvent_t
|
||||||
|
#define cudaEventDestroy musaEventDestroy
|
||||||
|
#define cudaFree musaFree
|
||||||
|
#define cudaFreeHost musaFreeHost
|
||||||
|
#define cudaGetDevice musaGetDevice
|
||||||
|
#define cudaGetDeviceCount musaGetDeviceCount
|
||||||
|
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||||
|
#define cudaGetErrorString musaGetErrorString
|
||||||
|
#define cudaGetLastError musaGetLastError
|
||||||
|
#define cudaHostRegister musaHostRegister
|
||||||
|
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||||
|
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||||
|
#define cudaHostUnregister musaHostUnregister
|
||||||
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
|
#define cudaMalloc musaMalloc
|
||||||
|
#define cudaMallocHost musaMallocHost
|
||||||
|
#define cudaMemcpy musaMemcpy
|
||||||
|
#define cudaMemcpyAsync musaMemcpyAsync
|
||||||
|
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||||
|
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||||
|
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||||
|
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||||
|
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||||
|
#define cudaMemcpyKind musaMemcpyKind
|
||||||
|
#define cudaMemset musaMemset
|
||||||
|
#define cudaMemsetAsync musaMemsetAsync
|
||||||
|
#define cudaMemGetInfo musaMemGetInfo
|
||||||
|
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||||
|
#define cudaSetDevice musaSetDevice
|
||||||
|
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||||
|
#define cudaStreamDestroy musaStreamDestroy
|
||||||
|
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||||
|
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||||
|
#define cudaStreamPerThread musaStreamPerThread
|
||||||
|
#define cudaStreamSynchronize musaStreamSynchronize
|
||||||
|
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||||
|
#define cudaStream_t musaStream_t
|
||||||
|
#define cudaSuccess musaSuccess
|
||||||
|
|
||||||
|
// Additional mappings for MUSA virtual memory pool
|
||||||
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
|
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||||
|
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||||
|
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||||
|
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||||
|
#define CUdevice MUdevice
|
||||||
|
#define CUdeviceptr MUdeviceptr
|
||||||
|
#define CUmemAccessDesc MUmemAccessDesc
|
||||||
|
#define CUmemAllocationProp MUmemAllocationProp
|
||||||
|
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||||
|
#define cuDeviceGet muDeviceGet
|
||||||
|
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||||
|
#define cuMemAddressFree muMemAddressFree
|
||||||
|
#define cuMemAddressReserve muMemAddressReserve
|
||||||
|
#define cuMemCreate muMemCreate
|
||||||
|
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||||
|
#define cuMemMap muMemMap
|
||||||
|
#define cuMemRelease muMemRelease
|
||||||
|
#define cuMemSetAccess muMemSetAccess
|
||||||
|
#define cuMemUnmap muMemUnmap
|
||||||
|
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||||
|
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||||
|
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||||
|
#define make_cudaExtent make_musaExtent
|
||||||
|
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||||
|
|
||||||
|
// Additional mappings for MUSA graphs
|
||||||
|
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||||
|
#define CUresult MUresult
|
||||||
|
#define cuGetErrorString muGetErrorString
|
||||||
|
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||||
|
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||||
|
#define cudaGraphDestroy musaGraphDestroy
|
||||||
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||||
|
#define cudaGraphExec_t musaGraphExec_t
|
||||||
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||||
|
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||||
|
#define cudaGraphGetNodes musaGraphGetNodes
|
||||||
|
#define cudaGraphInstantiate musaGraphInstantiate
|
||||||
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||||
|
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||||
|
#define cudaGraphLaunch musaGraphLaunch
|
||||||
|
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||||
|
#define cudaGraphNode_t musaGraphNode_t
|
||||||
|
#define cudaGraphNodeType musaGraphNodeType
|
||||||
|
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||||
|
#define cudaGraph_t musaGraph_t
|
||||||
|
#define cudaKernelNodeParams musaKernelNodeParams
|
||||||
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||||
|
#define cudaStreamEndCapture musaStreamEndCapture
|
||||||
|
|
||||||
|
// XXX: Clang builtins mapping
|
||||||
|
#define __vsub4 __vsub4_musa
|
||||||
|
#define __vcmpeq4 __vcmpeq4_musa
|
||||||
|
#define __vcmpne4 __vcmpne4_musa
|
||||||
|
|
||||||
|
#ifndef __has_builtin
|
||||||
|
#define __has_builtin(x) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
||||||
|
return __vsubss4(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
||||||
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||||
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||||
|
unsigned int c;
|
||||||
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
||||||
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||||
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||||
|
unsigned int c;
|
||||||
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
|
@ -6450,22 +6450,22 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
// compute mask for subtraction
|
// compute mask for subtraction
|
||||||
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
|
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
|
||||||
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
|
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
|
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
|
||||||
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
|
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
|
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
|
||||||
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
|
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
|
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
|
||||||
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
|
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
// load Q8 and take product with Q3
|
// load Q8 and take product with Q3
|
||||||
|
@ -7721,13 +7721,13 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
|
vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
|
||||||
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
|
vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
|
||||||
vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
|
vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
|
vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
|
||||||
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
||||||
vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
|
vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
|
||||||
vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
|
vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
|
||||||
m <<= 1;
|
m <<= 1;
|
||||||
|
|
||||||
vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
|
vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
|
||||||
|
|
|
@ -3981,6 +3981,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
ggml_sycl_func_t func;
|
ggml_sycl_func_t func;
|
||||||
|
|
||||||
switch (tensor->op) {
|
switch (tensor->op) {
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
|
func = ggml_sycl_op_conv_transpose_1d;
|
||||||
|
break;
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
func = ggml_sycl_repeat;
|
func = ggml_sycl_repeat;
|
||||||
break;
|
break;
|
||||||
|
@ -4105,6 +4108,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
func = ggml_sycl_argsort;
|
func = ggml_sycl_argsort;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
func = ggml_sycl_op_timestep_embedding;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -5090,6 +5096,15 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
|
{
|
||||||
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
ggml_type src1_type = op->src[1]->type;
|
||||||
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
|
@ -5213,6 +5228,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include "concat.hpp"
|
#include "concat.hpp"
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
|
#include "conv.hpp"
|
||||||
#include "convert.hpp"
|
#include "convert.hpp"
|
||||||
#include "dequantize.hpp"
|
#include "dequantize.hpp"
|
||||||
#include "dmmv.hpp"
|
#include "dmmv.hpp"
|
||||||
|
@ -23,5 +24,6 @@
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
#include "norm.hpp"
|
#include "norm.hpp"
|
||||||
#include "softmax.hpp"
|
#include "softmax.hpp"
|
||||||
|
#include "tsembd.hpp"
|
||||||
|
|
||||||
#endif // GGML_SYCL_BACKEND_HPP
|
#endif // GGML_SYCL_BACKEND_HPP
|
||||||
|
|
99
ggml/src/ggml-sycl/conv.cpp
Normal file
99
ggml/src/ggml-sycl/conv.cpp
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "conv.hpp"
|
||||||
|
|
||||||
|
static void conv_transpose_1d_kernel(
|
||||||
|
const int s0, const int output_size,
|
||||||
|
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
||||||
|
const int src1_ne0, const int dst_ne0,
|
||||||
|
const float * src0, const float * src1, float * dst,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
int global_index = item_ct1.get_local_id(2) +
|
||||||
|
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||||
|
if (global_index >= output_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_index = global_index / dst_ne0;
|
||||||
|
|
||||||
|
float accumulator = 0;
|
||||||
|
|
||||||
|
for (int c = 0; c < src0_ne2; c++) {
|
||||||
|
int idx = global_index % dst_ne0;
|
||||||
|
|
||||||
|
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
|
||||||
|
int input_offset = src1_ne0 * c;
|
||||||
|
|
||||||
|
for (int i = 0; i < src1_ne0; i++) {
|
||||||
|
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int weight_idx = idx - i*s0;
|
||||||
|
|
||||||
|
float kernel_weight = src0[kernel_offset + weight_idx];
|
||||||
|
float input_value = src1[input_offset+i];
|
||||||
|
|
||||||
|
accumulator += kernel_weight * input_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst[global_index] = accumulator;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void conv_transpose_1d_f32_f32_sycl(
|
||||||
|
const int s0, const int output_size,
|
||||||
|
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
||||||
|
const int src1_ne0, const int dst_ne0,
|
||||||
|
const float *src0, const float *src1, float *dst,
|
||||||
|
const queue_ptr& stream) {
|
||||||
|
|
||||||
|
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
||||||
|
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
||||||
|
const sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(
|
||||||
|
block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
conv_transpose_1d_kernel(
|
||||||
|
s0, output_size,
|
||||||
|
src0_ne0, src0_ne1, src0_ne2,
|
||||||
|
src1_ne0, dst_ne0,
|
||||||
|
src0, src1, dst, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst) {
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
const float * src1_d = (const float *)src1->data;
|
||||||
|
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||||
|
|
||||||
|
const int s0 = opts[0];
|
||||||
|
|
||||||
|
const int64_t output_size = ggml_nelements(dst);
|
||||||
|
|
||||||
|
conv_transpose_1d_f32_f32_sycl(s0, output_size,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||||
|
src1->ne[0], dst->ne[0],
|
||||||
|
src0_d, src1_d, dst_d, stream);
|
||||||
|
}
|
||||||
|
|
21
ggml/src/ggml-sycl/conv.hpp
Normal file
21
ggml/src/ggml-sycl/conv.hpp
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_CONV_HPP
|
||||||
|
#define GGML_SYCL_CONV_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_CONV_HPP
|
|
@ -41,6 +41,8 @@
|
||||||
#define SYCL_ACC_BLOCK_SIZE 256
|
#define SYCL_ACC_BLOCK_SIZE 256
|
||||||
#define SYCL_IM2COL_BLOCK_SIZE 256
|
#define SYCL_IM2COL_BLOCK_SIZE 256
|
||||||
#define SYCL_POOL2D_BLOCK_SIZE 256
|
#define SYCL_POOL2D_BLOCK_SIZE 256
|
||||||
|
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||||
|
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||||
|
|
||||||
// dmmv = dequantize_mul_mat_vec
|
// dmmv = dequantize_mul_mat_vec
|
||||||
#ifndef GGML_SYCL_DMMV_X
|
#ifndef GGML_SYCL_DMMV_X
|
||||||
|
|
71
ggml/src/ggml-sycl/tsembd.cpp
Normal file
71
ggml/src/ggml-sycl/tsembd.cpp
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "tsembd.hpp"
|
||||||
|
|
||||||
|
static void timestep_embedding_f32(
|
||||||
|
const float * timesteps, float * dst, const int nb1,
|
||||||
|
const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
|
||||||
|
// item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
|
||||||
|
// item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
|
||||||
|
int i = item_ct1.get_group(1);
|
||||||
|
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||||
|
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||||
|
|
||||||
|
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||||
|
embed_data[dim] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int half = dim / 2;
|
||||||
|
if (j >= half) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float timestep = timesteps[i];
|
||||||
|
float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
|
||||||
|
float arg = timestep * freq;
|
||||||
|
embed_data[j] = sycl::cos(arg);
|
||||||
|
embed_data[j + half] = sycl::sin(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void timestep_embedding_f32_sycl(
|
||||||
|
const float * x, float * dst, const int ne00, const int nb1,
|
||||||
|
const int dim, const int max_period, const queue_ptr& stream) {
|
||||||
|
// As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
|
||||||
|
int half_ceil = dim / 2;
|
||||||
|
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
||||||
|
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
||||||
|
sycl::range<3> gridDim(1, ne00, num_blocks);
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(
|
||||||
|
gridDim * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
timestep_embedding_f32(
|
||||||
|
x, dst, nb1, dim, max_period, item_ct1
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor * dst) {
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int dim = dst->op_params[0];
|
||||||
|
const int max_period = dst->op_params[1];
|
||||||
|
|
||||||
|
timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
|
||||||
|
}
|
21
ggml/src/ggml-sycl/tsembd.hpp
Normal file
21
ggml/src/ggml-sycl/tsembd.hpp
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_TSEMBD_HPP
|
||||||
|
#define GGML_SYCL_TSEMBD_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_TSEMBD_HPP
|
|
@ -30,6 +30,20 @@
|
||||||
|
|
||||||
#define ASYNCIO_CONCURRENCY 64
|
#define ASYNCIO_CONCURRENCY 64
|
||||||
|
|
||||||
|
// define prototypes
|
||||||
|
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str);
|
||||||
|
bool directory_exists(const std::string& path);
|
||||||
|
bool create_directory(const std::string& path);
|
||||||
|
std::string to_uppercase(const std::string& input);
|
||||||
|
bool string_ends_with(const std::string& str, const std::string& suffix);
|
||||||
|
std::string join_paths(const std::string& path1, const std::string& path2);
|
||||||
|
std::string basename(const std::string &path);
|
||||||
|
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16);
|
||||||
|
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b);
|
||||||
|
void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id);
|
||||||
|
void process_shaders(std::vector<std::future<void>>& tasks);
|
||||||
|
void write_output_files();
|
||||||
|
|
||||||
std::mutex lock;
|
std::mutex lock;
|
||||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||||
|
|
||||||
|
@ -38,7 +52,7 @@ std::string input_dir = "vulkan-shaders";
|
||||||
std::string output_dir = "/tmp";
|
std::string output_dir = "/tmp";
|
||||||
std::string target_hpp = "ggml-vulkan-shaders.hpp";
|
std::string target_hpp = "ggml-vulkan-shaders.hpp";
|
||||||
std::string target_cpp = "ggml-vulkan-shaders.cpp";
|
std::string target_cpp = "ggml-vulkan-shaders.cpp";
|
||||||
bool no_clean = false;
|
bool clean = true;
|
||||||
|
|
||||||
const std::vector<std::string> type_names = {
|
const std::vector<std::string> type_names = {
|
||||||
"f32",
|
"f32",
|
||||||
|
@ -464,8 +478,9 @@ void write_output_files() {
|
||||||
}
|
}
|
||||||
fprintf(src, "\n};\n\n");
|
fprintf(src, "\n};\n\n");
|
||||||
|
|
||||||
if (!no_clean) {
|
if (clean) {
|
||||||
std::remove(path.c_str());
|
std::remove(path.c_str());
|
||||||
|
// fprintf(stderr, "Removed: %s\n", path.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,6 +496,18 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (argc <= 1 || args.find("--help") != args.end()) {
|
||||||
|
std::cout << "Usage:\n"
|
||||||
|
"\tvulkan-shaders-gen [options]\n\n"
|
||||||
|
"Options:\n"
|
||||||
|
"\t--glslc <path> Path to glslc executable (default: /usr/bin/glslc)\n"
|
||||||
|
"\t--input-dir Directory containing shader sources (required)\n"
|
||||||
|
"\t--output-dir Output directory for generated SPIR-V files and optional C++ headers\n"
|
||||||
|
"\t--target-hpp <path> Path to generate a header file with shader declarations in C++ format\n"
|
||||||
|
"\t--target-cpp <path> Path to generate a source code file implementing the declared shaders (optional)\n"
|
||||||
|
"\t--no-clean Keep temporary SPIR-V files after build (default: remove them)\n";
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
if (args.find("--glslc") != args.end()) {
|
if (args.find("--glslc") != args.end()) {
|
||||||
GLSLC = args["--glslc"]; // Path to glslc
|
GLSLC = args["--glslc"]; // Path to glslc
|
||||||
}
|
}
|
||||||
|
@ -497,7 +524,7 @@ int main(int argc, char** argv) {
|
||||||
target_cpp = args["--target-cpp"]; // Path to generated cpp file
|
target_cpp = args["--target-cpp"]; // Path to generated cpp file
|
||||||
}
|
}
|
||||||
if (args.find("--no-clean") != args.end()) {
|
if (args.find("--no-clean") != args.end()) {
|
||||||
no_clean = true; // Keep temporary SPIR-V files in output-dir after build
|
clean = false; // Keep temporary SPIR-V files in output-dir after build
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!directory_exists(input_dir)) {
|
if (!directory_exists(input_dir)) {
|
||||||
|
|
|
@ -33,17 +33,15 @@
|
||||||
|
|
||||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||||
|
|
||||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 7
|
#define LLAMA_SESSION_VERSION 8
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 1
|
#define LLAMA_STATE_SEQ_VERSION 2
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -693,10 +691,11 @@ extern "C" {
|
||||||
// State / sessions
|
// State / sessions
|
||||||
//
|
//
|
||||||
|
|
||||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
// Returns the *actual* size in bytes of the state
|
||||||
// and kv_cache) - will often be smaller after compacting tokens
|
// (rng, logits, embedding and kv_cache)
|
||||||
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
|
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||||
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||||
"use llama_state_get_size instead");
|
"use llama_state_get_size instead");
|
||||||
|
|
||||||
// Copies the state to the specified destination address.
|
// Copies the state to the specified destination address.
|
||||||
|
@ -704,7 +703,8 @@ extern "C" {
|
||||||
// Returns the number of bytes copied
|
// Returns the number of bytes copied
|
||||||
LLAMA_API size_t llama_state_get_data(
|
LLAMA_API size_t llama_state_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst);
|
uint8_t * dst,
|
||||||
|
size_t size);
|
||||||
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst),
|
uint8_t * dst),
|
||||||
|
@ -714,7 +714,8 @@ extern "C" {
|
||||||
// Returns the number of bytes read
|
// Returns the number of bytes read
|
||||||
LLAMA_API size_t llama_state_set_data(
|
LLAMA_API size_t llama_state_set_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src);
|
const uint8_t * src,
|
||||||
|
size_t size);
|
||||||
LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src),
|
const uint8_t * src),
|
||||||
|
@ -756,6 +757,7 @@ extern "C" {
|
||||||
LLAMA_API size_t llama_state_seq_get_data(
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst,
|
uint8_t * dst,
|
||||||
|
size_t size,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||||
|
@ -765,6 +767,7 @@ extern "C" {
|
||||||
LLAMA_API size_t llama_state_seq_set_data(
|
LLAMA_API size_t llama_state_seq_set_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src,
|
const uint8_t * src,
|
||||||
|
size_t size,
|
||||||
llama_seq_id dest_seq_id);
|
llama_seq_id dest_seq_id);
|
||||||
|
|
||||||
LLAMA_API size_t llama_state_seq_save_file(
|
LLAMA_API size_t llama_state_seq_save_file(
|
||||||
|
|
380
klite.embd
380
klite.embd
|
@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
const LITEVER = 159;
|
const LITEVER = 160;
|
||||||
const urlParams = new URLSearchParams(window.location.search);
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
const localflag = true;
|
const localflag = true;
|
||||||
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
|
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
|
||||||
|
@ -3384,10 +3384,7 @@ Current version indicated by LITEVER below.
|
||||||
let is_local = false;
|
let is_local = false;
|
||||||
|
|
||||||
if (url) {
|
if (url) {
|
||||||
is_local = (url.toLowerCase().includes("localhost") ||
|
is_local = is_local_url(url);
|
||||||
url.toLowerCase().includes("127.0.0.1") ||
|
|
||||||
url.toLowerCase().includes("192.168.") ||
|
|
||||||
!url.toLowerCase().includes("."));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((uses_cors_proxy||proxy_by_default) && !is_local) {
|
if ((uses_cors_proxy||proxy_by_default) && !is_local) {
|
||||||
|
@ -3540,6 +3537,170 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function oai_api_sync_req(targetep,oai_payload,oaiheaders)
|
||||||
|
{
|
||||||
|
fetch(targetep, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: oaiheaders,
|
||||||
|
body: JSON.stringify(oai_payload),
|
||||||
|
referrerPolicy: 'no-referrer',
|
||||||
|
})
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
console.log("sync finished response: " + JSON.stringify(data));
|
||||||
|
if (custom_oai_key != "" && data.choices != null && data.choices.length > 0) {
|
||||||
|
let dch = data.choices[0];
|
||||||
|
if (dch.text) {
|
||||||
|
synchro_polled_response = dch.text;
|
||||||
|
}
|
||||||
|
else if (dch.message) {
|
||||||
|
synchro_polled_response = dch.message.content;
|
||||||
|
|
||||||
|
if(localsettings.opmode==1 && gametext_arr.length>0 && synchro_polled_response!="")
|
||||||
|
{
|
||||||
|
synchro_polled_response = cleanup_story_completion(synchro_polled_response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
console.error("Error, unknown OAI response");
|
||||||
|
clear_poll_flags();
|
||||||
|
render_gametext();
|
||||||
|
msgbox("Error, unknown OAI response");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//error occurred, maybe captcha failed
|
||||||
|
console.error("error occurred in OAI generation");
|
||||||
|
clear_poll_flags();
|
||||||
|
render_gametext();
|
||||||
|
msgbox("Error occurred during text generation: " + formatError(data));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Error:', error);
|
||||||
|
clear_poll_flags();
|
||||||
|
render_gametext();
|
||||||
|
msgbox("Error while submitting prompt: " + error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function oai_api_stream_sse(sub_endpt,submit_payload,submit_headers)
|
||||||
|
{
|
||||||
|
synchro_pending_stream = "";
|
||||||
|
let reqOpt =
|
||||||
|
{method: 'POST',
|
||||||
|
headers: submit_headers,
|
||||||
|
body: JSON.stringify(submit_payload)};
|
||||||
|
if(globalabortcontroller)
|
||||||
|
{
|
||||||
|
reqOpt.signal = globalabortcontroller.signal;
|
||||||
|
}
|
||||||
|
fetch(sub_endpt, reqOpt)
|
||||||
|
.then(x => {
|
||||||
|
if(x.ok)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}else{
|
||||||
|
return x.text().then(errdat => {
|
||||||
|
throw new Error('Error while SSE streaming: ' + errdat);
|
||||||
|
return null;
|
||||||
|
}).catch(err => {
|
||||||
|
throw new Error('Error while SSE streaming: ' + (x.statusText) + '\n' + err);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(resp => {
|
||||||
|
resp.body
|
||||||
|
.pipeThrough(new TextDecoderStream())
|
||||||
|
.pipeThrough(new TransformStream({
|
||||||
|
start(ctrl) {
|
||||||
|
ctrl.buf = '';
|
||||||
|
},
|
||||||
|
transform(chunk, ctrl) {
|
||||||
|
ctrl.buf += chunk;
|
||||||
|
let evs = [];
|
||||||
|
let m;
|
||||||
|
while ((m = /^data: (.*)\n\n/m.exec(ctrl.buf)) !== null) {
|
||||||
|
try{evs.push({data: JSON.parse(m[1])});} catch (e) {}
|
||||||
|
ctrl.buf = ctrl.buf.substring(m.index + m[0].length);
|
||||||
|
}
|
||||||
|
if (evs.length) {
|
||||||
|
ctrl.enqueue(evs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.pipeTo(new WritableStream({
|
||||||
|
write(chunk) {
|
||||||
|
let was_empty = (synchro_pending_stream=="");
|
||||||
|
//cut stream if aborted
|
||||||
|
if(pending_response_id && pending_response_id != "-1" && pending_response_id != "")
|
||||||
|
{
|
||||||
|
for (let event of chunk) {
|
||||||
|
if (event.data && event.data.choices && event.data.choices.length>0) {
|
||||||
|
if(event.data.choices[0].text)
|
||||||
|
{
|
||||||
|
synchro_pending_stream += event.data.choices[0].text;
|
||||||
|
}else if(event.data.choices[0].delta && event.data.choices[0].delta.content)
|
||||||
|
{
|
||||||
|
synchro_pending_stream += event.data.choices[0].delta.content;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(event.data.choices[0].finish_reason=="stop")
|
||||||
|
{
|
||||||
|
last_stop_reason = "stop";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(was_empty && synchro_pending_stream!="")
|
||||||
|
{
|
||||||
|
render_gametext(false);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
update_pending_stream_displays();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
close() { //end of stream
|
||||||
|
synchro_polled_response = synchro_pending_stream;
|
||||||
|
let need_clean_output = (synchro_polled_response!="" && localsettings.opmode==1 && gametext_arr.length>0 && document.getElementById("useoaichatcompl").checked);
|
||||||
|
if(need_clean_output)
|
||||||
|
{
|
||||||
|
synchro_polled_response = cleanup_story_completion(synchro_polled_response);
|
||||||
|
}
|
||||||
|
synchro_pending_stream = "";
|
||||||
|
poll_pending_response();
|
||||||
|
//handle gen failures
|
||||||
|
if(resp.status==503)
|
||||||
|
{
|
||||||
|
msgbox("Error while submitting prompt: Server appears to be busy.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
abort(error) {
|
||||||
|
console.error('Error:', error);
|
||||||
|
if(error.name!="AbortError") //aborts are silent. slightly diff logic
|
||||||
|
{
|
||||||
|
flush_streaming_text();
|
||||||
|
msgbox("Error while submitting prompt: " + error);
|
||||||
|
}
|
||||||
|
clear_poll_flags();
|
||||||
|
render_gametext();
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Error:', error);
|
||||||
|
if(error.name!="AbortError") //aborts are silent. slightly diff logic
|
||||||
|
{
|
||||||
|
flush_streaming_text();
|
||||||
|
msgbox("Error while submitting prompt: " + error);
|
||||||
|
}
|
||||||
|
clear_poll_flags();
|
||||||
|
render_gametext();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
function kobold_api_stream_sse(sub_endpt,submit_payload)
|
function kobold_api_stream_sse(sub_endpt,submit_payload)
|
||||||
{
|
{
|
||||||
synchro_pending_stream = "";
|
synchro_pending_stream = "";
|
||||||
|
@ -3572,7 +3733,7 @@ Current version indicated by LITEVER below.
|
||||||
ctrl.buf += chunk;
|
ctrl.buf += chunk;
|
||||||
let evs = [];
|
let evs = [];
|
||||||
let m;
|
let m;
|
||||||
while ((m = /^event: (.*)\ndata: (.*)\n\n/.exec(ctrl.buf)) !== null) {
|
while ((m = /^event: (.*)\ndata: (.*)\n\n/m.exec(ctrl.buf)) !== null) {
|
||||||
evs.push({event: m[1], data: JSON.parse(m[2])});
|
evs.push({event: m[1], data: JSON.parse(m[2])});
|
||||||
ctrl.buf = ctrl.buf.substring(m.index + m[0].length);
|
ctrl.buf = ctrl.buf.substring(m.index + m[0].length);
|
||||||
}
|
}
|
||||||
|
@ -5152,19 +5313,32 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function is_local_url(target_url)
|
||||||
|
{
|
||||||
|
let is_local = (target_url.toLowerCase().includes("localhost")
|
||||||
|
|| target_url.toLowerCase().includes("127.0.0.1")
|
||||||
|
|| target_url.toLowerCase().includes("192.168.")
|
||||||
|
|| target_url.toLowerCase().includes("10.0.0.")
|
||||||
|
|| target_url.toLowerCase().includes("://10.0.")
|
||||||
|
|| !target_url.toLowerCase().includes(".")); //hostname without dots cannot be wan accessible
|
||||||
|
return is_local;
|
||||||
|
}
|
||||||
|
|
||||||
|
function is_browser_supports_sse()
|
||||||
|
{
|
||||||
|
return (self.TransformStream!=null && self.TextDecoderStream!=null && self.WritableStream!=null);
|
||||||
|
}
|
||||||
function is_using_custom_ep()
|
function is_using_custom_ep()
|
||||||
{
|
{
|
||||||
return (custom_oai_key!=""||custom_kobold_endpoint!=""||custom_claude_key!=""||custom_palm_key!=""||custom_cohere_key!="");
|
return (custom_oai_key!=""||custom_kobold_endpoint!=""||custom_claude_key!=""||custom_palm_key!=""||custom_cohere_key!="");
|
||||||
}
|
}
|
||||||
|
|
||||||
function is_using_kcpp_with_streaming()
|
function is_using_kcpp_with_streaming()
|
||||||
{
|
{
|
||||||
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.30") >= 0);
|
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.30") >= 0);
|
||||||
}
|
}
|
||||||
function is_using_kcpp_with_sse() //need 1.39 for multibyte fix
|
function is_using_kcpp_with_sse() //need 1.39 for multibyte fix
|
||||||
{
|
{
|
||||||
let browsersupported = (self.TransformStream!=null && self.TextDecoderStream!=null && self.WritableStream!=null);
|
return (is_browser_supports_sse() && custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.40") >= 0);
|
||||||
return (browsersupported && custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.40") >= 0);
|
|
||||||
}
|
}
|
||||||
function is_using_kcpp_with_mirostat()
|
function is_using_kcpp_with_mirostat()
|
||||||
{
|
{
|
||||||
|
@ -7224,6 +7398,12 @@ Current version indicated by LITEVER below.
|
||||||
onDoneCallback([]);
|
onDoneCallback([]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if(cached_worker_list!=null && cached_worker_list.length>1 && performance.now() < stale_cached_worker_time)
|
||||||
|
{
|
||||||
|
console.log("Reuse cached worker list");
|
||||||
|
onDoneCallback(cached_worker_list);
|
||||||
|
return;
|
||||||
|
}
|
||||||
multifetch(worker_endpoints,(resArr,errArr)=>{
|
multifetch(worker_endpoints,(resArr,errArr)=>{
|
||||||
|
|
||||||
if(resArr && resArr.length>0)
|
if(resArr && resArr.length>0)
|
||||||
|
@ -7246,6 +7426,8 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cached_worker_list = wdata;
|
||||||
|
stale_cached_worker_time = performance.now() + 30000; //cache worker list for 30s
|
||||||
if (onDoneCallback != null) {
|
if (onDoneCallback != null) {
|
||||||
onDoneCallback(wdata);
|
onDoneCallback(wdata);
|
||||||
}
|
}
|
||||||
|
@ -7875,34 +8057,49 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function get_oai_model_dropdown()
|
||||||
|
{
|
||||||
|
let ddval = document.getElementById("customapidropdown").value;
|
||||||
|
switch(ddval)
|
||||||
|
{
|
||||||
|
case "3":
|
||||||
|
return document.getElementById("custom_openrouter_model");
|
||||||
|
case "7":
|
||||||
|
return document.getElementById("custom_mistralai_model");
|
||||||
|
default:
|
||||||
|
return document.getElementById("custom_oai_model");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function ep_should_always_use_chat_completions()
|
||||||
|
{
|
||||||
|
let epchoice = document.getElementById("customapidropdown").value;
|
||||||
|
return (epchoice==7);
|
||||||
|
}
|
||||||
|
|
||||||
function select_custom_oai_model()
|
function select_custom_oai_model()
|
||||||
{
|
{
|
||||||
let isOpenrouter = (document.getElementById("customapidropdown").value==3);
|
|
||||||
let isMistralai = (document.getElementById("customapidropdown").value==7);
|
|
||||||
inputBox("Enter custom model name","Custom Model Name",localsettings.saved_oai_custommodel,"", ()=>{
|
inputBox("Enter custom model name","Custom Model Name",localsettings.saved_oai_custommodel,"", ()=>{
|
||||||
let coai = getInputBoxValue().trim();
|
let coai = getInputBoxValue().trim();
|
||||||
let dropdown = (isMistralai?document.getElementById("custom_mistralai_model"):(isOpenrouter?document.getElementById("custom_openrouter_model"):document.getElementById("custom_oai_model")));
|
let dropdown = get_oai_model_dropdown();
|
||||||
let mdlopt = (isMistralai?document.getElementById("custom_mistralai_model_option"):(isOpenrouter?"custom_openrouter_model_option":"custom_oai_model_option"));
|
var mdlopt = dropdown.querySelector('option.custom_model_option');
|
||||||
if(coai!="")
|
if(coai!="")
|
||||||
{
|
{
|
||||||
document.getElementById(mdlopt).value = coai;
|
mdlopt.value = coai;
|
||||||
document.getElementById(mdlopt).innerText = coai;
|
mdlopt.innerText = coai;
|
||||||
document.getElementById(mdlopt).style.display = "";
|
mdlopt.style.display = "";
|
||||||
dropdown.selectedIndex = dropdown.options.length - 1;
|
dropdown.selectedIndex = dropdown.options.length - 1;
|
||||||
}
|
}
|
||||||
oai_model_change(isOpenrouter||isMistralai);
|
oai_model_change(ep_should_always_use_chat_completions());
|
||||||
},false);
|
},false);
|
||||||
}
|
}
|
||||||
function oai_model_change(autotoggle_check = false)
|
function oai_model_change(autotoggle_check = false)
|
||||||
{
|
{
|
||||||
let isOpenrouter = (document.getElementById("customapidropdown").value==3);
|
let dropdown = get_oai_model_dropdown();
|
||||||
let isMistralai = (document.getElementById("customapidropdown").value==7);
|
|
||||||
let dropdown = (isMistralai?document.getElementById("custom_mistralai_model"):(isOpenrouter?document.getElementById("custom_openrouter_model"):document.getElementById("custom_oai_model")));
|
|
||||||
let non_completions = (dropdown.value.includes("davinci-002") || dropdown.value.includes("text-davinci-003") || dropdown.value.includes("text-davinci-002")
|
let non_completions = (dropdown.value.includes("davinci-002") || dropdown.value.includes("text-davinci-003") || dropdown.value.includes("text-davinci-002")
|
||||||
|| dropdown.value.includes("text-davinci-001") || dropdown.value.includes("gpt-3.5-turbo-instruct") || dropdown.value == "davinci");
|
|| dropdown.value.includes("text-davinci-001") || dropdown.value.includes("gpt-3.5-turbo-instruct") || dropdown.value == "davinci");
|
||||||
if(autotoggle_check)
|
if(autotoggle_check)
|
||||||
{
|
{
|
||||||
if(isMistralai || isOpenrouter || dropdown.selectedIndex==dropdown.options.length-1)
|
if(ep_should_always_use_chat_completions() || dropdown.selectedIndex==dropdown.options.length-1)
|
||||||
{
|
{
|
||||||
document.getElementById("useoaichatcompl").checked = true;
|
document.getElementById("useoaichatcompl").checked = true;
|
||||||
} else {
|
} else {
|
||||||
|
@ -7919,6 +8116,11 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
desired_oai_ep = desired_oai_ep.slice(0, -1);
|
desired_oai_ep = desired_oai_ep.slice(0, -1);
|
||||||
}
|
}
|
||||||
|
if(!desired_oai_ep.includes("://")) //user did not add http/https
|
||||||
|
{
|
||||||
|
let is_local = is_local_url(desired_oai_ep);
|
||||||
|
desired_oai_ep = (is_local?"http://":"https://") + desired_oai_ep;
|
||||||
|
}
|
||||||
if (document.getElementById("oaiaddversion").checked)
|
if (document.getElementById("oaiaddversion").checked)
|
||||||
{
|
{
|
||||||
if(desired_oai_ep!="" && desired_oai_ep.length > 4 && !desired_oai_ep.slice(-4).toLowerCase().includes("/v") && !desired_oai_ep.toLowerCase().includes("/v1/")) {
|
if(desired_oai_ep!="" && desired_oai_ep.length > 4 && !desired_oai_ep.slice(-4).toLowerCase().includes("/v") && !desired_oai_ep.toLowerCase().includes("/v1/")) {
|
||||||
|
@ -7939,6 +8141,8 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let isOpenrouter = (document.getElementById("customapidropdown").value==3);
|
||||||
|
let dropdown = get_oai_model_dropdown();
|
||||||
fetch((desired_oai_ep + oai_models_endpoint), {
|
fetch((desired_oai_ep + oai_models_endpoint), {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: oaiheaders,
|
headers: oaiheaders,
|
||||||
|
@ -7957,9 +8161,6 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
if (!data.error && data.data && data.data.length > 0)
|
if (!data.error && data.data && data.data.length > 0)
|
||||||
{
|
{
|
||||||
let isOpenrouter = (document.getElementById("customapidropdown").value==3);
|
|
||||||
let isMistralai = (document.getElementById("customapidropdown").value==7);
|
|
||||||
let dropdown = (isMistralai?document.getElementById("custom_mistralai_model"):(isOpenrouter?document.getElementById("custom_openrouter_model"):document.getElementById("custom_oai_model")));
|
|
||||||
var lastOption = dropdown.lastElementChild;
|
var lastOption = dropdown.lastElementChild;
|
||||||
for (var i = dropdown.options.length - 1; i >= 0; i--) {
|
for (var i = dropdown.options.length - 1; i >= 0; i--) {
|
||||||
var option = dropdown.options[i];
|
var option = dropdown.options[i];
|
||||||
|
@ -8041,12 +8242,13 @@ Current version indicated by LITEVER below.
|
||||||
else if(epchoice==2 || epchoice==3 || epchoice==7)
|
else if(epchoice==2 || epchoice==3 || epchoice==7)
|
||||||
{
|
{
|
||||||
document.getElementById("oaicustom").classList.remove("hidden");
|
document.getElementById("oaicustom").classList.remove("hidden");
|
||||||
|
document.getElementById("openrouterdesc").classList.add("hidden");
|
||||||
|
document.getElementById("mistralaidesc").classList.add("hidden");
|
||||||
|
document.getElementById("oaidesc").classList.add("hidden");
|
||||||
if(epchoice==2)
|
if(epchoice==2)
|
||||||
{
|
{
|
||||||
document.getElementById("oaidesc").classList.remove("hidden");
|
document.getElementById("oaidesc").classList.remove("hidden");
|
||||||
document.getElementById("custom_oai_model").classList.remove("hidden");
|
document.getElementById("custom_oai_model").classList.remove("hidden");
|
||||||
document.getElementById("openrouterdesc").classList.add("hidden");
|
|
||||||
document.getElementById("mistralaidesc").classList.add("hidden");
|
|
||||||
document.getElementById("custom_oai_endpoint").classList.remove("hidden");
|
document.getElementById("custom_oai_endpoint").classList.remove("hidden");
|
||||||
document.getElementById("custom_oai_key").value = localsettings.saved_oai_key;
|
document.getElementById("custom_oai_key").value = localsettings.saved_oai_key;
|
||||||
if (localflag) {
|
if (localflag) {
|
||||||
|
@ -8057,19 +8259,15 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
else if(epchoice==7)
|
else if(epchoice==7)
|
||||||
{
|
{
|
||||||
document.getElementById("oaidesc").classList.add("hidden");
|
|
||||||
document.getElementById("custom_mistralai_model").classList.remove("hidden");
|
document.getElementById("custom_mistralai_model").classList.remove("hidden");
|
||||||
document.getElementById("openrouterdesc").classList.add("hidden");
|
|
||||||
document.getElementById("mistralaidesc").classList.remove("hidden");
|
document.getElementById("mistralaidesc").classList.remove("hidden");
|
||||||
document.getElementById("custom_oai_endpoint").classList.add("hidden");
|
document.getElementById("custom_oai_endpoint").classList.add("hidden");
|
||||||
document.getElementById("custom_oai_key").value = localsettings.saved_mistralai_key;
|
document.getElementById("custom_oai_key").value = localsettings.saved_mistralai_key;
|
||||||
document.getElementById("custom_oai_endpoint").value = default_mistralai_base;
|
document.getElementById("custom_oai_endpoint").value = default_mistralai_base;
|
||||||
}
|
}
|
||||||
else
|
else //openrouter supports autofetch
|
||||||
{
|
{
|
||||||
document.getElementById("oaidesc").classList.add("hidden");
|
|
||||||
document.getElementById("openrouterdesc").classList.remove("hidden");
|
document.getElementById("openrouterdesc").classList.remove("hidden");
|
||||||
document.getElementById("mistralaidesc").classList.add("hidden");
|
|
||||||
document.getElementById("custom_openrouter_model").classList.remove("hidden");
|
document.getElementById("custom_openrouter_model").classList.remove("hidden");
|
||||||
document.getElementById("custom_oai_endpoint").value = default_openrouter_base;
|
document.getElementById("custom_oai_endpoint").value = default_openrouter_base;
|
||||||
document.getElementById("custom_oai_endpoint").classList.add("hidden");
|
document.getElementById("custom_oai_endpoint").classList.add("hidden");
|
||||||
|
@ -8084,7 +8282,7 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
oai_model_change(epchoice==3 || epchoice==7 || force_autotoggle_chatcompl);
|
oai_model_change(ep_should_always_use_chat_completions() || force_autotoggle_chatcompl);
|
||||||
toggleoaichatcompl();
|
toggleoaichatcompl();
|
||||||
}
|
}
|
||||||
else if(epchoice==4)
|
else if(epchoice==4)
|
||||||
|
@ -8175,6 +8373,11 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
desiredkoboldendpoint = desiredkoboldendpoint.split("/api")[0];
|
desiredkoboldendpoint = desiredkoboldendpoint.split("/api")[0];
|
||||||
}
|
}
|
||||||
|
if(!desiredkoboldendpoint.includes("://")) //user did not add http/https
|
||||||
|
{
|
||||||
|
let is_local = is_local_url(desiredkoboldendpoint);
|
||||||
|
desiredkoboldendpoint = (is_local?"http://":"https://") + desiredkoboldendpoint;
|
||||||
|
}
|
||||||
|
|
||||||
let urls1 = [
|
let urls1 = [
|
||||||
apply_proxy_url(desiredkoboldendpoint + kobold_custom_mdl_endpoint),
|
apply_proxy_url(desiredkoboldendpoint + kobold_custom_mdl_endpoint),
|
||||||
|
@ -8395,10 +8598,7 @@ Current version indicated by LITEVER below.
|
||||||
//if it still fails, then show error
|
//if it still fails, then show error
|
||||||
console.log("Error: " + error);
|
console.log("Error: " + error);
|
||||||
|
|
||||||
let is_local = (custom_kobold_endpoint.toLowerCase().includes("localhost")
|
let is_local = is_local_url(custom_kobold_endpoint);
|
||||||
|| custom_kobold_endpoint.toLowerCase().includes("127.0.0.1")
|
|
||||||
|| custom_kobold_endpoint.toLowerCase().includes("192.168.")
|
|
||||||
|| !custom_kobold_endpoint.toLowerCase().includes(".")); //hostname without dots cannot be wan accessible
|
|
||||||
|
|
||||||
if (uses_cors_proxy || is_local) {
|
if (uses_cors_proxy || is_local) {
|
||||||
if(is_local && sublocalpathname!="")
|
if(is_local && sublocalpathname!="")
|
||||||
|
@ -8445,6 +8645,11 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
desired_oai_ep = desired_oai_ep.slice(0, -1);
|
desired_oai_ep = desired_oai_ep.slice(0, -1);
|
||||||
}
|
}
|
||||||
|
if(!desired_oai_ep.includes("://")) //user did not add http/https
|
||||||
|
{
|
||||||
|
let is_local = is_local_url(desired_oai_ep);
|
||||||
|
desired_oai_ep = (is_local?"http://":"https://") + desired_oai_ep;
|
||||||
|
}
|
||||||
if (document.getElementById("oaiaddversion").checked)
|
if (document.getElementById("oaiaddversion").checked)
|
||||||
{
|
{
|
||||||
if(desired_oai_ep!="" && desired_oai_ep.length > 4 && !desired_oai_ep.slice(-4).toLowerCase().includes("/v") && !desired_oai_ep.toLowerCase().includes("/v1/")) {
|
if(desired_oai_ep!="" && desired_oai_ep.length > 4 && !desired_oai_ep.slice(-4).toLowerCase().includes("/v") && !desired_oai_ep.toLowerCase().includes("/v1/")) {
|
||||||
|
@ -8480,9 +8685,7 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
localsettings.saved_oai_role = document.getElementById("oairoledropdown").value;
|
localsettings.saved_oai_role = document.getElementById("oairoledropdown").value;
|
||||||
localsettings.saved_oai_jailbreak2 = document.getElementById("jailbreakprompttext2").value;
|
localsettings.saved_oai_jailbreak2 = document.getElementById("jailbreakprompttext2").value;
|
||||||
let isOpenrouter = (document.getElementById("customapidropdown").value==3);
|
let dropdown = get_oai_model_dropdown();
|
||||||
let isMistralai = (document.getElementById("customapidropdown").value==7);
|
|
||||||
let dropdown = (isMistralai?document.getElementById("custom_mistralai_model"):(isOpenrouter?document.getElementById("custom_openrouter_model"):document.getElementById("custom_oai_model")));
|
|
||||||
custom_oai_model = dropdown.value.trim();
|
custom_oai_model = dropdown.value.trim();
|
||||||
localsettings.saved_oai_custommodel = custom_oai_model;
|
localsettings.saved_oai_custommodel = custom_oai_model;
|
||||||
selected_models = [{ "performance": 100.0, "queued": 0.0, "eta": 0, "name": custom_oai_model, "count": 1 }];
|
selected_models = [{ "performance": 100.0, "queued": 0.0, "eta": 0, "name": custom_oai_model, "count": 1 }];
|
||||||
|
@ -8751,7 +8954,9 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
|
|
||||||
var cached_model_list = null;
|
var cached_model_list = null;
|
||||||
|
var cached_worker_list = null;
|
||||||
var stale_cached_model_time = performance.now();
|
var stale_cached_model_time = performance.now();
|
||||||
|
var stale_cached_worker_time = performance.now();
|
||||||
function fetch_models(onDoneCallback)
|
function fetch_models(onDoneCallback)
|
||||||
{
|
{
|
||||||
if(localflag)
|
if(localflag)
|
||||||
|
@ -8786,7 +8991,7 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_model_list = mdls;
|
cached_model_list = mdls;
|
||||||
stale_cached_model_time = performance.now() + 30000; //cache model list for 1m
|
stale_cached_model_time = performance.now() + 30000; //cache model list for 30s
|
||||||
onDoneCallback(mdls);
|
onDoneCallback(mdls);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -11738,6 +11943,26 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
function cleanup_story_completion(resp)
|
function cleanup_story_completion(resp)
|
||||||
{
|
{
|
||||||
|
if(gametext_arr.length>0)
|
||||||
|
{
|
||||||
|
//fix duplicate sentences
|
||||||
|
const sentenceEndings = /[.!?]/g;
|
||||||
|
let lastsentences = gametext_arr[gametext_arr.length-1].split(sentenceEndings);
|
||||||
|
lastsentences = lastsentences.map(lastsentences => lastsentences.trim()); //remove whitespace
|
||||||
|
lastsentences = lastsentences.filter(lastsentences => lastsentences.length > 0);
|
||||||
|
if(lastsentences.length>0)
|
||||||
|
{
|
||||||
|
let lastsentence = lastsentences[lastsentences.length - 1];
|
||||||
|
if(lastsentence.length>10 && resp.trim().startsWith(lastsentence)) //only match if its long enough and matches verbatim
|
||||||
|
{
|
||||||
|
let foundindex = resp.indexOf(lastsentence);
|
||||||
|
if (foundindex !== -1 && foundindex<5) {
|
||||||
|
resp = resp.substring(foundindex+lastsentence.length); //remove duplicated part
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//fix response lacking space
|
||||||
if(!gametext_arr[gametext_arr.length-1].endsWith(" ") && !gametext_arr[gametext_arr.length-1].endsWith("\n"))
|
if(!gametext_arr[gametext_arr.length-1].endsWith(" ") && !gametext_arr[gametext_arr.length-1].endsWith("\n"))
|
||||||
{
|
{
|
||||||
if(/^\.\.\.[a-zA-Z0-9]/.test(resp))
|
if(/^\.\.\.[a-zA-Z0-9]/.test(resp))
|
||||||
|
@ -11749,6 +11974,7 @@ Current version indicated by LITEVER below.
|
||||||
resp = " "+resp;
|
resp = " "+resp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return resp;
|
return resp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11911,9 +12137,9 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
|
{
|
||||||
//apply custom logit bias for official OAI only
|
//apply custom logit bias for official OAI only
|
||||||
|
|
||||||
let needbaneos = (custom_oai_endpoint.toLowerCase().includes("api.openai.com") && determine_if_ban_eos(input_was_empty));
|
let needbaneos = (custom_oai_endpoint.toLowerCase().includes("api.openai.com") && determine_if_ban_eos(input_was_empty));
|
||||||
|
|
||||||
if(needbaneos)
|
if(needbaneos)
|
||||||
|
@ -11926,6 +12152,9 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
oai_payload.prompt = submit_payload.prompt;
|
oai_payload.prompt = submit_payload.prompt;
|
||||||
|
|
||||||
|
//lets try adding stop sequences, limit to first 4
|
||||||
|
oai_payload.stop = get_stop_sequences().slice(0, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
last_request_str = JSON.stringify(oai_payload);
|
last_request_str = JSON.stringify(oai_payload);
|
||||||
|
@ -11945,49 +12174,15 @@ Current version indicated by LITEVER below.
|
||||||
oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net";
|
oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net";
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch(targetep, {
|
if(is_browser_supports_sse() && document.getElementById("oaistreaming").checked)
|
||||||
method: 'POST',
|
|
||||||
headers: oaiheaders,
|
|
||||||
body: JSON.stringify(oai_payload),
|
|
||||||
referrerPolicy: 'no-referrer',
|
|
||||||
})
|
|
||||||
.then((response) => response.json())
|
|
||||||
.then((data) => {
|
|
||||||
console.log("sync finished response: " + JSON.stringify(data));
|
|
||||||
if (custom_oai_key != "" && data.choices != null && data.choices.length > 0) {
|
|
||||||
let dch = data.choices[0];
|
|
||||||
if (dch.text) {
|
|
||||||
synchro_polled_response = dch.text;
|
|
||||||
}
|
|
||||||
else if (dch.message) {
|
|
||||||
synchro_polled_response = dch.message.content;
|
|
||||||
|
|
||||||
if(localsettings.opmode==1 && gametext_arr.length>0 && synchro_polled_response!="")
|
|
||||||
{
|
{
|
||||||
synchro_polled_response = cleanup_story_completion(synchro_polled_response);
|
oai_payload.stream = true;
|
||||||
|
oai_api_stream_sse(targetep,oai_payload,oaiheaders);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
oai_api_sync_req(targetep,oai_payload,oaiheaders);
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
console.error("Error, unknown OAI response");
|
|
||||||
clear_poll_flags();
|
|
||||||
render_gametext();
|
|
||||||
msgbox("Error, unknown OAI response");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
//error occurred, maybe captcha failed
|
|
||||||
console.error("error occurred in OAI generation");
|
|
||||||
clear_poll_flags();
|
|
||||||
render_gametext();
|
|
||||||
msgbox("Error occurred during text generation: " + formatError(data));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error('Error:', error);
|
|
||||||
clear_poll_flags();
|
|
||||||
render_gametext();
|
|
||||||
msgbox("Error while submitting prompt: " + error);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
else if (custom_claude_key != "")//handle for Claude
|
else if (custom_claude_key != "")//handle for Claude
|
||||||
{
|
{
|
||||||
|
@ -17069,7 +17264,7 @@ Current version indicated by LITEVER below.
|
||||||
<option value="gpt-4-turbo">gpt-4-turbo</option>
|
<option value="gpt-4-turbo">gpt-4-turbo</option>
|
||||||
<option value="gpt-4o">gpt-4o</option>
|
<option value="gpt-4o">gpt-4o</option>
|
||||||
<option value="gpt-4-32k">gpt-4-32k</option>
|
<option value="gpt-4-32k">gpt-4-32k</option>
|
||||||
<option style="display:none;" id="custom_oai_model_option" value="custom">[Custom]</option>
|
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
|
||||||
</select>
|
</select>
|
||||||
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_openrouter_model" onchange="oai_model_change(true)">
|
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_openrouter_model" onchange="oai_model_change(true)">
|
||||||
<option value="openai/gpt-3.5-turbo">openai/gpt-3.5-turbo</option>
|
<option value="openai/gpt-3.5-turbo">openai/gpt-3.5-turbo</option>
|
||||||
|
@ -17078,8 +17273,8 @@ Current version indicated by LITEVER below.
|
||||||
<option value="mistralai/mistral-7b-instruct" selected="selected">mistralai/mistral-7b-instruct</option>
|
<option value="mistralai/mistral-7b-instruct" selected="selected">mistralai/mistral-7b-instruct</option>
|
||||||
<option value="gryphe/mythomax-l2-13b">gryphe/mythomax-l2-13b</option>
|
<option value="gryphe/mythomax-l2-13b">gryphe/mythomax-l2-13b</option>
|
||||||
<option value="huggingfaceh4/zephyr-7b-beta">huggingfaceh4/zephyr-7b-beta</option>
|
<option value="huggingfaceh4/zephyr-7b-beta">huggingfaceh4/zephyr-7b-beta</option>
|
||||||
<option value="anthropic/claude-2">anthropic/claude-2</option>
|
<option value="anthropic/claude-2.0">anthropic/claude-2.0</option>
|
||||||
<option style="display:none;" id="custom_openrouter_model_option" value="custom">[Custom]</option>
|
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
|
||||||
</select>
|
</select>
|
||||||
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_mistralai_model" onchange="oai_model_change(true)">
|
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_mistralai_model" onchange="oai_model_change(true)">
|
||||||
<option value="open-mistral-7b">open-mistral-7b</option>
|
<option value="open-mistral-7b">open-mistral-7b</option>
|
||||||
|
@ -17089,15 +17284,18 @@ Current version indicated by LITEVER below.
|
||||||
<option value="open-mixtral-8x22b">open-mixtral-8x22b</option>
|
<option value="open-mixtral-8x22b">open-mixtral-8x22b</option>
|
||||||
<option value="mistral-medium">mistral-medium</option>
|
<option value="mistral-medium">mistral-medium</option>
|
||||||
<option value="mistral-large-latest">mistral-large-latest</option>
|
<option value="mistral-large-latest">mistral-large-latest</option>
|
||||||
<option style="display:none;" id="custom_mistralai_model_option" value="custom">[Custom]</option>
|
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
|
||||||
</select>
|
</select>
|
||||||
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaifetchlist" onclick="oai_fetch_models()">Fetch List</button>
|
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaifetchlist" onclick="oai_fetch_models()">Fetch List</button>
|
||||||
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaiusecustom" onclick="select_custom_oai_model()">Use Custom</button>
|
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaiusecustom" onclick="select_custom_oai_model()">Use Custom</button>
|
||||||
<input type="checkbox" id="oaiaddversion" onchange="" checked>
|
<div style="display:inline-flex">
|
||||||
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
|
<div><input type="checkbox" id="oaiaddversion" title="Add Endpoint Version Number" onchange="" checked>
|
||||||
<input type="checkbox" id="useoaichatcompl" onchange="toggleoaichatcompl()">
|
<div class="box-label">Add Version Num</div></div>
|
||||||
<div class="box-label" id="useoaichatcompllabel">Use ChatCompletions API</div>
|
<div><input type="checkbox" id="oaistreaming" title="Enable SSE Streaming" onchange="">
|
||||||
|
<div class="box-label">Streaming</div></div>
|
||||||
|
<div><input type="checkbox" id="useoaichatcompl" title="Use ChatCompletions API" onchange="toggleoaichatcompl()">
|
||||||
|
<div class="box-label" id="useoaichatcompllabel">ChatCompletions API</div></div>
|
||||||
|
</div>
|
||||||
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
|
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
|
||||||
<br>
|
<br>
|
||||||
Main Message Role:
|
Main Message Role:
|
||||||
|
|
1334
src/llama.cpp
1334
src/llama.cpp
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue