mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # CMakeLists.txt # cmake/common.cmake # docs/backend/SYCL.md # examples/main/README.md # examples/speculative/speculative.cpp # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-musa/CMakeLists.txt # ggml/src/ggml-sycl/CMakeLists.txt # ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt # tests/test-backend-ops.cpp
This commit is contained in:
commit
0c90d2ebcf
58 changed files with 4222 additions and 1537 deletions
|
@ -1,519 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# Options
|
|
||||||
IOS_MIN_OS_VERSION=16.4
|
|
||||||
MACOS_MIN_OS_VERSION=13.3
|
|
||||||
VISIONOS_MIN_OS_VERSION=1.0
|
|
||||||
TVOS_MIN_OS_VERSION=16.4
|
|
||||||
|
|
||||||
BUILD_SHARED_LIBS=OFF
|
|
||||||
LLAMA_BUILD_EXAMPLES=OFF
|
|
||||||
LLAMA_BUILD_TESTS=OFF
|
|
||||||
LLAMA_BUILD_SERVER=OFF
|
|
||||||
GGML_METAL=ON
|
|
||||||
GGML_METAL_EMBED_LIBRARY=ON
|
|
||||||
GGML_BLAS_DEFAULT=ON
|
|
||||||
GGML_METAL_USE_BF16=ON
|
|
||||||
GGML_OPENMP=OFF
|
|
||||||
|
|
||||||
COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
|
|
||||||
COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
|
|
||||||
|
|
||||||
# Common options for all builds
|
|
||||||
COMMON_CMAKE_ARGS=(
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED=NO
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY=""
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED=NO
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_DEBUG_INFORMATION_FORMAT="dwarf-with-dsym"
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_GCC_GENERATE_DEBUGGING_SYMBOLS=YES
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_COPY_PHASE_STRIP=NO
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_STRIP_INSTALLED_PRODUCT=NO
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
|
||||||
-DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS}
|
|
||||||
-DLLAMA_BUILD_EXAMPLES=${LLAMA_BUILD_EXAMPLES}
|
|
||||||
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
|
|
||||||
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
|
|
||||||
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
|
|
||||||
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
|
|
||||||
-DGGML_METAL=${GGML_METAL}
|
|
||||||
-DGGML_METAL_USE_BF16=${GGML_METAL_USE_BF16}
|
|
||||||
-DGGML_NATIVE=OFF
|
|
||||||
-DGGML_OPENMP=${GGML_OPENMP}
|
|
||||||
)
|
|
||||||
|
|
||||||
check_required_tool() {
|
|
||||||
local tool=$1
|
|
||||||
local install_message=$2
|
|
||||||
|
|
||||||
if ! command -v $tool &> /dev/null; then
|
|
||||||
echo "Error: $tool is required but not found."
|
|
||||||
echo "$install_message"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
echo "Checking for required tools..."
|
|
||||||
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
|
|
||||||
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
|
||||||
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
|
|
||||||
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
## Clean up previous builds
|
|
||||||
rm -rf build-apple
|
|
||||||
rm -rf build-ios-sim
|
|
||||||
rm -rf build-ios-device
|
|
||||||
rm -rf build-macos
|
|
||||||
rm -rf build-visionos
|
|
||||||
rm -rf build-visionos-sim
|
|
||||||
rm -rf build-tvos-sim
|
|
||||||
rm -rf build-tvos-device
|
|
||||||
|
|
||||||
# Setup the xcframework build directory structure
|
|
||||||
setup_framework_structure() {
|
|
||||||
local build_dir=$1
|
|
||||||
local min_os_version=$2
|
|
||||||
local platform=$3 # "ios", "macos", "visionos", or "tvos"
|
|
||||||
local framework_name="llama"
|
|
||||||
|
|
||||||
echo "Creating ${platform}-style framework structure for ${build_dir}"
|
|
||||||
|
|
||||||
if [[ "$platform" == "macos" ]]; then
|
|
||||||
# macOS versioned structure uses versioned directories
|
|
||||||
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Headers
|
|
||||||
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Modules
|
|
||||||
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Resources
|
|
||||||
|
|
||||||
# Create symbolic links
|
|
||||||
ln -sf A ${build_dir}/framework/${framework_name}.framework/Versions/Current
|
|
||||||
ln -sf Versions/Current/Headers ${build_dir}/framework/${framework_name}.framework/Headers
|
|
||||||
ln -sf Versions/Current/Modules ${build_dir}/framework/${framework_name}.framework/Modules
|
|
||||||
ln -sf Versions/Current/Resources ${build_dir}/framework/${framework_name}.framework/Resources
|
|
||||||
ln -sf Versions/Current/${framework_name} ${build_dir}/framework/${framework_name}.framework/${framework_name}
|
|
||||||
|
|
||||||
# Set header and module paths
|
|
||||||
local header_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Headers/
|
|
||||||
local module_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Modules/
|
|
||||||
else
|
|
||||||
# iOS/VisionOS/tvOS use a flat structure
|
|
||||||
mkdir -p ${build_dir}/framework/${framework_name}.framework/Headers
|
|
||||||
mkdir -p ${build_dir}/framework/${framework_name}.framework/Modules
|
|
||||||
|
|
||||||
# Remove any existing structure to ensure clean build
|
|
||||||
rm -rf ${build_dir}/framework/${framework_name}.framework/Versions
|
|
||||||
|
|
||||||
# Set header and module paths
|
|
||||||
local header_path=${build_dir}/framework/${framework_name}.framework/Headers/
|
|
||||||
local module_path=${build_dir}/framework/${framework_name}.framework/Modules/
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Copy all required headers (common for all platforms)
|
|
||||||
cp include/llama.h ${header_path}
|
|
||||||
cp ggml/include/ggml.h ${header_path}
|
|
||||||
cp ggml/include/ggml-alloc.h ${header_path}
|
|
||||||
cp ggml/include/ggml-backend.h ${header_path}
|
|
||||||
cp ggml/include/ggml-metal.h ${header_path}
|
|
||||||
cp ggml/include/ggml-cpu.h ${header_path}
|
|
||||||
cp ggml/include/ggml-blas.h ${header_path}
|
|
||||||
cp ggml/include/gguf.h ${header_path}
|
|
||||||
|
|
||||||
# Create module map (common for all platforms)
|
|
||||||
cat > ${module_path}module.modulemap << EOF
|
|
||||||
framework module llama {
|
|
||||||
header "llama.h"
|
|
||||||
header "ggml.h"
|
|
||||||
header "ggml-alloc.h"
|
|
||||||
header "ggml-backend.h"
|
|
||||||
header "ggml-metal.h"
|
|
||||||
header "ggml-cpu.h"
|
|
||||||
header "ggml-blas.h"
|
|
||||||
header "gguf.h"
|
|
||||||
|
|
||||||
link "c++"
|
|
||||||
link framework "Accelerate"
|
|
||||||
link framework "Metal"
|
|
||||||
link framework "Foundation"
|
|
||||||
|
|
||||||
export *
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Platform-specific settings for Info.plist
|
|
||||||
local platform_name=""
|
|
||||||
local sdk_name=""
|
|
||||||
local supported_platform=""
|
|
||||||
|
|
||||||
case "$platform" in
|
|
||||||
"ios")
|
|
||||||
platform_name="iphoneos"
|
|
||||||
sdk_name="iphoneos${min_os_version}"
|
|
||||||
supported_platform="iPhoneOS"
|
|
||||||
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
|
||||||
local device_family=' <key>UIDeviceFamily</key>
|
|
||||||
<array>
|
|
||||||
<integer>1</integer>
|
|
||||||
<integer>2</integer>
|
|
||||||
</array>'
|
|
||||||
;;
|
|
||||||
"macos")
|
|
||||||
platform_name="macosx"
|
|
||||||
sdk_name="macosx${min_os_version}"
|
|
||||||
supported_platform="MacOSX"
|
|
||||||
local plist_path="${build_dir}/framework/${framework_name}.framework/Versions/A/Resources/Info.plist"
|
|
||||||
local device_family=""
|
|
||||||
;;
|
|
||||||
"visionos")
|
|
||||||
platform_name="xros"
|
|
||||||
sdk_name="xros${min_os_version}"
|
|
||||||
supported_platform="XRPlatform"
|
|
||||||
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
|
||||||
local device_family=""
|
|
||||||
;;
|
|
||||||
"tvos")
|
|
||||||
platform_name="appletvos"
|
|
||||||
sdk_name="appletvos${min_os_version}"
|
|
||||||
supported_platform="AppleTVOS"
|
|
||||||
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
|
||||||
local device_family=' <key>UIDeviceFamily</key>
|
|
||||||
<array>
|
|
||||||
<integer>3</integer>
|
|
||||||
</array>'
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
# Create Info.plist
|
|
||||||
cat > ${plist_path} << EOF
|
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
||||||
<plist version="1.0">
|
|
||||||
<dict>
|
|
||||||
<key>CFBundleDevelopmentRegion</key>
|
|
||||||
<string>en</string>
|
|
||||||
<key>CFBundleExecutable</key>
|
|
||||||
<string>llama</string>
|
|
||||||
<key>CFBundleIdentifier</key>
|
|
||||||
<string>org.ggml.llama</string>
|
|
||||||
<key>CFBundleInfoDictionaryVersion</key>
|
|
||||||
<string>6.0</string>
|
|
||||||
<key>CFBundleName</key>
|
|
||||||
<string>llama</string>
|
|
||||||
<key>CFBundlePackageType</key>
|
|
||||||
<string>FMWK</string>
|
|
||||||
<key>CFBundleShortVersionString</key>
|
|
||||||
<string>1.0</string>
|
|
||||||
<key>CFBundleVersion</key>
|
|
||||||
<string>1</string>
|
|
||||||
<key>MinimumOSVersion</key>
|
|
||||||
<string>${min_os_version}</string>
|
|
||||||
<key>CFBundleSupportedPlatforms</key>
|
|
||||||
<array>
|
|
||||||
<string>${supported_platform}</string>
|
|
||||||
</array>${device_family}
|
|
||||||
<key>DTPlatformName</key>
|
|
||||||
<string>${platform_name}</string>
|
|
||||||
<key>DTSDKName</key>
|
|
||||||
<string>${sdk_name}</string>
|
|
||||||
</dict>
|
|
||||||
</plist>
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create dynamic libraries from static libraries.
|
|
||||||
combine_static_libraries() {
|
|
||||||
local build_dir="$1"
|
|
||||||
local release_dir="$2"
|
|
||||||
local platform="$3" # "ios", "macos", "visionos", or "tvos"
|
|
||||||
local is_simulator="$4"
|
|
||||||
local base_dir="$(pwd)"
|
|
||||||
local framework_name="llama"
|
|
||||||
|
|
||||||
# Determine output path based on platform
|
|
||||||
local output_lib=""
|
|
||||||
if [[ "$platform" == "macos" ]]; then
|
|
||||||
# macOS uses versioned structure
|
|
||||||
output_lib="${build_dir}/framework/${framework_name}.framework/Versions/A/${framework_name}"
|
|
||||||
else
|
|
||||||
# iOS, visionOS, and tvOS use a directory flat structure
|
|
||||||
output_lib="${build_dir}/framework/${framework_name}.framework/${framework_name}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
local libs=(
|
|
||||||
"${base_dir}/${build_dir}/src/${release_dir}/libllama.a"
|
|
||||||
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml.a"
|
|
||||||
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-base.a"
|
|
||||||
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
|
|
||||||
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
|
|
||||||
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create temporary directory for processing
|
|
||||||
local temp_dir="${base_dir}/${build_dir}/temp"
|
|
||||||
mkdir -p "${temp_dir}"
|
|
||||||
|
|
||||||
# Since we have multiple architectures libtool will find object files that do not
|
|
||||||
# match the target architecture. We suppress these warnings.
|
|
||||||
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
|
||||||
|
|
||||||
# Determine SDK, architectures, and install_name based on platform and simulator flag.
|
|
||||||
local sdk=""
|
|
||||||
local archs=""
|
|
||||||
local min_version_flag=""
|
|
||||||
local install_name=""
|
|
||||||
|
|
||||||
case "$platform" in
|
|
||||||
"ios")
|
|
||||||
if [[ "$is_simulator" == "true" ]]; then
|
|
||||||
sdk="iphonesimulator"
|
|
||||||
archs="arm64 x86_64"
|
|
||||||
min_version_flag="-mios-simulator-version-min=${IOS_MIN_OS_VERSION}"
|
|
||||||
else
|
|
||||||
sdk="iphoneos"
|
|
||||||
archs="arm64"
|
|
||||||
min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
|
|
||||||
fi
|
|
||||||
install_name="@rpath/llama.framework/llama"
|
|
||||||
;;
|
|
||||||
"macos")
|
|
||||||
sdk="macosx"
|
|
||||||
archs="arm64 x86_64"
|
|
||||||
min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
|
|
||||||
install_name="@rpath/llama.framework/Versions/Current/llama"
|
|
||||||
;;
|
|
||||||
"visionos")
|
|
||||||
if [[ "$is_simulator" == "true" ]]; then
|
|
||||||
sdk="xrsimulator"
|
|
||||||
archs="arm64 x86_64"
|
|
||||||
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}-simulator"
|
|
||||||
else
|
|
||||||
sdk="xros"
|
|
||||||
archs="arm64"
|
|
||||||
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}"
|
|
||||||
fi
|
|
||||||
# Use flat structure for visionOS, same as iOS
|
|
||||||
install_name="@rpath/llama.framework/llama"
|
|
||||||
;;
|
|
||||||
"tvos")
|
|
||||||
if [[ "$is_simulator" == "true" ]]; then
|
|
||||||
sdk="appletvsimulator"
|
|
||||||
archs="arm64 x86_64"
|
|
||||||
min_version_flag="-mtvos-simulator-version-min=${TVOS_MIN_OS_VERSION}"
|
|
||||||
else
|
|
||||||
sdk="appletvos"
|
|
||||||
archs="arm64"
|
|
||||||
min_version_flag="-mtvos-version-min=${TVOS_MIN_OS_VERSION}"
|
|
||||||
fi
|
|
||||||
install_name="@rpath/llama.framework/llama"
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
# Build architecture flags
|
|
||||||
local arch_flags=""
|
|
||||||
for arch in $archs; do
|
|
||||||
arch_flags+=" -arch $arch"
|
|
||||||
done
|
|
||||||
|
|
||||||
# Create dynamic library
|
|
||||||
echo "Creating dynamic library for ${platform}."
|
|
||||||
xcrun -sdk $sdk clang++ -dynamiclib \
|
|
||||||
-isysroot $(xcrun --sdk $sdk --show-sdk-path) \
|
|
||||||
$arch_flags \
|
|
||||||
$min_version_flag \
|
|
||||||
-Wl,-force_load,"${temp_dir}/combined.a" \
|
|
||||||
-framework Foundation -framework Metal -framework Accelerate \
|
|
||||||
-install_name "$install_name" \
|
|
||||||
-o "${base_dir}/${output_lib}"
|
|
||||||
|
|
||||||
# Platform-specific post-processing for device builds
|
|
||||||
if [[ "$is_simulator" == "false" ]]; then
|
|
||||||
if command -v vtool &>/dev/null; then
|
|
||||||
case "$platform" in
|
|
||||||
"ios")
|
|
||||||
echo "Marking binary as a framework binary for iOS..."
|
|
||||||
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
|
||||||
;;
|
|
||||||
"visionos")
|
|
||||||
echo "Marking binary as a framework binary for visionOS..."
|
|
||||||
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
|
||||||
;;
|
|
||||||
"tvos")
|
|
||||||
echo "Marking binary as a framework binary for tvOS..."
|
|
||||||
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
else
|
|
||||||
echo "Warning: vtool not found. Binary may not pass App Store validation."
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Creating properly formatted dSYM..."
|
|
||||||
# Create a separate directory for dSYMs for all platforms
|
|
||||||
mkdir -p "${base_dir}/${build_dir}/dSYMs"
|
|
||||||
|
|
||||||
# iOS and visionOS style dSYM (flat structure)
|
|
||||||
if [[ "$platform" == "ios" || "$platform" == "visionos" || "$platform" == "tvos" ]]; then
|
|
||||||
# Generate dSYM in the dSYMs directory
|
|
||||||
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
|
|
||||||
|
|
||||||
# Create a copy of the binary that will be stripped
|
|
||||||
cp "${base_dir}/${output_lib}" "${temp_dir}/binary_to_strip"
|
|
||||||
|
|
||||||
# Strip debug symbols from the copy
|
|
||||||
xcrun strip -S "${temp_dir}/binary_to_strip" -o "${temp_dir}/stripped_lib"
|
|
||||||
|
|
||||||
# Replace the original with the stripped version
|
|
||||||
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
|
|
||||||
else
|
|
||||||
# macOS style dSYM
|
|
||||||
# First strip debug info to a separate file
|
|
||||||
xcrun strip -S "${base_dir}/${output_lib}" -o "${temp_dir}/stripped_lib"
|
|
||||||
|
|
||||||
# Generate dSYM in the dSYMs directory
|
|
||||||
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
|
|
||||||
|
|
||||||
# Replace original binary with stripped version
|
|
||||||
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Remove any automatically generated dSYM files in the framework structure as they will
|
|
||||||
# otherwise case Invalid Bundle Structure validation errors.
|
|
||||||
if [ -d "${base_dir}/${output_lib}.dSYM" ]; then
|
|
||||||
echo "Removing generated dSYM file in framework structure: ${base_dir}/${output_lib}.dSYM"
|
|
||||||
rm -rf "${base_dir}/${output_lib}.dSYM"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
rm -rf "${temp_dir}"
|
|
||||||
}
|
|
||||||
|
|
||||||
echo "Building for iOS simulator..."
|
|
||||||
cmake -B build-ios-sim -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
|
|
||||||
-DIOS=ON \
|
|
||||||
-DCMAKE_SYSTEM_NAME=iOS \
|
|
||||||
-DCMAKE_OSX_SYSROOT=iphonesimulator \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
|
|
||||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-ios-sim --config Release -- -quiet
|
|
||||||
|
|
||||||
echo "Building for iOS devices..."
|
|
||||||
cmake -B build-ios-device -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_OSX_SYSROOT=iphoneos \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
|
|
||||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-ios-device --config Release -- -quiet
|
|
||||||
|
|
||||||
echo "Building for macOS..."
|
|
||||||
cmake -B build-macos -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${MACOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
|
||||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-macos --config Release -- -quiet
|
|
||||||
|
|
||||||
echo "Building for visionOS..."
|
|
||||||
cmake -B build-visionos -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
|
||||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
|
||||||
-DCMAKE_OSX_SYSROOT=xros \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
|
||||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-visionos --config Release -- -quiet
|
|
||||||
|
|
||||||
echo "Building for visionOS simulator..."
|
|
||||||
cmake -B build-visionos-sim -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
|
||||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
|
||||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
|
||||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-visionos-sim --config Release -- -quiet
|
|
||||||
|
|
||||||
# Add tvOS builds (might need the same u_int definitions as watchOS and visionOS)
|
|
||||||
echo "Building for tvOS simulator..."
|
|
||||||
cmake -B build-tvos-sim -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_SYSTEM_NAME=tvOS \
|
|
||||||
-DCMAKE_OSX_SYSROOT=appletvsimulator \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
|
||||||
-DGGML_METAL=ON \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvsimulator \
|
|
||||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-tvos-sim --config Release -- -quiet
|
|
||||||
|
|
||||||
echo "Building for tvOS devices..."
|
|
||||||
cmake -B build-tvos-device -G Xcode \
|
|
||||||
"${COMMON_CMAKE_ARGS[@]}" \
|
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
|
|
||||||
-DCMAKE_SYSTEM_NAME=tvOS \
|
|
||||||
-DCMAKE_OSX_SYSROOT=appletvos \
|
|
||||||
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
|
||||||
-DGGML_METAL=ON \
|
|
||||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvos \
|
|
||||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
|
||||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
|
||||||
-S .
|
|
||||||
cmake --build build-tvos-device --config Release -- -quiet
|
|
||||||
|
|
||||||
# Setup frameworks and copy binaries and headers
|
|
||||||
echo "Setting up framework structures..."
|
|
||||||
setup_framework_structure "build-ios-sim" ${IOS_MIN_OS_VERSION} "ios"
|
|
||||||
setup_framework_structure "build-ios-device" ${IOS_MIN_OS_VERSION} "ios"
|
|
||||||
setup_framework_structure "build-macos" ${MACOS_MIN_OS_VERSION} "macos"
|
|
||||||
setup_framework_structure "build-visionos" ${VISIONOS_MIN_OS_VERSION} "visionos"
|
|
||||||
setup_framework_structure "build-visionos-sim" ${VISIONOS_MIN_OS_VERSION} "visionos"
|
|
||||||
setup_framework_structure "build-tvos-sim" ${TVOS_MIN_OS_VERSION} "tvos"
|
|
||||||
setup_framework_structure "build-tvos-device" ${TVOS_MIN_OS_VERSION} "tvos"
|
|
||||||
|
|
||||||
# Create dynamic libraries from static libraries
|
|
||||||
echo "Creating dynamic libraries from static libraries..."
|
|
||||||
combine_static_libraries "build-ios-sim" "Release-iphonesimulator" "ios" "true"
|
|
||||||
combine_static_libraries "build-ios-device" "Release-iphoneos" "ios" "false"
|
|
||||||
combine_static_libraries "build-macos" "Release" "macos" "false"
|
|
||||||
combine_static_libraries "build-visionos" "Release-xros" "visionos" "false"
|
|
||||||
combine_static_libraries "build-visionos-sim" "Release-xrsimulator" "visionos" "true"
|
|
||||||
combine_static_libraries "build-tvos-sim" "Release-appletvsimulator" "tvos" "true"
|
|
||||||
combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
|
|
||||||
|
|
||||||
# Create XCFramework with correct debug symbols paths
|
|
||||||
echo "Creating XCFramework..."
|
|
||||||
xcodebuild -create-xcframework \
|
|
||||||
-framework $(pwd)/build-ios-sim/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-macos/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-visionos/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-visionos-sim/dSYMs/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-tvos-device/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-tvos-device/dSYMs/llama.dSYM \
|
|
||||||
-framework $(pwd)/build-tvos-sim/framework/llama.framework \
|
|
||||||
-debug-symbols $(pwd)/build-tvos-sim/dSYMs/llama.dSYM \
|
|
||||||
-output $(pwd)/build-apple/llama.xcframework
|
|
|
@ -180,7 +180,8 @@ class Model:
|
||||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||||
if len(extra) == 0 and len(missing_files) > 0:
|
if len(extra) == 0 and len(missing_files) > 0:
|
||||||
raise ValueError(f"Missing or incomplete model files: {missing_files}")
|
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||||
|
f"Missing tensors: {missing}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||||
f"Missing tensors: {missing}\n"
|
f"Missing tensors: {missing}\n"
|
||||||
|
@ -908,6 +909,40 @@ class Model:
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
|
def _set_vocab_rwkv_world(self):
|
||||||
|
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||||
|
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||||
|
|
||||||
|
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||||
|
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||||
|
|
||||||
|
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
parts = line.split(' ')
|
||||||
|
assert len(parts) >= 3
|
||||||
|
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||||
|
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||||
|
assert isinstance(token, bytes)
|
||||||
|
assert len(token) == token_len
|
||||||
|
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||||
|
tokens.append(token_text.encode("utf-8"))
|
||||||
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
remainder = vocab_size - len(tokens)
|
||||||
|
assert remainder >= 0
|
||||||
|
for i in range(len(tokens), vocab_size):
|
||||||
|
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||||
|
toktypes.append(gguf.TokenType.UNUSED)
|
||||||
|
|
||||||
|
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||||
|
special_vocab.chat_template = "rwkv-world"
|
||||||
|
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||||
|
special_vocab._set_special_token("eot", 261)
|
||||||
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||||
|
@ -1065,13 +1100,6 @@ class BloomModel(Model):
|
||||||
|
|
||||||
tensors.append((self.map_tensor_name(name), data_torch))
|
tensors.append((self.map_tensor_name(name), data_torch))
|
||||||
|
|
||||||
if name == "word_embeddings.weight":
|
|
||||||
assert self.tensor_names is not None
|
|
||||||
|
|
||||||
# TODO: tie them at runtime, don't duplicate in the model file
|
|
||||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
|
||||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
|
||||||
|
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
@ -1713,6 +1741,25 @@ class LlamaModel(Model):
|
||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("Mistral3ForConditionalGeneration")
|
||||||
|
class Mistral3Model(LlamaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||||
|
|
||||||
|
# we need to merge the text_config into the root level of hparams
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
hparams = Model.load_hparams(kwargs["dir_model"])
|
||||||
|
if "text_config" in hparams:
|
||||||
|
hparams = {**hparams, **hparams["text_config"]}
|
||||||
|
kwargs["hparams"] = hparams
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||||
|
name = name.replace("language_model.", "")
|
||||||
|
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||||
|
return []
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("DeciLMForCausalLM")
|
@Model.register("DeciLMForCausalLM")
|
||||||
class DeciModel(Model):
|
class DeciModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.DECI
|
model_arch = gguf.MODEL_ARCH.DECI
|
||||||
|
@ -2370,10 +2417,6 @@ class GPT2Model(Model):
|
||||||
|
|
||||||
tensors.append((new_name, data_torch))
|
tensors.append((new_name, data_torch))
|
||||||
|
|
||||||
# note: GPT2 output is tied to (same as) wte in original model
|
|
||||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
|
||||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
|
||||||
|
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
@ -2703,21 +2746,26 @@ class CodeShellModel(Model):
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||||
|
|
||||||
|
_has_tok_embd = False
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
||||||
|
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||||
|
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||||
|
|
||||||
new_name = self.map_tensor_name(name)
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
|
# assuming token_embd.weight is seen before output.weight
|
||||||
|
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||||
|
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||||
|
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||||
|
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||||
|
self.tensor_names.remove("transformer.wte.weight")
|
||||||
|
elif new_name == tok_embd_name:
|
||||||
|
self._has_tok_embd = True
|
||||||
|
|
||||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
return [(new_name, data_torch)]
|
||||||
assert self.tensor_names is not None
|
|
||||||
|
|
||||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
|
||||||
# copy tok_embd.weight to output.weight
|
|
||||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
|
||||||
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
|
|
||||||
@Model.register("InternLM2ForCausalLM")
|
@Model.register("InternLM2ForCausalLM")
|
||||||
|
@ -3412,38 +3460,7 @@ class Rwkv6Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.RWKV6
|
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
self._set_vocab_rwkv_world()
|
||||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
|
||||||
|
|
||||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
|
||||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
|
||||||
|
|
||||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
for line in lines:
|
|
||||||
parts = line.split(' ')
|
|
||||||
assert len(parts) >= 3
|
|
||||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
|
||||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
|
||||||
assert isinstance(token, bytes)
|
|
||||||
assert len(token) == token_len
|
|
||||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
|
||||||
tokens.append(token_text.encode("utf-8"))
|
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
|
||||||
remainder = vocab_size - len(tokens)
|
|
||||||
assert remainder >= 0
|
|
||||||
for i in range(len(tokens), vocab_size):
|
|
||||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
|
||||||
toktypes.append(gguf.TokenType.UNUSED)
|
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
|
||||||
self.gguf_writer.add_token_list(tokens)
|
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
|
||||||
special_vocab.chat_template = "rwkv-world"
|
|
||||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
|
||||||
special_vocab._set_special_token("eot", 261)
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
|
@ -3565,6 +3582,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||||
yield (new_name, data)
|
yield (new_name, data)
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
|
||||||
|
class Rwkv7Model(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_rwkv_world()
|
||||||
|
|
||||||
|
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||||
|
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
|
try:
|
||||||
|
head_size = self.hparams["head_size"]
|
||||||
|
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||||
|
except KeyError:
|
||||||
|
head_size = self.hparams["head_dim"]
|
||||||
|
layer_norm_eps = self.hparams["norm_eps"]
|
||||||
|
hidden_size = self.hparams["hidden_size"]
|
||||||
|
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||||
|
|
||||||
|
# ICLR: In-Context-Learning-Rate
|
||||||
|
try:
|
||||||
|
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
|
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
|
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||||
|
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||||
|
except KeyError:
|
||||||
|
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
|
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
|
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||||
|
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||||
|
|
||||||
|
# RWKV isn't context limited
|
||||||
|
self.gguf_writer.add_context_length(1048576)
|
||||||
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||||
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
|
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||||
|
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||||
|
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||||
|
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||||
|
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
# required by llama.cpp, unused
|
||||||
|
self.gguf_writer.add_head_count(0)
|
||||||
|
|
||||||
|
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||||
|
lora_needs_transpose: bool = True
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# unify tensor names here to make life easier
|
||||||
|
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
|
||||||
|
name = name.replace("self_attn", "attention").replace("attn", "attention")
|
||||||
|
name = name.replace("time_mixer.", "")
|
||||||
|
# lora layer names in fla-hub's impl
|
||||||
|
if "_lora.lora" in name:
|
||||||
|
self.lora_needs_transpose = False
|
||||||
|
name = name.replace("_lora.lora.0.weight", "1.weight")
|
||||||
|
name = name.replace("_lora.lora.2.weight", "2.weight")
|
||||||
|
name = name.replace("_lora.lora.2.bias", "0.weight")
|
||||||
|
|
||||||
|
name = name.replace("feed_forward_norm", "ln2")
|
||||||
|
name = name.replace("g_norm", "ln_x")
|
||||||
|
|
||||||
|
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
|
||||||
|
# some models have dummy v0/v1/v2 on first layer while others don't
|
||||||
|
# ignore them all since they are not used
|
||||||
|
return
|
||||||
|
|
||||||
|
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
|
||||||
|
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
|
||||||
|
|
||||||
|
if bid is not None and "attention.x_" in name:
|
||||||
|
if "attention.x_x" in name:
|
||||||
|
# already concatenated
|
||||||
|
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||||
|
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
|
||||||
|
yield (new_name, data)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.lerp_weights[bid][name] = data_torch
|
||||||
|
except KeyError:
|
||||||
|
self.lerp_weights[bid] = {name: data_torch}
|
||||||
|
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
|
||||||
|
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||||
|
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
|
||||||
|
yield (new_name, data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
data_torch = data_torch.squeeze()
|
||||||
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||||
|
new_name += ".weight"
|
||||||
|
|
||||||
|
if self.lora_needs_transpose and any(
|
||||||
|
new_name.endswith(t) for t in [
|
||||||
|
"time_mix_w1.weight", "time_mix_w2.weight",
|
||||||
|
"time_mix_a1.weight", "time_mix_a2.weight",
|
||||||
|
"time_mix_v1.weight", "time_mix_v2.weight",
|
||||||
|
"time_mix_g1.weight", "time_mix_g2.weight",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
data_torch = data_torch.transpose(0, 1)
|
||||||
|
|
||||||
|
if 'r_k' in new_name:
|
||||||
|
data_torch = data_torch.flatten()
|
||||||
|
|
||||||
|
if bid == 0 and "time_mix_a" in new_name:
|
||||||
|
# dummy v0/v1/v2 on first layer
|
||||||
|
# easist way to make llama happy
|
||||||
|
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
|
||||||
|
|
||||||
|
yield (new_name, data_torch)
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("RwkvHybridForCausalLM")
|
||||||
|
class ARwkv7Model(Rwkv7Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.ARWKV7
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
try:
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
except FileNotFoundError:
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
|
hidden_size = self.hparams["hidden_size"]
|
||||||
|
head_size = self.hparams["head_size"]
|
||||||
|
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||||
|
intermediate_size = self.hparams["intermediate_size"]
|
||||||
|
wkv_has_gate = self.hparams["wkv_has_gate"]
|
||||||
|
assert self.hparams["wkv_version"] == 7
|
||||||
|
|
||||||
|
# ICLR: In-Context-Learning-Rate
|
||||||
|
lora_rank_decay = 64
|
||||||
|
lora_rank_iclr = 64
|
||||||
|
lora_rank_value_residual_mix = 32
|
||||||
|
lora_rank_gate = 128 if wkv_has_gate else 0
|
||||||
|
|
||||||
|
# RWKV isn't context limited
|
||||||
|
self.gguf_writer.add_context_length(1048576)
|
||||||
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
|
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||||
|
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||||
|
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||||
|
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||||
|
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
self.gguf_writer.add_token_shift_count(1)
|
||||||
|
|
||||||
|
# required by llama.cpp, unused
|
||||||
|
self.gguf_writer.add_head_count(0)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||||
class MambaModel(Model):
|
class MambaModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||||
|
|
|
@ -1872,6 +1872,10 @@ struct server_context {
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
params_dft.n_parallel = 1;
|
params_dft.n_parallel = 1;
|
||||||
|
|
||||||
|
// force F16 KV cache for the draft model for extra performance
|
||||||
|
params_dft.cache_type_k = GGML_TYPE_F16;
|
||||||
|
params_dft.cache_type_v = GGML_TYPE_F16;
|
||||||
|
|
||||||
llama_init_dft = common_init_from_params(params_dft);
|
llama_init_dft = common_init_from_params(params_dft);
|
||||||
|
|
||||||
model_dft = llama_init_dft.model.get();
|
model_dft = llama_init_dft.model.get();
|
||||||
|
@ -1892,10 +1896,6 @@ struct server_context {
|
||||||
cparams_dft = common_context_params_to_llama(params_dft);
|
cparams_dft = common_context_params_to_llama(params_dft);
|
||||||
cparams_dft.n_batch = n_ctx_dft;
|
cparams_dft.n_batch = n_ctx_dft;
|
||||||
|
|
||||||
// force F16 KV cache for the draft model for extra performance
|
|
||||||
cparams_dft.type_k = GGML_TYPE_F16;
|
|
||||||
cparams_dft.type_v = GGML_TYPE_F16;
|
|
||||||
|
|
||||||
// the context is not needed - we will create one for each slot
|
// the context is not needed - we will create one for each slot
|
||||||
llama_init_dft.context.reset();
|
llama_init_dft.context.reset();
|
||||||
}
|
}
|
||||||
|
|
26
ggml/cmake/common.cmake
Normal file
26
ggml/cmake/common.cmake
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
function(ggml_get_flags CCID CCVER)
|
||||||
|
set(C_FLAGS "")
|
||||||
|
set(CXX_FLAGS "")
|
||||||
|
|
||||||
|
if (CCID MATCHES "Clang")
|
||||||
|
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
|
||||||
|
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
|
||||||
|
|
||||||
|
if (
|
||||||
|
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
|
||||||
|
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
|
||||||
|
)
|
||||||
|
list(APPEND C_FLAGS -Wdouble-promotion)
|
||||||
|
endif()
|
||||||
|
elseif (CCID STREQUAL "GNU")
|
||||||
|
set(C_FLAGS -Wdouble-promotion)
|
||||||
|
set(CXX_FLAGS -Wno-array-bounds)
|
||||||
|
|
||||||
|
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
||||||
|
list(APPEND CXX_FLAGS -Wextra-semi)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
|
||||||
|
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
|
||||||
|
endfunction()
|
|
@ -460,6 +460,7 @@ extern "C" {
|
||||||
GGML_OP_RMS_NORM,
|
GGML_OP_RMS_NORM,
|
||||||
GGML_OP_RMS_NORM_BACK,
|
GGML_OP_RMS_NORM_BACK,
|
||||||
GGML_OP_GROUP_NORM,
|
GGML_OP_GROUP_NORM,
|
||||||
|
GGML_OP_L2_NORM,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_MUL_MAT_ID,
|
GGML_OP_MUL_MAT_ID,
|
||||||
|
@ -508,6 +509,7 @@ extern "C" {
|
||||||
GGML_OP_ADD_REL_POS,
|
GGML_OP_ADD_REL_POS,
|
||||||
GGML_OP_RWKV_WKV6,
|
GGML_OP_RWKV_WKV6,
|
||||||
GGML_OP_GATED_LINEAR_ATTN,
|
GGML_OP_GATED_LINEAR_ATTN,
|
||||||
|
GGML_OP_RWKV_WKV7,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -1108,6 +1110,18 @@ extern "C" {
|
||||||
int n_groups,
|
int n_groups,
|
||||||
float eps);
|
float eps);
|
||||||
|
|
||||||
|
// l2 normalize along rows
|
||||||
|
// used in rwkv v7
|
||||||
|
GGML_API struct ggml_tensor * ggml_l2_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||||
|
@ -1903,6 +1917,16 @@ extern "C" {
|
||||||
struct ggml_tensor * state,
|
struct ggml_tensor * state,
|
||||||
float scale);
|
float scale);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * r,
|
||||||
|
struct ggml_tensor * w,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * state);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||||
|
|
|
@ -8159,7 +8159,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||||
|
float sum = 0;
|
||||||
|
svuint8_t m4b = svdup_n_u8(0xf);
|
||||||
|
svint32_t vzero = svdup_n_s32(0);
|
||||||
|
svuint8_t mone = svdup_n_u8(0x30);
|
||||||
|
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
||||||
|
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||||
|
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||||
|
|
||||||
|
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
||||||
|
|
||||||
|
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||||
|
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
||||||
|
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
||||||
|
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
||||||
|
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
||||||
|
const svint64_t prod = svdup_n_s64(0);
|
||||||
|
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
||||||
|
svdot_s64(prod, q8sums_2, q6scales_2)));
|
||||||
|
int32_t isum = 0;
|
||||||
|
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||||
|
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
||||||
|
svint32_t isum_tmp = svdup_n_s32(0);
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
||||||
|
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
||||||
|
qh += 32;
|
||||||
|
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
||||||
|
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
||||||
|
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
||||||
|
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
||||||
|
q6 += 64;
|
||||||
|
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||||
|
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||||
|
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||||
|
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||||
|
q8 += 64;
|
||||||
|
|
||||||
|
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
||||||
|
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
||||||
|
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
||||||
|
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
||||||
|
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
||||||
|
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
||||||
|
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||||
|
|
||||||
|
scale += 4;
|
||||||
|
q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||||
|
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||||
|
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||||
|
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||||
|
q8 += 64;
|
||||||
|
|
||||||
|
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
||||||
|
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
||||||
|
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
||||||
|
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
||||||
|
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
||||||
|
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
||||||
|
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||||
|
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||||
|
scale += 4;
|
||||||
|
}
|
||||||
|
isum += svaddv_s32(pg32_4, isum_tmp);
|
||||||
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
||||||
|
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
||||||
|
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
||||||
|
svint32_t isum_tmp = svdup_n_s32(0);
|
||||||
|
for (int j = 0; j < QK_K/128; j++) {
|
||||||
|
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
||||||
|
qh += 32;
|
||||||
|
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
||||||
|
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
||||||
|
q6 += 64;
|
||||||
|
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
||||||
|
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
||||||
|
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
||||||
|
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
||||||
|
q8 += 128;
|
||||||
|
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
||||||
|
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
||||||
|
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
||||||
|
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
||||||
|
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
||||||
|
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
||||||
|
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
||||||
|
|
||||||
|
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
||||||
|
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||||
|
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||||
|
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
||||||
|
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||||
|
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||||
|
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
||||||
|
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||||
|
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||||
|
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
||||||
|
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||||
|
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||||
|
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
||||||
|
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
||||||
|
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
||||||
|
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
||||||
|
|
||||||
|
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
||||||
|
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
||||||
|
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
||||||
|
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
||||||
|
scale += 8;
|
||||||
|
}
|
||||||
|
isum += svaddv_s32(pg32_8, isum_tmp);
|
||||||
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sum;
|
||||||
|
|
||||||
|
#elif __ARM_NEON
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||||
|
|
|
@ -8578,6 +8578,69 @@ static void ggml_compute_forward_group_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_l2_norm
|
||||||
|
|
||||||
|
static void ggml_compute_forward_l2_norm_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(eps >= 0.0f);
|
||||||
|
|
||||||
|
// TODO: optimize
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
|
sum += (ggml_float)(x[i00] * x[i00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
memcpy(y, x, ne00 * sizeof(float));
|
||||||
|
|
||||||
|
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||||
|
|
||||||
|
ggml_vec_scale_f32(ne00, y, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_l2_norm(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_l2_norm_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_mul_mat
|
// ggml_compute_forward_mul_mat
|
||||||
|
|
||||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
|
@ -13643,6 +13706,184 @@ static void ggml_compute_forward_gla(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_rwkv_wkv7
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
const int64_t T = dst->src[1]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t HEADS = dst->src[1]->ne[1];
|
||||||
|
const int64_t n_seqs = dst->src[6]->ne[1];
|
||||||
|
const int64_t head_size = C / HEADS;
|
||||||
|
|
||||||
|
float * dst_data = (float *) dst->data;
|
||||||
|
float * state = ((float *) dst->data) + C * T;
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
if (ith >= HEADS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int h_start = (HEADS * ith) / nth;
|
||||||
|
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||||
|
(HEADS * (ith + 1)) / nth : HEADS;
|
||||||
|
|
||||||
|
float * r = (float *) dst->src[0]->data;
|
||||||
|
float * w = (float *) dst->src[1]->data;
|
||||||
|
float * k = (float *) dst->src[2]->data;
|
||||||
|
float * v = (float *) dst->src[3]->data;
|
||||||
|
float * a = (float *) dst->src[4]->data;
|
||||||
|
float * b = (float *) dst->src[5]->data;
|
||||||
|
|
||||||
|
int64_t t_stride = HEADS * head_size; // Same to C
|
||||||
|
|
||||||
|
int64_t h_stride = C / HEADS;
|
||||||
|
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||||
|
int64_t h_stride_2d = head_size * head_size;
|
||||||
|
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
for (int64_t t = 0; t < T; t++) {
|
||||||
|
int64_t t_offset = t * t_stride;
|
||||||
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
|
float * state_cur = state + state_offset;
|
||||||
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||||
|
|
||||||
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
|
int64_t h_offset = h * h_stride;
|
||||||
|
int64_t t_h_offset = t_offset + h_offset;
|
||||||
|
int64_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
|
for (int64_t ii = 0; ii < head_size; ii++) {
|
||||||
|
int64_t t_h_i_offset = t_h_offset + ii;
|
||||||
|
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
||||||
|
|
||||||
|
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
||||||
|
|
||||||
|
float sa = 0;
|
||||||
|
{
|
||||||
|
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||||
|
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||||
|
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||||
|
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
||||||
|
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||||
|
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
||||||
|
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
||||||
|
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_F32_VEC_REDUCE(sa, sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
||||||
|
|
||||||
|
int64_t j = 0;
|
||||||
|
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||||
|
for (; j < head_size; j += GGML_F32_STEP) {
|
||||||
|
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||||
|
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
||||||
|
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
||||||
|
|
||||||
|
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
||||||
|
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
||||||
|
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
||||||
|
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
||||||
|
|
||||||
|
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
||||||
|
|
||||||
|
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||||
|
// kv + s * decay + sa * b
|
||||||
|
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
||||||
|
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
||||||
|
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
||||||
|
|
||||||
|
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
||||||
|
|
||||||
|
// There shouldn't be left-overs though.
|
||||||
|
for (; j < head_size; j++) {
|
||||||
|
int64_t t_h_j_offset = t_h_offset + j;
|
||||||
|
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
|
||||||
|
float r_val = r[t_h_j_offset];
|
||||||
|
float w_val = w[t_h_j_offset];
|
||||||
|
float k_val = k[t_h_j_offset];
|
||||||
|
float b_val = b[t_h_j_offset];
|
||||||
|
float kv_val = v[t_h_i_offset] * k_val;
|
||||||
|
|
||||||
|
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||||
|
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||||
|
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (int64_t t = 0; t < T; t++) {
|
||||||
|
int64_t t_offset = t * t_stride;
|
||||||
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
|
float * state_cur = state + state_offset;
|
||||||
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||||
|
|
||||||
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
|
int64_t h_offset = h * h_stride;
|
||||||
|
int64_t t_h_offset = t_offset + h_offset;
|
||||||
|
int64_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < head_size; i++) {
|
||||||
|
int64_t t_h_i_offset = t_h_offset + i;
|
||||||
|
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
|
||||||
|
float v_val = v[t_h_i_offset];
|
||||||
|
|
||||||
|
float sa = 0, result = 0;
|
||||||
|
for (int64_t j = 0; j < head_size; j++) {
|
||||||
|
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < head_size; j++) {
|
||||||
|
int64_t t_h_j_offset = t_h_offset + j;
|
||||||
|
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
|
||||||
|
float r_val = r[t_h_j_offset];
|
||||||
|
float w_val = w[t_h_j_offset];
|
||||||
|
float k_val = k[t_h_j_offset];
|
||||||
|
float b_val = b[t_h_j_offset];
|
||||||
|
float kv_val = v_val * k_val;
|
||||||
|
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||||
|
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||||
|
result += state_cur[h_2d_i_j_offset] * r_val;
|
||||||
|
}
|
||||||
|
dst_data[t_h_i_offset] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_wkv7(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_map_unary
|
// ggml_compute_forward_map_unary
|
||||||
|
|
||||||
static void ggml_compute_forward_map_unary_f32(
|
static void ggml_compute_forward_map_unary_f32(
|
||||||
|
@ -14209,6 +14450,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_group_norm(params, tensor);
|
ggml_compute_forward_group_norm(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_l2_norm(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_mul_mat(params, tensor);
|
ggml_compute_forward_mul_mat(params, tensor);
|
||||||
|
@ -14396,6 +14641,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_gla(params, tensor);
|
ggml_compute_forward_gla(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
{
|
{
|
||||||
ggml_unary_op_f32_t fun;
|
ggml_unary_op_f32_t fun;
|
||||||
|
@ -14621,6 +14870,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
|
@ -14687,14 +14937,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
case GGML_OP_WIN_UNPART:
|
case GGML_OP_WIN_UNPART:
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_RWKV_WKV6:
|
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
|
|
@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
||||||
#define USE_CUDA_GRAPH
|
#define USE_CUDA_GRAPH
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ bool g_mul_mat_q = true;
|
||||||
#include "ggml-cuda/tsembd.cuh"
|
#include "ggml-cuda/tsembd.cuh"
|
||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
#include "ggml-cuda/wkv6.cuh"
|
#include "ggml-cuda/wkv.cuh"
|
||||||
#include "ggml-cuda/gla.cuh"
|
#include "ggml-cuda/gla.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
@ -265,6 +265,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||||
device_vmm ? "yes" : "no", prop.warpSize);
|
device_vmm ? "yes" : "no", prop.warpSize);
|
||||||
#elif defined(GGML_USE_MUSA)
|
#elif defined(GGML_USE_MUSA)
|
||||||
|
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
|
||||||
|
info.devices[id].warp_size = 32;
|
||||||
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
|
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
|
||||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||||
|
@ -2201,6 +2203,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
ggml_cuda_op_group_norm(ctx, dst);
|
ggml_cuda_op_group_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_cuda_op_l2_norm(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
ggml_cuda_op_concat(ctx, dst);
|
ggml_cuda_op_concat(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2309,6 +2314,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2615,13 +2623,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
||||||
|
|
||||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 12000
|
||||||
cudaGraphExecUpdateResultInfo result_info;
|
cudaGraphExecUpdateResultInfo result_info;
|
||||||
#ifdef __HIP_PLATFORM_AMD__
|
|
||||||
hipGraphNode_t errorNode;
|
|
||||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
|
||||||
#else
|
|
||||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||||
#endif
|
#else
|
||||||
|
cudaGraphNode_t errorNode;
|
||||||
|
cudaGraphExecUpdateResult result_info;
|
||||||
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
|
||||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||||
|
@ -3164,6 +3174,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
break;
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||||
|
@ -3218,6 +3229,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT: {
|
case GGML_OP_FLASH_ATTN_EXT: {
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
|
|
|
@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// template <int block_size>
|
||||||
|
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
|
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
// const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
// float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
// for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
// const float xi = x[row*ncols + col];
|
||||||
|
// tmp += xi * xi;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // sum up partial sums
|
||||||
|
// tmp = warp_reduce_sum(tmp);
|
||||||
|
// if (block_size > WARP_SIZE) {
|
||||||
|
// __shared__ float s_sum[32];
|
||||||
|
// int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
// int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
// if (lane_id == 0) {
|
||||||
|
// s_sum[warp_id] = tmp;
|
||||||
|
// }
|
||||||
|
// __syncthreads();
|
||||||
|
// tmp = s_sum[lane_id];
|
||||||
|
// tmp = warp_reduce_sum(tmp);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||||
|
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||||
|
|
||||||
|
// for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
|
static __global__ void l2_norm_f32(
|
||||||
|
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||||
|
const int64_t stride_sample, const float eps) {
|
||||||
|
const int nrows = gridDim.x;
|
||||||
|
const int nchannels = gridDim.y;
|
||||||
|
|
||||||
|
const int row = blockIdx.x;
|
||||||
|
const int channel = blockIdx.y;
|
||||||
|
const int sample = blockIdx.z;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||||
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||||
|
|
||||||
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[col];
|
||||||
|
tmp += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
if constexpr (block_size > WARP_SIZE) {
|
||||||
|
static_assert(block_size == 1024, "unexpected block_size");
|
||||||
|
__shared__ float s_sum[32];
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmp = s_sum[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||||
|
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
dst[col] = scale * x[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void norm_f32_cuda(
|
static void norm_f32_cuda(
|
||||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||||
|
@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32_cuda(
|
||||||
|
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||||
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||||
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||||
|
if (ncols < 1024) {
|
||||||
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(1024, 1, 1);
|
||||||
|
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||||
|
|
||||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
float * dst_d = (float *) dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
GGML_ASSERT(eps >= 0.0f);
|
||||||
|
|
||||||
|
const size_t ts0 = ggml_type_size(src0->type);
|
||||||
|
GGML_ASSERT(nb00 == ts0);
|
||||||
|
const int64_t s01 = nb01 / ts0;
|
||||||
|
const int64_t s02 = nb02 / ts0;
|
||||||
|
const int64_t s03 = nb03 / ts0;
|
||||||
|
|
||||||
|
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||||
|
}
|
||||||
|
|
|
@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
|
@ -112,7 +112,7 @@
|
||||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||||
#define cudaGraphLaunch hipGraphLaunch
|
#define cudaGraphLaunch hipGraphLaunch
|
||||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
||||||
#define cudaGraphNodeType hipGraphNodeType
|
#define cudaGraphNodeType hipGraphNodeType
|
||||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||||
#define cudaGraphInstantiate hipGraphInstantiate
|
#define cudaGraphInstantiate hipGraphInstantiate
|
||||||
|
|
3
ggml/src/ggml-cuda/vendors/musa.h
vendored
3
ggml/src/ggml-cuda/vendors/musa.h
vendored
|
@ -119,7 +119,7 @@
|
||||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||||
#define cudaGraphExec_t musaGraphExec_t
|
#define cudaGraphExec_t musaGraphExec_t
|
||||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
||||||
#define cudaGraphGetNodes musaGraphGetNodes
|
#define cudaGraphGetNodes musaGraphGetNodes
|
||||||
#define cudaGraphInstantiate musaGraphInstantiate
|
#define cudaGraphInstantiate musaGraphInstantiate
|
||||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||||
|
@ -132,6 +132,7 @@
|
||||||
#define cudaGraph_t musaGraph_t
|
#define cudaGraph_t musaGraph_t
|
||||||
#define cudaKernelNodeParams musaKernelNodeParams
|
#define cudaKernelNodeParams musaKernelNodeParams
|
||||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||||
|
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||||
#define cudaStreamEndCapture musaStreamEndCapture
|
#define cudaStreamEndCapture musaStreamEndCapture
|
||||||
|
|
||||||
typedef mt_bfloat16 nv_bfloat16;
|
typedef mt_bfloat16 nv_bfloat16;
|
||||||
|
|
199
ggml/src/ggml-cuda/wkv.cu
Normal file
199
ggml/src/ggml-cuda/wkv.cu
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "wkv.cuh"
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
|
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int bid = blockIdx.x;
|
||||||
|
|
||||||
|
const int head_size = block_size;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
_tf[tid] = tf[head_i * head_size + tid];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||||
|
__syncthreads();
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
const float4& k = (float4&)(_k[j]);
|
||||||
|
const float4& r = (float4&)(_r[j]);
|
||||||
|
const float4& tf = (float4&)(_tf[j]);
|
||||||
|
const float4& td = (float4&)(_td[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 kv;
|
||||||
|
|
||||||
|
kv.x = k.x * _v;
|
||||||
|
kv.y = k.y * _v;
|
||||||
|
kv.z = k.z * _v;
|
||||||
|
kv.w = k.w * _v;
|
||||||
|
|
||||||
|
y += r.x * (tf.x * kv.x + s.x);
|
||||||
|
y += r.y * (tf.y * kv.y + s.y);
|
||||||
|
y += r.z * (tf.z * kv.z + s.z);
|
||||||
|
y += r.w * (tf.w * kv.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * td.x + kv.x;
|
||||||
|
s.y = s.y * td.y + kv.y;
|
||||||
|
s.z = s.z * td.z + kv.z;
|
||||||
|
s.w = s.w * td.w + kv.w;
|
||||||
|
}
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
|
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int bid = blockIdx.x;
|
||||||
|
|
||||||
|
const int head_size = block_size;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
||||||
|
|
||||||
|
#ifndef GGML_USE_MUSA
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||||
|
__syncthreads();
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_w[tid] = w[t];
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_a[tid] = a[t];
|
||||||
|
_b[tid] = b[t];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float sa = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < head_size; j += 4)
|
||||||
|
{
|
||||||
|
const float4& a = (float4&)(_a[j]);
|
||||||
|
const float4& s = (float4&)(state[j]);
|
||||||
|
sa += a.x * s.x;
|
||||||
|
sa += a.y * s.y;
|
||||||
|
sa += a.z * s.z;
|
||||||
|
sa += a.w * s.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
const float4& r = (float4&)(_r[j]);
|
||||||
|
const float4& w = (float4&)(_w[j]);
|
||||||
|
const float4& k = (float4&)(_k[j]);
|
||||||
|
const float4& b = (float4&)(_b[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 kv;
|
||||||
|
|
||||||
|
kv.x = k.x * _v;
|
||||||
|
kv.y = k.y * _v;
|
||||||
|
kv.z = k.z * _v;
|
||||||
|
kv.w = k.w * _v;
|
||||||
|
|
||||||
|
s.x = s.x * w.x + kv.x + sa * b.x;
|
||||||
|
s.y = s.y * w.y + kv.y + sa * b.y;
|
||||||
|
s.z = s.z * w.z + kv.z + sa * b.z;
|
||||||
|
s.w = s.w * w.w + kv.w + sa * b.w;
|
||||||
|
|
||||||
|
y += s.x * r.x;
|
||||||
|
y += s.y * r.y;
|
||||||
|
y += s.z * r.z;
|
||||||
|
y += s.w * r.w;
|
||||||
|
}
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const float * k_d = (const float *)dst->src[0]->data;
|
||||||
|
const float * v_d = (const float *)dst->src[1]->data;
|
||||||
|
const float * r_d = (const float *)dst->src[2]->data;
|
||||||
|
const float * tf_d = (const float *)dst->src[3]->data;
|
||||||
|
const float * td_d = (const float *)dst->src[4]->data;
|
||||||
|
const float * s_d = (const float *)dst->src[5]->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||||
|
|
||||||
|
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||||
|
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
|
} else {
|
||||||
|
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const float * r_d = (const float *)dst->src[0]->data;
|
||||||
|
const float * w_d = (const float *)dst->src[1]->data;
|
||||||
|
const float * k_d = (const float *)dst->src[2]->data;
|
||||||
|
const float * v_d = (const float *)dst->src[3]->data;
|
||||||
|
const float * a_d = (const float *)dst->src[4]->data;
|
||||||
|
const float * b_d = (const float *)dst->src[5]->data;
|
||||||
|
const float * s_d = (const float *)dst->src[6]->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[6]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||||
|
|
||||||
|
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||||
|
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||||
|
} else {
|
||||||
|
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,3 +3,5 @@
|
||||||
#define CUDA_WKV_BLOCK_SIZE 64
|
#define CUDA_WKV_BLOCK_SIZE 64
|
||||||
|
|
||||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,89 +0,0 @@
|
||||||
#include "common.cuh"
|
|
||||||
#include "wkv6.cuh"
|
|
||||||
|
|
||||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
|
||||||
const int tid = threadIdx.x;
|
|
||||||
const int bid = blockIdx.x;
|
|
||||||
|
|
||||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
|
||||||
const int batch_i = bid / H;
|
|
||||||
const int head_i = bid % H;
|
|
||||||
const int state_size = C * head_size;
|
|
||||||
const int n_seq_tokens = T / B;
|
|
||||||
|
|
||||||
float state[head_size];
|
|
||||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < head_size; i++) {
|
|
||||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
_tf[tid] = tf[head_i * head_size + tid];
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
|
||||||
__syncthreads();
|
|
||||||
_k[tid] = k[t];
|
|
||||||
_r[tid] = r[t];
|
|
||||||
_td[tid] = td[t];
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const float _v = v[t];
|
|
||||||
float y = 0;
|
|
||||||
for (int j = 0; j < head_size; j += 4) {
|
|
||||||
const float4& k = (float4&)(_k[j]);
|
|
||||||
const float4& r = (float4&)(_r[j]);
|
|
||||||
const float4& tf = (float4&)(_tf[j]);
|
|
||||||
const float4& td = (float4&)(_td[j]);
|
|
||||||
float4& s = (float4&)(state[j]);
|
|
||||||
float4 kv;
|
|
||||||
|
|
||||||
kv.x = k.x * _v;
|
|
||||||
kv.y = k.y * _v;
|
|
||||||
kv.z = k.z * _v;
|
|
||||||
kv.w = k.w * _v;
|
|
||||||
|
|
||||||
y += r.x * (tf.x * kv.x + s.x);
|
|
||||||
y += r.y * (tf.y * kv.y + s.y);
|
|
||||||
y += r.z * (tf.z * kv.z + s.z);
|
|
||||||
y += r.w * (tf.w * kv.w + s.w);
|
|
||||||
|
|
||||||
s.x = s.x * td.x + kv.x;
|
|
||||||
s.y = s.y * td.y + kv.y;
|
|
||||||
s.z = s.z * td.z + kv.z;
|
|
||||||
s.w = s.w * td.w + kv.w;
|
|
||||||
}
|
|
||||||
dst[t] = y;
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < head_size; i++) {
|
|
||||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const float * k_d = (const float *)dst->src[0]->data;
|
|
||||||
const float * v_d = (const float *)dst->src[1]->data;
|
|
||||||
const float * r_d = (const float *)dst->src[2]->data;
|
|
||||||
const float * tf_d = (const float *)dst->src[3]->data;
|
|
||||||
const float * td_d = (const float *)dst->src[4]->data;
|
|
||||||
const float * s_d = (const float *)dst->src[5]->data;
|
|
||||||
|
|
||||||
const int64_t B = dst->src[5]->ne[1];
|
|
||||||
const int64_t T = dst->src[0]->ne[2];
|
|
||||||
const int64_t C = dst->ne[0];
|
|
||||||
const int64_t H = dst->src[0]->ne[1];
|
|
||||||
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(C % H == 0);
|
|
||||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
||||||
|
|
||||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
|
||||||
}
|
|
|
@ -285,6 +285,13 @@ typedef struct {
|
||||||
float eps;
|
float eps;
|
||||||
} ggml_metal_kargs_rms_norm;
|
} ggml_metal_kargs_rms_norm;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne00_4;
|
||||||
|
uint64_t nb01;
|
||||||
|
float eps;
|
||||||
|
} ggml_metal_kargs_l2_norm;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
int64_t ne01;
|
int64_t ne01;
|
||||||
|
|
|
@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||||
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_NORM,
|
GGML_METAL_KERNEL_TYPE_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||||
|
@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
||||||
|
@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
return true;
|
return true;
|
||||||
|
@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
{
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == 64);
|
||||||
|
|
||||||
|
size_t offs_src3 = 0;
|
||||||
|
size_t offs_src4 = 0;
|
||||||
|
size_t offs_src5 = 0;
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||||
|
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||||
|
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||||
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||||
|
|
||||||
|
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
||||||
|
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
||||||
|
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
||||||
|
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
{
|
||||||
|
const int64_t B = dst->src[6]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == 64);
|
||||||
|
|
||||||
|
size_t offs_src3 = 0;
|
||||||
|
size_t offs_src4 = 0;
|
||||||
|
size_t offs_src5 = 0;
|
||||||
|
size_t offs_src6 = 0;
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||||
|
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||||
|
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||||
|
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||||
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||||
|
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||||
|
|
||||||
|
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
||||||
|
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
||||||
|
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
||||||
|
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = MIN(nth, ne00/4);
|
||||||
|
|
||||||
|
ggml_metal_kargs_l2_norm args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne00_4 =*/ ne00/4,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.eps =*/ eps,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|
|
@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_rwkv_wkv6_f32(
|
||||||
|
device const float * k,
|
||||||
|
device const float * v,
|
||||||
|
device const float * r,
|
||||||
|
device const float * tf,
|
||||||
|
device const float * td,
|
||||||
|
device const float * state_in,
|
||||||
|
device float * dst,
|
||||||
|
constant uint & B,
|
||||||
|
constant uint & T,
|
||||||
|
constant uint & C,
|
||||||
|
constant uint & H,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const uint head_size = 64; // TODO: support head_size = 128
|
||||||
|
const uint batch_id = tgpig.x / H;
|
||||||
|
const uint head_id = tgpig.x % H;
|
||||||
|
const uint tid = tpitg.x;
|
||||||
|
|
||||||
|
if (batch_id >= B || head_id >= H) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint state_size = C * head_size;
|
||||||
|
const uint n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
threadgroup float _k[head_size];
|
||||||
|
threadgroup float _r[head_size];
|
||||||
|
threadgroup float _tf[head_size];
|
||||||
|
threadgroup float _td[head_size];
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
|
||||||
|
for (uint i = 0; i < head_size; i++) {
|
||||||
|
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
_tf[tid] = tf[head_id * head_size + tid];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
|
||||||
|
for (uint t = start_t; t < end_t; t += C) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const float v_val = v[t];
|
||||||
|
float y = 0.0;
|
||||||
|
|
||||||
|
for (uint j = 0; j < head_size; j += 4) {
|
||||||
|
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||||
|
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||||
|
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
float4 kv = k_vec * v_val;
|
||||||
|
|
||||||
|
float4 temp = tf_vec * kv + s_vec;
|
||||||
|
y += dot(r_vec, temp);
|
||||||
|
|
||||||
|
s_vec = s_vec * td_vec + kv;
|
||||||
|
state[j] = s_vec[0];
|
||||||
|
state[j+1] = s_vec[1];
|
||||||
|
state[j+2] = s_vec[2];
|
||||||
|
state[j+3] = s_vec[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_rwkv_wkv7_f32(
|
||||||
|
device const float * r,
|
||||||
|
device const float * w,
|
||||||
|
device const float * k,
|
||||||
|
device const float * v,
|
||||||
|
device const float * a,
|
||||||
|
device const float * b,
|
||||||
|
device const float * state_in,
|
||||||
|
device float * dst,
|
||||||
|
constant uint & B,
|
||||||
|
constant uint & T,
|
||||||
|
constant uint & C,
|
||||||
|
constant uint & H,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const uint head_size = 64; // TODO: support head_size = 128
|
||||||
|
const uint batch_id = tgpig.x / H;
|
||||||
|
const uint head_id = tgpig.x % H;
|
||||||
|
const uint tid = tpitg.x;
|
||||||
|
|
||||||
|
if (batch_id >= B || head_id >= H) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint state_size = C * head_size;
|
||||||
|
const uint n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
threadgroup float _r[head_size];
|
||||||
|
threadgroup float _w[head_size];
|
||||||
|
threadgroup float _k[head_size];
|
||||||
|
threadgroup float _a[head_size];
|
||||||
|
threadgroup float _b[head_size];
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
|
||||||
|
for (uint i = 0; i < head_size; i++) {
|
||||||
|
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ tid * head_size + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
|
||||||
|
for (uint t = start_t; t < end_t; t += C) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_w[tid] = w[t];
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_a[tid] = a[t];
|
||||||
|
_b[tid] = b[t];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const float v_val = v[t];
|
||||||
|
float y = 0.0, sa = 0.0;
|
||||||
|
|
||||||
|
float4 sa_vec(0.0);
|
||||||
|
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||||
|
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
sa_vec += a_vec * s_vec;
|
||||||
|
}
|
||||||
|
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||||
|
|
||||||
|
for (uint j = 0; j < head_size; j += 4) {
|
||||||
|
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||||
|
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||||
|
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
float4 kv = k_vec * v_val;
|
||||||
|
|
||||||
|
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||||
|
y += dot(s_vec, r_vec);
|
||||||
|
|
||||||
|
state[j] = s_vec[0];
|
||||||
|
state[j+1] = s_vec[1];
|
||||||
|
state[j+2] = s_vec[2];
|
||||||
|
state[j+3] = s_vec[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ tid * head_size + i] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_argmax(
|
kernel void kernel_argmax(
|
||||||
device const void * x,
|
device const void * x,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
|
@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_l2_norm(
|
||||||
|
constant ggml_metal_kargs_l2_norm & args,
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort ntg[[threads_per_threadgroup]]) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
shmem_f32[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
|
sumf += dot(x[i00], x[i00]);
|
||||||
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sumf = shmem_f32[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||||
|
|
||||||
|
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
|
y[i00] = x[i00] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_group_norm(
|
kernel void kernel_group_norm(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
|
@ -297,8 +297,27 @@ static int ggml_backend_opencl_n_devices = 0;
|
||||||
struct ProfilingInfo {
|
struct ProfilingInfo {
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
// Kernel execution time in nanoseconds.
|
|
||||||
cl_ulong duration_ns;
|
cl_kernel kernel;
|
||||||
|
cl_event evt;
|
||||||
|
|
||||||
|
cl_ulong cmd_queued;
|
||||||
|
cl_ulong cmd_submit;
|
||||||
|
cl_ulong cmd_start;
|
||||||
|
cl_ulong cmd_end;
|
||||||
|
cl_ulong overhead_start;
|
||||||
|
cl_ulong overhead_end;
|
||||||
|
// For the times below, see spec for clGetEventProfilingInfo
|
||||||
|
// The time kernel spent in cmd queue - SUBMIT - QUEUED
|
||||||
|
cl_ulong cmd_queued_duration_ns;
|
||||||
|
// The time kernel spent for submission - START - SUBMIT
|
||||||
|
cl_ulong cmd_submit_duration_ns;
|
||||||
|
// Kernel execution time in nanoseconds - END - START
|
||||||
|
cl_ulong cmd_duration_ns;
|
||||||
|
// The time for the kernel to complete - COMPLETE - END
|
||||||
|
cl_ulong cmd_complete_duration_ns;
|
||||||
|
// Total time to finish the kernel - COMPELTE - QUEUED
|
||||||
|
cl_ulong cmd_total_duration_ns;
|
||||||
// Global and local work sizes.
|
// Global and local work sizes.
|
||||||
size_t global_size[3];
|
size_t global_size[3];
|
||||||
size_t local_size[3];
|
size_t local_size[3];
|
||||||
|
@ -903,12 +922,56 @@ static void ggml_cl2_free(void) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate profiling info
|
||||||
|
for (ProfilingInfo & info : g_profiling_info) {
|
||||||
|
cl_ulong cmd_queued;
|
||||||
|
cl_ulong cmd_submit;
|
||||||
|
cl_ulong cmd_start;
|
||||||
|
cl_ulong cmd_end;
|
||||||
|
cl_ulong cmd_complete;
|
||||||
|
|
||||||
|
CL_CHECK(clWaitForEvents(1, &info.evt));
|
||||||
|
CL_CHECK(clGetEventProfilingInfo(
|
||||||
|
info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
|
||||||
|
CL_CHECK(clGetEventProfilingInfo(
|
||||||
|
info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
|
||||||
|
CL_CHECK(clGetEventProfilingInfo(
|
||||||
|
info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
|
||||||
|
CL_CHECK(clGetEventProfilingInfo(
|
||||||
|
info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
|
||||||
|
CL_CHECK(clGetEventProfilingInfo(
|
||||||
|
info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
|
||||||
|
CL_CHECK(clReleaseEvent(info.evt));
|
||||||
|
|
||||||
|
char kernel_name[512];
|
||||||
|
CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
|
||||||
|
sizeof(kernel_name), kernel_name, NULL));
|
||||||
|
info.kernel_name = kernel_name;
|
||||||
|
|
||||||
|
info.cmd_queued = cmd_queued;
|
||||||
|
info.cmd_submit = cmd_submit;
|
||||||
|
info.cmd_start = cmd_start;
|
||||||
|
info.cmd_end = cmd_end;
|
||||||
|
|
||||||
|
info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
|
||||||
|
info.cmd_submit_duration_ns = cmd_start - cmd_submit;
|
||||||
|
info.cmd_duration_ns = cmd_end - cmd_start;
|
||||||
|
info.cmd_complete_duration_ns = cmd_complete - cmd_end;
|
||||||
|
info.cmd_total_duration_ns = cmd_complete - cmd_queued;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump a csv
|
||||||
float total_kernel_time = 0;
|
float total_kernel_time = 0;
|
||||||
fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
|
fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
|
||||||
for (const ProfilingInfo & info : g_profiling_info) {
|
for (const ProfilingInfo & info : g_profiling_info) {
|
||||||
total_kernel_time += info.duration_ns/1.e6f;
|
total_kernel_time += info.cmd_duration_ns/1.e6f;
|
||||||
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
||||||
info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
|
info.op_name.c_str(), info.kernel_name.c_str(),
|
||||||
|
info.cmd_queued_duration_ns/1.e6f,
|
||||||
|
info.cmd_submit_duration_ns/1.e6f,
|
||||||
|
info.cmd_duration_ns/1.e6f,
|
||||||
|
info.cmd_complete_duration_ns/1.e6f,
|
||||||
|
info.cmd_total_duration_ns/1.e6f,
|
||||||
info.global_size[0], info.global_size[1], info.global_size[2],
|
info.global_size[0], info.global_size[1], info.global_size[2],
|
||||||
info.local_size[0], info.local_size[2], info.local_size[2],
|
info.local_size[0], info.local_size[2], info.local_size[2],
|
||||||
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
|
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
|
||||||
|
@ -916,6 +979,27 @@ static void ggml_cl2_free(void) {
|
||||||
fclose(fperf);
|
fclose(fperf);
|
||||||
|
|
||||||
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
|
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
|
||||||
|
|
||||||
|
// Dump a simple chrome trace
|
||||||
|
FILE* ftrace = fopen("cl_trace.json", "w");
|
||||||
|
if (!ftrace) {
|
||||||
|
GGML_LOG_ERROR("Failed to open cl_trace.json\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(ftrace, "[\n");
|
||||||
|
for (const ProfilingInfo & info : g_profiling_info) {
|
||||||
|
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||||
|
info.kernel_name.c_str(), info.cmd_queued/1000);
|
||||||
|
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||||
|
info.kernel_name.c_str(), info.cmd_submit/1000);
|
||||||
|
|
||||||
|
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||||
|
info.kernel_name.c_str(), info.cmd_start/1000);
|
||||||
|
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||||
|
info.kernel_name.c_str(), info.cmd_end/1000);
|
||||||
|
}
|
||||||
|
fclose(ftrace);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2062,25 +2146,14 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
|
||||||
// Profiling utility
|
// Profiling utility
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
#ifdef GGML_OPENCL_PROFILING
|
#ifdef GGML_OPENCL_PROFILING
|
||||||
void populateProfilingInfo(
|
static void populateProfilingInfo(
|
||||||
ProfilingInfo& info, cl_event evt, cl_kernel kernel,
|
ProfilingInfo& info, cl_event evt, cl_kernel kernel,
|
||||||
size_t global_size[3], size_t local_size[3],
|
size_t global_size[3], size_t local_size[3],
|
||||||
const ggml_tensor * tensor) {
|
const ggml_tensor * tensor) {
|
||||||
cl_ulong start;
|
info.op_name = tensor->name;
|
||||||
cl_ulong end;
|
info.kernel = kernel;
|
||||||
CL_CHECK(clWaitForEvents(1, &evt));
|
info.evt = evt;
|
||||||
CL_CHECK(clGetEventProfilingInfo(
|
|
||||||
evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
|
|
||||||
CL_CHECK(clGetEventProfilingInfo(
|
|
||||||
evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
|
|
||||||
|
|
||||||
char kernel_name[512];
|
|
||||||
CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
|
|
||||||
sizeof(kernel_name), kernel_name, NULL));
|
|
||||||
|
|
||||||
info.duration_ns = end - start;
|
|
||||||
info.op_name = tensor->name;
|
|
||||||
info.kernel_name = kernel_name;
|
|
||||||
info.local_size[0] = local_size[0];
|
info.local_size[0] = local_size[0];
|
||||||
info.local_size[1] = local_size[1];
|
info.local_size[1] = local_size[1];
|
||||||
info.local_size[2] = local_size[2];
|
info.local_size[2] = local_size[2];
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "softmax.hpp"
|
#include "softmax.hpp"
|
||||||
#include "tsembd.hpp"
|
#include "tsembd.hpp"
|
||||||
#include "im2col.hpp"
|
#include "im2col.hpp"
|
||||||
#include "wkv6.hpp"
|
#include "wkv.hpp"
|
||||||
#include "outprod.hpp"
|
#include "outprod.hpp"
|
||||||
#include "element_wise.hpp"
|
#include "element_wise.hpp"
|
||||||
#include "cpy.hpp"
|
#include "cpy.hpp"
|
||||||
|
|
|
@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
|
||||||
return opt;
|
return opt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace sycl_ex = sycl::ext::oneapi::experimental;
|
||||||
struct ggml_backend_sycl_context {
|
struct ggml_backend_sycl_context {
|
||||||
int device;
|
int device;
|
||||||
std::string name;
|
std::string name;
|
||||||
|
@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
|
||||||
return pool(device);
|
return pool(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_GRAPH
|
||||||
|
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_sycl_pool & host_pool(int device) {
|
ggml_sycl_pool & host_pool(int device) {
|
||||||
if (host_pools[device] == nullptr) {
|
if (host_pools[device] == nullptr) {
|
||||||
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
||||||
|
|
|
@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||||
sycl::range<3>(1, 1, WARP_SIZE),
|
sycl::range<3>(1, 1, WARP_SIZE),
|
||||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||||
nrows, item_ct1);
|
nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||||
vx, y, dst, ncols, nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||||
default:
|
default:
|
||||||
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_UNUSED(src1);
|
GGML_UNUSED(src1);
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
#include "element_wise.hpp"
|
#include "element_wise.hpp"
|
||||||
|
|
||||||
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||||
const int ne10, const int ne11, const int ne12,
|
const int ne10, const int ne11, const int ne12,
|
||||||
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_f32(const float * x, float * dst, const int k,
|
static void gelu_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const float GELU_COEF_A = 0.044715f;
|
const float GELU_COEF_A = 0.044715f;
|
||||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
|
||||||
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void silu_f32(const float * x, float * dst, const int k,
|
static void silu_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_quick_f32(const float *x, float *dst, int k,
|
static void gelu_quick_f32(const float *x, float *dst, int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const float GELU_QUICK_COEF = -1.702f;
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
|
||||||
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
||||||
}
|
}
|
||||||
|
|
||||||
void tanh_f32(const float *x, float *dst, int k,
|
static void tanh_f32(const float *x, float *dst, int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
|
||||||
dst[i] = sycl::tanh((float)(x[i]));
|
dst[i] = sycl::tanh((float)(x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void relu_f32(const float * x, float * dst, const int k,
|
static void relu_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void sigmoid_f32(const float * x, float * dst, const int k,
|
static void sigmoid_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void sqrt_f32(const float * x, float * dst, const int k,
|
static void sqrt_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::sqrt(x[i]);
|
dst[i] = sycl::sqrt(x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void sin_f32(const float * x, float * dst, const int k,
|
static void sin_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::sin(x[i]);
|
dst[i] = sycl::sin(x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cos_f32(const float * x, float * dst, const int k,
|
static void cos_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::cos(x[i]);
|
dst[i] = sycl::cos(x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
void hardswish_f32(const float * x, float * dst, const int k,
|
static void hardswish_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
void exp_f32(const float * x, float * dst, const int k,
|
static void exp_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::exp(x[i]);
|
dst[i] = sycl::exp(x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void log_f32(const float * x, float * dst, const int k,
|
static void log_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void neg_f32(const float * x, float * dst, const int k,
|
static void neg_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = -x[i];
|
dst[i] = -x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
void step_f32(const float * x, float * dst, const int k,
|
static void step_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = x[i] > 0.0f;
|
dst[i] = x[i] > 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
|
||||||
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
||||||
}
|
}
|
||||||
|
|
||||||
void sqr_f32(const float * x, float * dst, const int k,
|
static void sqr_f32(const float * x, float * dst, const int k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2);
|
||||||
|
@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = x[i] * x[i];
|
dst[i] = x[i] * x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||||
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
||||||
|
@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||||
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
int nidx = item_ct1.get_local_id(2) +
|
int nidx = item_ct1.get_local_id(2) +
|
||||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||||
|
@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
const int n_elements, const int ne10, const int ne11,
|
const int n_elements, const int ne10, const int ne11,
|
||||||
const int ne12, const int nb1, const int nb2,
|
const int ne12, const int nb1, const int nb2,
|
||||||
const int offset, queue_ptr stream) {
|
const int offset, queue_ptr stream) {
|
||||||
|
@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void silu_f32_sycl(const float *x, float *dst, const int k,
|
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void relu_f32_sycl(const float *x, float *dst, const int k,
|
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void exp_f32_sycl(const float *x, float *dst, const int k,
|
static void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void log_f32_sycl(const float *x, float *dst, const int k,
|
static void log_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void neg_f32_sycl(const float *x, float *dst, const int k,
|
static void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void step_f32_sycl(const float *x, float *dst, const int k,
|
static void step_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void sin_f32_sycl(const float *x, float *dst, const int k,
|
static void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void cos_f32_sycl(const float *x, float *dst, const int k,
|
static void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
const float negative_slope,
|
const float negative_slope,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||||
|
@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
|
@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||||
const float sf2, const float sf3, queue_ptr stream) {
|
const float sf2, const float sf3, queue_ptr stream) {
|
||||||
|
@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||||
const int ne01, const int ne02, const int ne0,
|
const int ne01, const int ne02, const int ne0,
|
||||||
const int ne1, const int ne2, queue_ptr stream) {
|
const int ne1, const int ne2, queue_ptr stream) {
|
||||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||||
|
|
|
@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
|
||||||
const size_t nrows = ne01;
|
const size_t nrows = ne01;
|
||||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||||
|
@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||||
// TODO: k-quants
|
// TODO: k-quants
|
||||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@
|
||||||
static bool g_sycl_loaded = false;
|
static bool g_sycl_loaded = false;
|
||||||
int g_ggml_sycl_debug = 0;
|
int g_ggml_sycl_debug = 0;
|
||||||
int g_ggml_sycl_disable_optimize = 0;
|
int g_ggml_sycl_disable_optimize = 0;
|
||||||
|
int g_ggml_sycl_disable_graph = 0;
|
||||||
|
|
||||||
static ggml_sycl_device_info ggml_sycl_init() {
|
static ggml_sycl_device_info ggml_sycl_init() {
|
||||||
ggml_sycl_device_info info = {};
|
ggml_sycl_device_info info = {};
|
||||||
|
@ -95,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
|
||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||||
|
|
||||||
dpct::device_info prop;
|
dpct::device_info prop;
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||||
|
@ -118,7 +119,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_device_opt_feature(int device_count) {
|
static void print_device_opt_feature(int device_count) {
|
||||||
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
||||||
GGML_LOG_INFO(
|
GGML_LOG_INFO(
|
||||||
"|ID| Device Type|Reorder|\n");
|
"|ID| Device Type|Reorder|\n");
|
||||||
|
@ -191,10 +192,12 @@ static void ggml_check_sycl() try {
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
||||||
|
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||||
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||||
GGML_LOG_INFO("Build with Macros:\n");
|
GGML_LOG_INFO("Build with Macros:\n");
|
||||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||||
|
@ -333,10 +336,11 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||||
tensor->extra = extra;
|
tensor->extra = extra;
|
||||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_is_quantized(tensor->type)) {
|
if (ggml_is_quantized(tensor->type)) {
|
||||||
// initialize padding to 0 to avoid possible NaN values
|
// initialize padding to 0 to avoid possible NaN values
|
||||||
|
@ -400,7 +404,7 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||||
const void *ptr_src, size_t size) {
|
const void *ptr_src, size_t size) {
|
||||||
char *host_buf = (char *)malloc(size);
|
char *host_buf = (char *)malloc(size);
|
||||||
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
||||||
|
@ -486,6 +490,22 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||||
|
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
||||||
|
if (buffer == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
|
||||||
|
|
||||||
|
if (ctx != nullptr) {
|
||||||
|
for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
|
||||||
|
release_extra_gpu(extra);
|
||||||
|
}
|
||||||
|
ctx->tensor_extras.clear(); // reset the tensor_extras vector
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
||||||
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
||||||
|
@ -495,7 +515,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
||||||
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
||||||
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
||||||
/* .clear = */ ggml_backend_sycl_buffer_clear,
|
/* .clear = */ ggml_backend_sycl_buffer_clear,
|
||||||
/* .reset = */ NULL,
|
/* .reset = */ ggml_backend_sycl_buffer_reset,
|
||||||
};
|
};
|
||||||
|
|
||||||
// sycl buffer type
|
// sycl buffer type
|
||||||
|
@ -576,7 +596,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||||
static std::mutex mutex;
|
static std::mutex mutex;
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
|
||||||
|
|
||||||
auto dev_count = ggml_backend_sycl_get_device_count();
|
auto dev_count = ggml_backend_sycl_get_device_count();
|
||||||
|
|
||||||
|
@ -604,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||||
return &ggml_backend_sycl_buffer_types[device];
|
return &ggml_backend_sycl_buffer_types[device];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
||||||
|
|
||||||
int device = ctx->device;
|
int device = ctx->device;
|
||||||
|
@ -1666,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1687,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
||||||
nchannels_y, item_ct1);
|
nchannels_y, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -1707,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
||||||
row_stride_x, channel_stride_x,
|
row_stride_x, channel_stride_x,
|
||||||
nchannels_y / nchannels_x, item_ct1);
|
nchannels_y / nchannels_x, item_ct1);
|
||||||
|
@ -1748,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||||
const sycl::range<3> block_nums(1, nrows, 1);
|
const sycl::range<3> block_nums(1, nrows, 1);
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -2680,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||||
|
@ -2898,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
@ -3271,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_sycl_set_main_device(const int main_device) try {
|
static void ggml_sycl_set_main_device(const int main_device) try {
|
||||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -3292,7 +3317,7 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||||
if (!g_sycl_loaded) return false;
|
if (!g_sycl_loaded) return false;
|
||||||
|
|
||||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||||
|
@ -3394,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
ggml_sycl_rms_norm(ctx, dst);
|
ggml_sycl_rms_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_sycl_l2_norm(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -3471,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -3610,7 +3641,7 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void reorder_qw(char *data_device, const int ncols, const int nrows,
|
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
|
@ -3624,7 +3655,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
size / sizeof(block_q4_0),
|
size / sizeof(block_q4_0),
|
||||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
||||||
const int ib = i;
|
const int ib = i;
|
||||||
|
|
||||||
|
@ -3638,7 +3669,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||||
sycl::free(tmp_buf, *stream);
|
sycl::free(tmp_buf, *stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||||
char*data_device = (char*)src0->data;
|
char*data_device = (char*)src0->data;
|
||||||
size_t ncols = src0->ne[0];
|
size_t ncols = src0->ne[0];
|
||||||
size_t nrows = src0->ne[1];
|
size_t nrows = src0->ne[1];
|
||||||
|
@ -3647,7 +3678,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||||
ggml_tensor *src0 = dst->src[0];
|
ggml_tensor *src0 = dst->src[0];
|
||||||
ggml_tensor *src1 = dst->src[1];
|
ggml_tensor *src1 = dst->src[1];
|
||||||
|
|
||||||
|
@ -3660,7 +3691,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||||
dpct::queue_ptr stream = ctx->stream();
|
dpct::queue_ptr stream = ctx->stream();
|
||||||
if (ctx->optimized_graph) {
|
if (ctx->optimized_graph) {
|
||||||
return;
|
return;
|
||||||
|
@ -3671,10 +3702,9 @@ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx)
|
||||||
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
|
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
||||||
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
|
||||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
|
||||||
|
|
||||||
|
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
|
||||||
|
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||||
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
|
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
@ -3696,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||||
}
|
}
|
||||||
GGML_ASSERT(ok);
|
GGML_ASSERT(ok);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
|
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_GRAPH
|
||||||
|
if (!g_ggml_sycl_disable_graph) {
|
||||||
|
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
|
||||||
|
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|
||||||
|
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||||
|
return GGML_STATUS_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
||||||
|
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||||
|
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||||
|
model_sycl_graph.end_recording();
|
||||||
|
|
||||||
|
if (!sycl_ctx->exec_graph) {
|
||||||
|
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||||
|
sycl_ctx->exec_graph = std::make_unique<
|
||||||
|
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
sycl_ctx->exec_graph->update(model_sycl_graph);
|
||||||
|
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
|
||||||
|
} catch (sycl::exception const & e) {
|
||||||
|
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
|
||||||
|
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||||
|
sycl_ctx->exec_graph = std::make_unique<
|
||||||
|
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||||
|
}
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3761,7 +3830,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_backend_sycl_get_device_count() {
|
int ggml_backend_sycl_get_device_count() {
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
|
|
||||||
return ggml_sycl_info().device_count;
|
return ggml_sycl_info().device_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3851,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
}
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
case GGML_UNARY_OP_NEG:
|
case GGML_UNARY_OP_NEG:
|
||||||
|
@ -3869,7 +3937,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
@ -3900,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} break;
|
}
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
@ -3917,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} break;
|
}
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
@ -3968,12 +4035,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
}
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||||
} break;
|
}
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
|
@ -3997,6 +4064,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return (op->src[0]->type == GGML_TYPE_F32);
|
return (op->src[0]->type == GGML_TYPE_F32);
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
@ -4030,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_UNUSED(src1);
|
GGML_UNUSED(src1);
|
||||||
|
|
|
@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
@ -734,7 +734,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -755,7 +755,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -777,7 +777,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -799,7 +799,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -821,7 +821,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -843,7 +843,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -864,7 +864,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -886,7 +886,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -908,7 +908,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
@ -1003,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GGML_UNUSED(src1);
|
GGML_UNUSED(src1);
|
||||||
|
|
|
@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||||
|
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||||
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||||
|
item_ct1.get_local_id(1);
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int nthreads = item_ct1.get_local_range(2);
|
||||||
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[row * ncols + col];
|
||||||
|
tmp += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
|
||||||
|
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||||
|
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||||
|
converged control flow. You may need to adjust the code.
|
||||||
|
*/
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
size_t nreduce = nwarps / WARP_SIZE;
|
||||||
|
tmp = 0.f;
|
||||||
|
for (size_t i = 0; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||||
|
}
|
||||||
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
const int nrows, const float eps,
|
const int nrows, const float eps,
|
||||||
queue_ptr stream, int device) {
|
queue_ptr stream, int device) {
|
||||||
|
@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
nullptr, WARP_SIZE);
|
nullptr, WARP_SIZE);
|
||||||
});
|
});
|
||||||
|
@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
});
|
});
|
||||||
|
@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
group_norm_f32(
|
group_norm_f32(
|
||||||
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||||
nullptr, WARP_SIZE);
|
nullptr, WARP_SIZE);
|
||||||
|
@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
group_norm_f32(x, dst, group_size, ne_elements,
|
group_norm_f32(x, dst, group_size, ne_elements,
|
||||||
eps_ct4, item_ct1,
|
eps_ct4, item_ct1,
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
|
@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
nullptr, WARP_SIZE);
|
nullptr, WARP_SIZE);
|
||||||
});
|
});
|
||||||
|
@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
block_dims),
|
block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1)
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
});
|
});
|
||||||
|
@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
|
const int nrows, const float eps,
|
||||||
|
queue_ptr stream, int device) {
|
||||||
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
|
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||||
|
if (ncols < 1024) {
|
||||||
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
|
nullptr, WARP_SIZE);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
|
/*
|
||||||
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||||
|
the limit. To get the device limit, query
|
||||||
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
|
*/
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||||
|
cgh);
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||||
ggml_tensor* dst, const float* src0_dd,
|
ggml_tensor* dst, const float* src0_dd,
|
||||||
const float* src1_dd, float* dst_dd,
|
const float* src1_dd, float* dst_dd,
|
||||||
|
@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||||
(void)dst;
|
(void)dst;
|
||||||
(void)src1_dd;
|
(void)src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst,
|
||||||
|
const float* src0_dd, const float* src1_dd,
|
||||||
|
float* dst_dd,
|
||||||
|
const queue_ptr& main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||||
|
|
||||||
|
(void)src1;
|
||||||
|
(void)dst;
|
||||||
|
(void)src1_dd;
|
||||||
|
}
|
||||||
|
|
|
@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
||||||
float* dst_dd,
|
float* dst_dd,
|
||||||
const queue_ptr& main_stream);
|
const queue_ptr& main_stream);
|
||||||
|
|
||||||
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst,
|
||||||
|
const float* src0_dd, const float* src1_dd,
|
||||||
|
float* dst_dd,
|
||||||
|
const queue_ptr& main_stream);
|
||||||
|
|
||||||
#endif // GGML_SYCL_NORM_HPP
|
#endif // GGML_SYCL_NORM_HPP
|
||||||
|
|
|
@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||||
|
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||||
nrows_y, scale, max_bias, m0,
|
nrows_y, scale, max_bias, m0,
|
||||||
m1, n_head_log2, item_ct1,
|
m1, n_head_log2, item_ct1,
|
||||||
|
|
305
ggml/src/ggml-sycl/wkv.cpp
Normal file
305
ggml/src/ggml-sycl/wkv.cpp
Normal file
|
@ -0,0 +1,305 @@
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
#include "wkv.hpp"
|
||||||
|
|
||||||
|
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||||
|
|
||||||
|
// Helper function for the main kernel
|
||||||
|
template <int block_size>
|
||||||
|
static void rwkv_wkv6_f32_kernel(
|
||||||
|
const int B, const int T, const int C, const int H,
|
||||||
|
const float* k, const float* v, const float* r,
|
||||||
|
const float* tf, const float* td, const float* s,
|
||||||
|
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int bid = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
const int head_size = block_size;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
// Set up shared memory pointers
|
||||||
|
float* _k = shared_mem;
|
||||||
|
float* _r = _k + head_size;
|
||||||
|
float* _tf = _r + head_size;
|
||||||
|
float* _td = _tf + head_size;
|
||||||
|
|
||||||
|
// Local state array
|
||||||
|
float state[block_size];
|
||||||
|
|
||||||
|
// Load initial state
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync threads before shared memory operations
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Load time-mixing parameters
|
||||||
|
_tf[tid] = tf[head_i * head_size + tid];
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Main sequence processing loop
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t += C) {
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
// Load current timestep data to shared memory
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
|
||||||
|
// Process in chunks of 4 for better vectorization
|
||||||
|
sycl::float4 k4, r4, tf4, td4, s4;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
// Load data in vec4 chunks
|
||||||
|
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||||
|
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||||
|
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
// Compute key-value product
|
||||||
|
sycl::float4 kv4 = k4 * _v;
|
||||||
|
|
||||||
|
// Accumulate weighted sum
|
||||||
|
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||||
|
|
||||||
|
// Update state
|
||||||
|
s4 = s4 * td4 + kv4;
|
||||||
|
|
||||||
|
// Store updated state
|
||||||
|
state[j] = s4.x();
|
||||||
|
state[j+1] = s4.y();
|
||||||
|
state[j+2] = s4.z();
|
||||||
|
state[j+3] = s4.w();
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save final state
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
|
static void rwkv_wkv7_f32_kernel(
|
||||||
|
const int B, const int T, const int C, const int H,
|
||||||
|
const float* r, const float* w, const float* k, const float* v,
|
||||||
|
const float* a, const float* b, const float* s,
|
||||||
|
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||||
|
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int bid = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
const int head_size = block_size;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
float* _r = shared_mem;
|
||||||
|
float* _w = _r + head_size;
|
||||||
|
float* _k = _w + head_size;
|
||||||
|
float* _a = _k + head_size;
|
||||||
|
float* _b = _a + head_size;
|
||||||
|
|
||||||
|
float state[block_size];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
|
t += C) {
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_w[tid] = w[t];
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_a[tid] = a[t];
|
||||||
|
_b[tid] = b[t];
|
||||||
|
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0, sa = 0;
|
||||||
|
sycl::float4 a4, s4;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||||
|
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
sa += sycl::dot(a4, s4);
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl::float4 r4, w4, k4, b4;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||||
|
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||||
|
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
sycl::float4 kv4 = k4 * _v;
|
||||||
|
|
||||||
|
s4 = s4 * w4 + kv4 + sa * b4;
|
||||||
|
y += sycl::dot(r4, s4);
|
||||||
|
|
||||||
|
state[j] = s4.x();
|
||||||
|
state[j+1] = s4.y();
|
||||||
|
state[j+2] = s4.z();
|
||||||
|
state[j+3] = s4.w();
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
|
|
||||||
|
const ggml_tensor *src0 = dst->src[0];
|
||||||
|
const ggml_tensor *src1 = dst->src[1];
|
||||||
|
|
||||||
|
const float* k_d = (const float*)dst->src[0]->data;
|
||||||
|
const float* v_d = (const float*)dst->src[1]->data;
|
||||||
|
const float* r_d = (const float*)dst->src[2]->data;
|
||||||
|
const float* tf_d = (const float*)dst->src[3]->data;
|
||||||
|
const float* td_d = (const float*)dst->src[4]->data;
|
||||||
|
const float* s_d = (const float*)dst->src[5]->data;
|
||||||
|
float* dst_d = (float*)dst->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
// Calculate execution configuration
|
||||||
|
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
||||||
|
sycl::range<3> block_dims(1, 1, C / H);
|
||||||
|
sycl::range<3> grid_dims(1, 1, B * H);
|
||||||
|
|
||||||
|
// Submit kernel
|
||||||
|
if (C / H == WKV_BLOCK_SIZE) {
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||||
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||||
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||||
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||||
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(src0);
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
|
|
||||||
|
const ggml_tensor *src0 = dst->src[0];
|
||||||
|
const ggml_tensor *src1 = dst->src[1];
|
||||||
|
|
||||||
|
const float* r_d = (const float*)dst->src[0]->data;
|
||||||
|
const float* w_d = (const float*)dst->src[1]->data;
|
||||||
|
const float* k_d = (const float*)dst->src[2]->data;
|
||||||
|
const float* v_d = (const float*)dst->src[3]->data;
|
||||||
|
const float* a_d = (const float*)dst->src[4]->data;
|
||||||
|
const float* b_d = (const float*)dst->src[5]->data;
|
||||||
|
const float* s_d = (const float*)dst->src[6]->data;
|
||||||
|
float* dst_d = (float*)dst->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[6]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[2];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
||||||
|
|
||||||
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
// Calculate execution configuration
|
||||||
|
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
||||||
|
sycl::range<3> block_dims(1, 1, C / H);
|
||||||
|
sycl::range<3> grid_dims(1, 1, B * H);
|
||||||
|
|
||||||
|
// Submit kernel
|
||||||
|
if (C / H == WKV_BLOCK_SIZE) {
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||||
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||||
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||||
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||||
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(src0);
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
}
|
10
ggml/src/ggml-sycl/wkv.hpp
Normal file
10
ggml/src/ggml-sycl/wkv.hpp
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
#ifndef GGML_SYCL_WKV_HPP
|
||||||
|
#define GGML_SYCL_WKV_HPP
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_WKV_HPP
|
|
@ -1,143 +0,0 @@
|
||||||
#include <sycl/sycl.hpp>
|
|
||||||
#include "wkv6.hpp"
|
|
||||||
|
|
||||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
||||||
|
|
||||||
// Helper function for the main kernel
|
|
||||||
static void rwkv_wkv_f32_kernel(
|
|
||||||
const int B, const int T, const int C, const int H,
|
|
||||||
const float* k, const float* v, const float* r,
|
|
||||||
const float* tf, const float* td, const float* s,
|
|
||||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
|
||||||
const int bid = item_ct1.get_group(2);
|
|
||||||
|
|
||||||
const int head_size = WKV_BLOCK_SIZE;
|
|
||||||
const int batch_i = bid / H;
|
|
||||||
const int head_i = bid % H;
|
|
||||||
const int state_size = C * head_size;
|
|
||||||
const int n_seq_tokens = T / B;
|
|
||||||
|
|
||||||
// Set up shared memory pointers
|
|
||||||
float* _k = shared_mem;
|
|
||||||
float* _r = _k + head_size;
|
|
||||||
float* _tf = _r + head_size;
|
|
||||||
float* _td = _tf + head_size;
|
|
||||||
|
|
||||||
// Local state array
|
|
||||||
float state[WKV_BLOCK_SIZE];
|
|
||||||
|
|
||||||
// Load initial state
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < head_size; i++) {
|
|
||||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync threads before shared memory operations
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
|
|
||||||
// Load time-mixing parameters
|
|
||||||
_tf[tid] = tf[head_i * head_size + tid];
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
|
|
||||||
// Main sequence processing loop
|
|
||||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
||||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
||||||
t += C) {
|
|
||||||
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
|
|
||||||
// Load current timestep data to shared memory
|
|
||||||
_k[tid] = k[t];
|
|
||||||
_r[tid] = r[t];
|
|
||||||
_td[tid] = td[t];
|
|
||||||
|
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
||||||
|
|
||||||
const float _v = v[t];
|
|
||||||
float y = 0;
|
|
||||||
|
|
||||||
// Process in chunks of 4 for better vectorization
|
|
||||||
sycl::float4 k4, r4, tf4, td4, s4;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < head_size; j += 4) {
|
|
||||||
// Load data in vec4 chunks
|
|
||||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
||||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
||||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
||||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
||||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
||||||
|
|
||||||
// Compute key-value product
|
|
||||||
sycl::float4 kv4 = k4 * _v;
|
|
||||||
|
|
||||||
// Accumulate weighted sum
|
|
||||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
||||||
|
|
||||||
// Update state
|
|
||||||
s4 = s4 * td4 + kv4;
|
|
||||||
|
|
||||||
// Store updated state
|
|
||||||
state[j] = s4.x();
|
|
||||||
state[j+1] = s4.y();
|
|
||||||
state[j+2] = s4.z();
|
|
||||||
state[j+3] = s4.w();
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[t] = y;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save final state
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < head_size; i++) {
|
|
||||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
||||||
|
|
||||||
const ggml_tensor *src0 = dst->src[0];
|
|
||||||
const ggml_tensor *src1 = dst->src[1];
|
|
||||||
|
|
||||||
const float* k_d = (const float*)dst->src[0]->data;
|
|
||||||
const float* v_d = (const float*)dst->src[1]->data;
|
|
||||||
const float* r_d = (const float*)dst->src[2]->data;
|
|
||||||
const float* tf_d = (const float*)dst->src[3]->data;
|
|
||||||
const float* td_d = (const float*)dst->src[4]->data;
|
|
||||||
const float* s_d = (const float*)dst->src[5]->data;
|
|
||||||
float* dst_d = (float*)dst->data;
|
|
||||||
|
|
||||||
const int64_t B = dst->src[5]->ne[1];
|
|
||||||
const int64_t T = dst->src[0]->ne[2];
|
|
||||||
const int64_t C = dst->ne[0];
|
|
||||||
const int64_t H = dst->src[0]->ne[1];
|
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(C % H == 0);
|
|
||||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
|
||||||
|
|
||||||
// Calculate execution configuration
|
|
||||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
|
||||||
sycl::range<3> block_dims(1, 1, C / H);
|
|
||||||
sycl::range<3> grid_dims(1, 1, B * H);
|
|
||||||
|
|
||||||
// Submit kernel
|
|
||||||
stream->submit([&](sycl::handler& cgh) {
|
|
||||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
||||||
|
|
||||||
cgh.parallel_for(
|
|
||||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
rwkv_wkv_f32_kernel(
|
|
||||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
||||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
GGML_UNUSED(src0);
|
|
||||||
GGML_UNUSED(src1);
|
|
||||||
}
|
|
|
@ -1,9 +0,0 @@
|
||||||
#ifndef GGML_SYCL_WKV6_HPP
|
|
||||||
#define GGML_SYCL_WKV6_HPP
|
|
||||||
|
|
||||||
#include "common.hpp"
|
|
||||||
|
|
||||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
||||||
|
|
||||||
|
|
||||||
#endif // GGML_SYCL_WKV6_HPP
|
|
|
@ -33,6 +33,7 @@
|
||||||
|
|
||||||
#include "ggml-vulkan-shaders.cpp"
|
#include "ggml-vulkan-shaders.cpp"
|
||||||
|
|
||||||
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
||||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||||
|
|
||||||
#define VK_VENDOR_ID_AMD 0x1002
|
#define VK_VENDOR_ID_AMD 0x1002
|
||||||
|
@ -153,6 +154,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||||
|
|
||||||
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||||
|
|
||||||
|
enum vk_device_architecture {
|
||||||
|
OTHER,
|
||||||
|
AMD_GCN,
|
||||||
|
AMD_RDNA1,
|
||||||
|
AMD_RDNA2,
|
||||||
|
AMD_RDNA3,
|
||||||
|
};
|
||||||
|
|
||||||
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||||
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||||
|
|
||||||
|
if (props.vendorID == VK_VENDOR_ID_AMD) {
|
||||||
|
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
|
bool amd_shader_core_properties = false;
|
||||||
|
bool integer_dot_product = false;
|
||||||
|
bool subgroup_size_control = false;
|
||||||
|
|
||||||
|
for (const auto& properties : ext_props) {
|
||||||
|
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
|
||||||
|
amd_shader_core_properties = true;
|
||||||
|
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||||
|
integer_dot_product = true;
|
||||||
|
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||||
|
subgroup_size_control = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
|
||||||
|
return vk_device_architecture::OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
vk::PhysicalDeviceProperties2 props2;
|
||||||
|
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
|
||||||
|
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||||
|
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||||
|
|
||||||
|
props2.pNext = &shader_core_props_amd;
|
||||||
|
shader_core_props_amd.pNext = &integer_dot_props;
|
||||||
|
integer_dot_props.pNext = &subgroup_size_control_props;
|
||||||
|
|
||||||
|
device.getProperties2(&props2);
|
||||||
|
|
||||||
|
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
|
||||||
|
return vk_device_architecture::AMD_GCN;
|
||||||
|
}
|
||||||
|
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
|
||||||
|
// RDNA
|
||||||
|
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
|
||||||
|
return vk_device_architecture::AMD_RDNA1;
|
||||||
|
}
|
||||||
|
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
|
||||||
|
return vk_device_architecture::AMD_RDNA3;
|
||||||
|
}
|
||||||
|
return vk_device_architecture::AMD_RDNA2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vk_device_architecture::OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
struct vk_device_struct {
|
struct vk_device_struct {
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
|
||||||
|
@ -165,6 +226,7 @@ struct vk_device_struct {
|
||||||
bool pipeline_robustness;
|
bool pipeline_robustness;
|
||||||
vk::Device device;
|
vk::Device device;
|
||||||
uint32_t vendor_id;
|
uint32_t vendor_id;
|
||||||
|
vk_device_architecture architecture;
|
||||||
vk_queue compute_queue;
|
vk_queue compute_queue;
|
||||||
vk_queue transfer_queue;
|
vk_queue transfer_queue;
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
|
@ -246,6 +308,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_group_norm_f32;
|
vk_pipeline pipeline_group_norm_f32;
|
||||||
vk_pipeline pipeline_rms_norm_f32;
|
vk_pipeline pipeline_rms_norm_f32;
|
||||||
vk_pipeline pipeline_rms_norm_back_f32;
|
vk_pipeline pipeline_rms_norm_back_f32;
|
||||||
|
vk_pipeline pipeline_l2_norm_f32;
|
||||||
vk_pipeline pipeline_gelu_f32;
|
vk_pipeline pipeline_gelu_f32;
|
||||||
vk_pipeline pipeline_gelu_quick_f32;
|
vk_pipeline pipeline_gelu_quick_f32;
|
||||||
vk_pipeline pipeline_silu_f32;
|
vk_pipeline pipeline_silu_f32;
|
||||||
|
@ -270,6 +333,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
vk_pipeline pipeline_pool2d_f32;
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||||
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||||
|
|
||||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||||
|
@ -372,6 +436,7 @@ struct vk_mat_mat_push_constants {
|
||||||
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
||||||
uint32_t k_split;
|
uint32_t k_split;
|
||||||
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
||||||
|
uint32_t padded_N;
|
||||||
};
|
};
|
||||||
struct vk_mat_vec_push_constants {
|
struct vk_mat_vec_push_constants {
|
||||||
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||||
|
@ -384,6 +449,7 @@ struct vk_mat_mat_id_push_constants {
|
||||||
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||||
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
||||||
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
|
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
|
||||||
|
uint32_t padded_N;
|
||||||
};
|
};
|
||||||
struct vk_mat_vec_id_push_constants {
|
struct vk_mat_vec_id_push_constants {
|
||||||
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||||
|
@ -569,6 +635,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
||||||
uint32_t H;
|
uint32_t H;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_rwkv_wkv7_push_constants {
|
||||||
|
uint32_t B;
|
||||||
|
uint32_t T;
|
||||||
|
uint32_t C;
|
||||||
|
uint32_t H;
|
||||||
|
};
|
||||||
|
|
||||||
// Allow pre-recording command buffers
|
// Allow pre-recording command buffers
|
||||||
struct vk_staging_memcpy {
|
struct vk_staging_memcpy {
|
||||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||||
|
@ -1449,6 +1522,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||||
return supported;
|
return supported;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct GpuPipelineConfig {
|
||||||
|
// GPU architecture identifier.
|
||||||
|
// Example: vk_device_architecture::AMD_GCN
|
||||||
|
vk_device_architecture arch;
|
||||||
|
|
||||||
|
// Mapping of pipeline names to their specific subgroup sizes.
|
||||||
|
// Example: {"soft_max_f32", 64}
|
||||||
|
std::unordered_map<std::string, uint32_t> pipelines;
|
||||||
|
|
||||||
|
// Default subgroup size for this GPU.
|
||||||
|
// Defaults to 0 if not explicitly provided.
|
||||||
|
uint32_t default_subgroup_size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pipeline configuration for RDNA1 GPUs.
|
||||||
|
static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
|
||||||
|
{"soft_max", 64}, {"im2col", 64},
|
||||||
|
{"argmax", 64}, {"mul_mat_vec", 64},
|
||||||
|
{"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pipeline configuration for RDNA2 GPUs.
|
||||||
|
static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
|
||||||
|
{"soft_max", 64}, {"im2col", 64},
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
|
||||||
|
|
||||||
|
// Define configurations for different GPUs.
|
||||||
|
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
|
||||||
|
{
|
||||||
|
vk_device_architecture::AMD_RDNA1,
|
||||||
|
{
|
||||||
|
rdna1_pipelines,
|
||||||
|
},
|
||||||
|
RDNA_DEFAULT_SUBGROUP_SIZE
|
||||||
|
},
|
||||||
|
{
|
||||||
|
vk_device_architecture::AMD_RDNA2,
|
||||||
|
{
|
||||||
|
rdna2_pipelines,
|
||||||
|
},
|
||||||
|
RDNA_DEFAULT_SUBGROUP_SIZE
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
|
||||||
|
for (const auto &config : gpu_pipeline_configs) {
|
||||||
|
if (config.arch == arch) {
|
||||||
|
auto pipIt = config.pipelines.find(pipeline_name);
|
||||||
|
if (pipIt != config.pipelines.end()) {
|
||||||
|
return pipIt->second;
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
|
||||||
|
std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
|
||||||
|
[](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
|
||||||
|
for (const auto &entry : sorted_pipelines) {
|
||||||
|
if (pipeline_name.find(entry.first) != std::string::npos) {
|
||||||
|
return entry.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config.default_subgroup_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0; // If no matching configuration is found
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_load_shaders(vk_device& device) {
|
static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||||
|
|
||||||
|
@ -1470,36 +1610,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
uint32_t l_align, m_align, s_align;
|
uint32_t l_align, m_align, s_align;
|
||||||
if (device->coopmat2) {
|
if (device->coopmat2) {
|
||||||
// spec constants and tile sizes for non-quant matmul/matmul_id
|
// spec constants and tile sizes for non-quant matmul/matmul_id
|
||||||
l_warptile = { 256, 128, 256, 64 };
|
l_warptile = { 256, 128, 256, 64, 1 };
|
||||||
m_warptile = { 256, 128, 128, 64 };
|
m_warptile = { 256, 128, 128, 64, 0 };
|
||||||
s_warptile = { 128, 64, 64, 64 };
|
s_warptile = { 128, 64, 64, 64, 0 };
|
||||||
l_wg_denoms = {128, 256, 1 };
|
l_wg_denoms = {128, 256, 1 };
|
||||||
m_wg_denoms = {128, 128, 1 };
|
m_wg_denoms = {128, 128, 1 };
|
||||||
s_wg_denoms = { 64, 64, 1 };
|
s_wg_denoms = { 64, 64, 1 };
|
||||||
|
|
||||||
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
||||||
l_warptile_mmq = { 256, 128, 256, 64 };
|
l_warptile_mmq = { 256, 128, 256, 64, 1 };
|
||||||
m_warptile_mmq = { 256, 128, 128, 64 };
|
m_warptile_mmq = { 256, 128, 128, 64, 1 };
|
||||||
s_warptile_mmq = { 256, 128, 128, 64 };
|
s_warptile_mmq = { 256, 32, 64, 128, 0 };
|
||||||
l_mmq_wg_denoms = { 128, 256, 1 };
|
l_mmq_wg_denoms = { 128, 256, 1 };
|
||||||
m_mmq_wg_denoms = { 128, 128, 1 };
|
m_mmq_wg_denoms = { 128, 128, 1 };
|
||||||
s_mmq_wg_denoms = { 128, 128, 1 };
|
s_mmq_wg_denoms = { 32, 64, 1 };
|
||||||
|
|
||||||
// spec constants and tile sizes for quant matmul (Qi_K)
|
// spec constants and tile sizes for quant matmul (Qi_K)
|
||||||
l_warptile_mmq_k = { 256, 128, 512, 16 };
|
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
|
||||||
m_warptile_mmq_k = { 256, 128, 256, 16 };
|
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
|
||||||
s_warptile_mmq_k = { 256, 32, 128, 64 };
|
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
|
||||||
l_mmq_wg_denoms_k = { 128, 512, 1 };
|
l_mmq_wg_denoms_k = { 64, 128, 1 };
|
||||||
m_mmq_wg_denoms_k = { 128, 256, 1 };
|
m_mmq_wg_denoms_k = { 32, 64, 1 };
|
||||||
s_mmq_wg_denoms_k = { 32, 128, 1 };
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
||||||
|
|
||||||
// spec constants and tile sizes for quant matmul_id
|
// spec constants and tile sizes for quant matmul_id
|
||||||
l_warptile_mmqid = { 256, 128, 128, 16 };
|
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||||
m_warptile_mmqid = { 256, 128, 64, 16 };
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||||
s_warptile_mmqid = { 256, 64, 64, 16 };
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||||
l_mmqid_wg_denoms = { 128, 128, 1 };
|
l_mmqid_wg_denoms = { 128, 64, 1 };
|
||||||
m_mmqid_wg_denoms = { 128, 64, 1 };
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
||||||
s_mmqid_wg_denoms = { 64, 64, 1 };
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
||||||
|
|
||||||
l_align = 128;
|
l_align = 128;
|
||||||
m_align = 64;
|
m_align = 64;
|
||||||
|
@ -1575,6 +1715,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
||||||
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
||||||
|
|
||||||
|
if (!require_full_subgroups && required_subgroup_size == 0) {
|
||||||
|
required_subgroup_size = get_subgroup_size(name, device->architecture);
|
||||||
|
}
|
||||||
|
|
||||||
if (!pipeline) {
|
if (!pipeline) {
|
||||||
pipeline = std::make_shared<vk_pipeline_struct>();
|
pipeline = std::make_shared<vk_pipeline_struct>();
|
||||||
pipeline->name = name;
|
pipeline->name = name;
|
||||||
|
@ -2132,6 +2276,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -2243,6 +2388,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
|
@ -2251,7 +2398,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
device->need_compiles = false;
|
device->need_compiles = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
||||||
|
|
||||||
static vk_device ggml_vk_get_device(size_t idx) {
|
static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
||||||
|
@ -2280,6 +2427,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->physical_device = physical_devices[dev_num];
|
device->physical_device = physical_devices[dev_num];
|
||||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
|
device->architecture = get_device_architecture(device->physical_device);
|
||||||
|
|
||||||
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
||||||
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
||||||
|
|
||||||
|
@ -2292,7 +2441,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
bool coopmat2_support = false;
|
bool coopmat2_support = false;
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
|
|
||||||
// Check if maintenance4 is supported
|
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||||
maintenance4_support = true;
|
maintenance4_support = true;
|
||||||
|
@ -2380,13 +2528,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
||||||
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
||||||
#if defined(_WIN32)
|
} else {
|
||||||
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
|
|
||||||
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
||||||
device->suballocation_block_size = 1024*1024*1024;
|
device->suballocation_block_size = 1024*1024*1024;
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
device->suballocation_block_size = device->max_memory_allocation_size;
|
|
||||||
}
|
}
|
||||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||||
|
|
||||||
|
@ -2405,7 +2549,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||||
|
|
||||||
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
|
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2787,7 +2931,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
subgroup_props.pNext = &driver_props;
|
subgroup_props.pNext = &driver_props;
|
||||||
physical_device.getProperties2(&props2);
|
physical_device.getProperties2(&props2);
|
||||||
|
|
||||||
const size_t subgroup_size = subgroup_props.subgroupSize;
|
vk_device_architecture arch = get_device_architecture(physical_device);
|
||||||
|
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
||||||
|
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
||||||
|
|
||||||
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
|
@ -2813,7 +2960,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
|
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
||||||
|
|
||||||
|
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
||||||
coopmat_support = false;
|
coopmat_support = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3858,10 +4007,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||||
|
|
||||||
if (ctx->device->coopmat2) {
|
if (ctx->device->coopmat2) {
|
||||||
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
// Use large shader when the N dimension is greater than the medium shader's tile size
|
||||||
|
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
||||||
|
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
|
// Use medium shader when the N dimension is greater than the small shader's tile size
|
||||||
|
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
||||||
|
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
|
@ -3886,18 +4039,19 @@ static void ggml_vk_matmul(
|
||||||
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
|
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
|
||||||
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
||||||
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
||||||
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
|
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
||||||
|
uint32_t padded_n) {
|
||||||
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
if (split_k == 1) {
|
if (split_k == 1) {
|
||||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(batch_stride_d == m * n);
|
GGML_ASSERT(batch_stride_d == m * n);
|
||||||
|
|
||||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
|
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||||
// Make sure enough workgroups get assigned for split k to work
|
// Make sure enough workgroups get assigned for split k to work
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
@ -3906,13 +4060,17 @@ static void ggml_vk_matmul(
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||||
|
|
||||||
if (ctx->device->coopmat2) {
|
if (ctx->device->coopmat2) {
|
||||||
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
|
// Use large shader when the N dimension is greater than the medium shader's tile size
|
||||||
|
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
||||||
|
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
|
// Use medium shader when the N dimension is greater than the small shader's tile size
|
||||||
|
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
||||||
|
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
|
@ -3937,14 +4095,15 @@ static void ggml_vk_matmul_id(
|
||||||
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
||||||
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
||||||
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
||||||
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
|
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
|
||||||
|
uint32_t padded_n) {
|
||||||
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
|
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
|
||||||
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
|
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
|
||||||
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
||||||
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
|
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
||||||
nei0, nei1, nbi1, ne11 };
|
nei0, nei1, nbi1, ne11, padded_n };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4106,15 +4265,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const int x_ne = ne01 * ne00;
|
|
||||||
const int y_ne = ne11 * ne10;
|
|
||||||
const int d_ne = ne11 * ne01;
|
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
||||||
|
|
||||||
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
||||||
|
const int x_ne = ne01 * ne00;
|
||||||
|
const int y_ne = padded_n * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
|
@ -4237,7 +4398,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
|
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
|
||||||
split_k, ne12*ne13, ne02, ne12, r2, r3
|
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
|
||||||
); // NOLINT
|
); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4688,15 +4849,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const uint64_t x_ne = ne01 * ne00;
|
|
||||||
const uint64_t y_ne = ne11 * ne10;
|
|
||||||
const uint64_t d_ne = ne21 * ne20;
|
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
||||||
|
|
||||||
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
||||||
|
const uint64_t x_ne = ne01 * ne00;
|
||||||
|
const uint64_t y_ne = padded_n * ne10;
|
||||||
|
const uint64_t d_ne = ne21 * ne20;
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
||||||
|
@ -4815,7 +4978,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
||||||
ne01, ne21, ne10, ne10, ne10, ne01,
|
ne01, ne21, ne10, ne10, ne10, ne01,
|
||||||
stride_batch_x, stride_batch_y, ne20*ne21,
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
||||||
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
|
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
|
||||||
); // NOLINT
|
); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5326,6 +5489,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_rms_norm_back_f32;
|
return ctx->device->pipeline_rms_norm_back_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_l2_norm_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(dst)) {
|
switch (ggml_get_unary_op(dst)) {
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
@ -5465,6 +5633,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||||
|
@ -5712,6 +5885,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
@ -5961,23 +6135,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
||||||
const ggml_tensor * k = dst->src[0];
|
GGML_ASSERT(version == 6 || version == 7);
|
||||||
const ggml_tensor * v = dst->src[1];
|
int num_srcs = version == 6 ? 6 : 7;
|
||||||
const ggml_tensor * r = dst->src[2];
|
|
||||||
const ggml_tensor * tf = dst->src[3];
|
for (int i = 0; i < num_srcs; i++) {
|
||||||
const ggml_tensor * td = dst->src[4];
|
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||||
const ggml_tensor * state = dst->src[5];
|
}
|
||||||
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(td->type));
|
|
||||||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
|
||||||
GGML_ASSERT(dst->buffer != nullptr);
|
GGML_ASSERT(dst->buffer != nullptr);
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||||
GGML_ASSERT(pipeline != nullptr);
|
GGML_ASSERT(pipeline != nullptr);
|
||||||
|
|
||||||
if (dryrun) {
|
if (dryrun) {
|
||||||
|
@ -5986,89 +6154,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
for (int i = 0; i < num_srcs; i++) {
|
||||||
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||||
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
}
|
||||||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
|
||||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
|
||||||
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
|
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||||
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
|
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
||||||
|
|
||||||
if (ctx->device->uma) {
|
if (ctx->device->uma) {
|
||||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
for (int i = 0; i < num_srcs; i++) {
|
||||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
}
|
||||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
|
||||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
|
||||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||||
|
dst_uma = d_D != nullptr;
|
||||||
K_uma = d_K != nullptr;
|
|
||||||
V_uma = d_V != nullptr;
|
|
||||||
R_uma = d_R != nullptr;
|
|
||||||
TF_uma = d_TF != nullptr;
|
|
||||||
TD_uma = d_TD != nullptr;
|
|
||||||
STATE_uma = d_State != nullptr;
|
|
||||||
DST_uma = d_D != nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!K_uma) {
|
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||||
d_K = k_buf_ctx->dev_buffer;
|
for (int i = 0; i < num_srcs; i++) {
|
||||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||||
|
if (!srcs_uma[i]) {
|
||||||
|
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||||
|
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!V_uma) {
|
|
||||||
d_V = v_buf_ctx->dev_buffer;
|
const uint64_t dst_size = ggml_nbytes(dst);
|
||||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
if (!dst_uma) {
|
||||||
}
|
|
||||||
if (!R_uma) {
|
|
||||||
d_R = r_buf_ctx->dev_buffer;
|
|
||||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
|
||||||
}
|
|
||||||
if (!TF_uma) {
|
|
||||||
d_TF = tf_buf_ctx->dev_buffer;
|
|
||||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
|
||||||
}
|
|
||||||
if (!TD_uma) {
|
|
||||||
d_TD = td_buf_ctx->dev_buffer;
|
|
||||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
|
||||||
}
|
|
||||||
if (!STATE_uma) {
|
|
||||||
d_State = state_buf_ctx->dev_buffer;
|
|
||||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
|
||||||
}
|
|
||||||
if (!DST_uma) {
|
|
||||||
d_D = dst_buf_ctx->dev_buffer;
|
d_D = dst_buf_ctx->dev_buffer;
|
||||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint64_t k_size = ggml_nbytes(k);
|
|
||||||
const uint64_t v_size = ggml_nbytes(v);
|
|
||||||
const uint64_t r_size = ggml_nbytes(r);
|
|
||||||
const uint64_t tf_size = ggml_nbytes(tf);
|
|
||||||
const uint64_t td_size = ggml_nbytes(td);
|
|
||||||
const uint64_t state_size = ggml_nbytes(state);
|
|
||||||
const uint64_t dst_size = ggml_nbytes(dst);
|
|
||||||
|
|
||||||
std::array<uint32_t, 3> elements = {
|
std::array<uint32_t, 3> elements = {
|
||||||
(uint32_t)(pc.B * pc.H),
|
(uint32_t)(pc.B * pc.H),
|
||||||
1,
|
1,
|
||||||
1
|
1
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
if (version == 6) {
|
||||||
vk_subbuffer{ d_K, k_offset, k_size },
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||||
vk_subbuffer{ d_V, v_offset, v_size },
|
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||||
vk_subbuffer{ d_R, r_offset, r_size },
|
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||||
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||||
vk_subbuffer{ d_TD, td_offset, td_size },
|
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||||
vk_subbuffer{ d_State, state_offset, state_size },
|
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||||
|
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||||
|
} else if (version == 7) {
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||||
|
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||||
|
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||||
|
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||||
|
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||||
|
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||||
|
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||||
|
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||||
|
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||||
|
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
||||||
|
} else {
|
||||||
|
// shouldn't happen
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -6077,7 +6229,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
const size_t n_heads = dst->src[0]->ne[1];
|
const size_t n_heads = dst->src[0]->ne[1];
|
||||||
const size_t n_seqs = dst->src[5]->ne[1];
|
const size_t n_seqs = dst->src[5]->ne[1];
|
||||||
|
|
||||||
ggml_vk_op_f32_rwkv6(
|
ggml_vk_op_f32_wkv(
|
||||||
ctx, subctx, dst,
|
ctx, subctx, dst,
|
||||||
{
|
{
|
||||||
(uint32_t)n_seqs,
|
(uint32_t)n_seqs,
|
||||||
|
@ -6085,6 +6237,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
(uint32_t)n_embed,
|
(uint32_t)n_embed,
|
||||||
(uint32_t)n_heads,
|
(uint32_t)n_heads,
|
||||||
},
|
},
|
||||||
|
6,
|
||||||
|
dryrun
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const size_t seq_length = dst->src[0]->ne[2];
|
||||||
|
const size_t n_embed = dst->ne[0];
|
||||||
|
const size_t n_heads = dst->src[0]->ne[1];
|
||||||
|
const size_t n_seqs = dst->src[6]->ne[1];
|
||||||
|
|
||||||
|
ggml_vk_op_f32_wkv(
|
||||||
|
ctx, subctx, dst,
|
||||||
|
{
|
||||||
|
(uint32_t)n_seqs,
|
||||||
|
(uint32_t)seq_length,
|
||||||
|
(uint32_t)n_embed,
|
||||||
|
(uint32_t)n_heads,
|
||||||
|
},
|
||||||
|
7,
|
||||||
dryrun
|
dryrun
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6386,6 +6558,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
float * op_params = (float *)dst->op_params;
|
||||||
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
@ -6775,7 +6952,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||||
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
||||||
m, n, k,
|
m, n, k,
|
||||||
k, k, m, k*m, k*n, m*n,
|
k, k, m, k*m, k*n, m*n,
|
||||||
split_k, batch, batch, batch, 1, 1
|
split_k, batch, batch, batch, 1, 1, n
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
ggml_vk_ctx_end(subctx);
|
ggml_vk_ctx_end(subctx);
|
||||||
|
@ -7120,7 +7297,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||||
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
||||||
m, n, k,
|
m, n, k,
|
||||||
k, k, m, k*m, k*n, m*n,
|
k, k, m, k*m, k*n, m*n,
|
||||||
split_k, batch, batch, batch, 1, 1
|
split_k, batch, batch, batch, 1, 1, n
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
ggml_vk_ctx_end(subctx);
|
ggml_vk_ctx_end(subctx);
|
||||||
|
@ -7381,6 +7558,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
@ -7397,6 +7575,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
@ -7443,6 +7622,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
@ -7560,6 +7740,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(node)) {
|
switch (ggml_get_unary_op(node)) {
|
||||||
|
@ -7650,6 +7834,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
@ -7723,6 +7912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
@ -7742,6 +7932,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_REPEAT_BACK:
|
case GGML_OP_REPEAT_BACK:
|
||||||
|
@ -8253,8 +8444,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||||
|
|
||||||
|
uint64_t total_mat_mul_bytes = 0;
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
||||||
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
||||||
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (ctx->device->need_compiles) {
|
if (ctx->device->need_compiles) {
|
||||||
ggml_vk_load_shaders(ctx->device);
|
ggml_vk_load_shaders(ctx->device);
|
||||||
|
@ -8275,17 +8470,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
||||||
int submit_node_idx = 0; // index to first node in a batch
|
int submit_node_idx = 0; // index to first node in a batch
|
||||||
|
|
||||||
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
|
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||||
// Start with a smaller count to get work submitted right away, and increase it after each submit.
|
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||||
int nodes_per_submit = 20;
|
// (and scaled down based on model size, so smaller models submit earlier).
|
||||||
|
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||||
|
int nodes_per_submit = 100;
|
||||||
int submitted_nodes = 0;
|
int submitted_nodes = 0;
|
||||||
int submit_count = 0;
|
int submit_count = 0;
|
||||||
|
uint64_t mul_mat_bytes = 0;
|
||||||
|
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
if (first_node_in_batch) {
|
if (first_node_in_batch) {
|
||||||
submit_node_idx = i;
|
submit_node_idx = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
||||||
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||||
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||||
|
(i == last_node);
|
||||||
|
|
||||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
||||||
|
|
||||||
|
@ -8302,13 +8507,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
if (submit) {
|
if (submit) {
|
||||||
first_node_in_batch = true;
|
first_node_in_batch = true;
|
||||||
submitted_nodes = 0;
|
submitted_nodes = 0;
|
||||||
switch (submit_count) {
|
mul_mat_bytes = 0;
|
||||||
case 0:
|
if (submit_count < 3) {
|
||||||
nodes_per_submit = 50;
|
mul_mat_bytes_per_submit *= 2;
|
||||||
break;
|
|
||||||
default:
|
|
||||||
nodes_per_submit = 100;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
submit_count++;
|
submit_count++;
|
||||||
}
|
}
|
||||||
|
@ -8659,6 +8860,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
|
@ -8688,6 +8890,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
|
case GGML_OP_RWKV_WKV7:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
return true;
|
return true;
|
||||||
|
@ -8834,7 +9037,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||||
UNUSED(instance_extensions);
|
UNUSED(instance_extensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||||
switch (props.vendorID) {
|
switch (props.vendorID) {
|
||||||
case VK_VENDOR_ID_INTEL:
|
case VK_VENDOR_ID_INTEL:
|
||||||
// Intel drivers don't support coopmat properly yet
|
// Intel drivers don't support coopmat properly yet
|
||||||
|
@ -8842,10 +9045,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
||||||
case VK_VENDOR_ID_AMD:
|
case VK_VENDOR_ID_AMD:
|
||||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||||
const std::string name = props.deviceName;
|
return arch == vk_device_architecture::AMD_RDNA3;
|
||||||
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
|
|
||||||
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
|
|
||||||
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -9075,6 +9275,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
||||||
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
||||||
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
|
} else if (tensor->op == GGML_OP_L2_NORM) {
|
||||||
|
const float eps = ((float *) tensor->op_params)[0];
|
||||||
|
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
||||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||||
if (src1 != nullptr) {
|
if (src1 != nullptr) {
|
||||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||||
|
@ -9194,6 +9397,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
||||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||||
|
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||||
|
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||||
|
src_clone[4], src_clone[5], src_clone[6]);
|
||||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||||
src_clone[0]->flags = src0->flags;
|
src_clone[0]->flags = src0->flags;
|
||||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||||
|
|
|
@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||||
|
|
||||||
uvec4 v = bl128.block.q4k[0];
|
uvec4 v = bl128.block.q4k[0];
|
||||||
|
|
||||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||||
|
|
||||||
uint32_t sc;
|
uint32_t sc;
|
||||||
uint32_t mbyte;
|
uint32_t mbyte;
|
||||||
|
@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||||
sc &= 0x3F;
|
sc &= 0x3F;
|
||||||
mbyte &= 0x3F;
|
mbyte &= 0x3F;
|
||||||
|
|
||||||
const float16_t d = loadd.x * float16_t(sc);
|
const float d = loadd.x * float(sc);
|
||||||
const float16_t m = loadd.y * float16_t(mbyte);
|
const float m = loadd.y * float(mbyte);
|
||||||
|
|
||||||
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
||||||
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
||||||
|
|
||||||
float16_t ret = d * float16_t(qs) - m;
|
float ret = d * float(qs) - m;
|
||||||
|
|
||||||
return ret;
|
return float16_t(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
||||||
|
|
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#define BLOCK_SIZE 512
|
||||||
|
|
||||||
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||||
|
|
||||||
|
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||||
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||||
|
sum[tid] += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
sum[tid] += sum[tid + s];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||||
|
|
||||||
|
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||||
|
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
|
||||||
layout (constant_id = 2) const uint BN = 64;
|
layout (constant_id = 2) const uint BN = 64;
|
||||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||||
|
|
||||||
|
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||||
|
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||||
|
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint M;
|
uint M;
|
||||||
|
@ -48,6 +52,8 @@ layout (push_constant) uniform parameter
|
||||||
uint broadcast2;
|
uint broadcast2;
|
||||||
uint broadcast3;
|
uint broadcast3;
|
||||||
#endif
|
#endif
|
||||||
|
// N dimension for the B matrix can be >= p.N
|
||||||
|
uint padded_N;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,15 +172,13 @@ void main() {
|
||||||
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
|
||||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
|
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
|
||||||
uint pos_b = 0;
|
uint pos_b = 0;
|
||||||
#else
|
#else
|
||||||
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
|
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
|
||||||
uint pos_b = batch_idx * p.batch_stride_b;
|
uint pos_b = batch_idx * p.batch_stride_b;
|
||||||
|
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
uint stride_a = p.stride_a / QUANT_K;
|
uint stride_a = p.stride_a / QUANT_K;
|
||||||
|
@ -195,6 +199,7 @@ void main() {
|
||||||
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
||||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||||
|
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||||
|
|
||||||
#if QUANT_K > 1
|
#if QUANT_K > 1
|
||||||
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
||||||
|
@ -202,18 +207,19 @@ void main() {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Use end_k rather than p.K as the dimension because that's what
|
// Use end_k rather than p.K as the dimension because that's what
|
||||||
// we need to bound check against when using split_k
|
// we need to bound check against when using split_k.
|
||||||
|
// Bounds check B against padded_N, but bounds check D against N.
|
||||||
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
|
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
|
||||||
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
|
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
|
||||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
|
||||||
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
||||||
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
|
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
|
||||||
|
|
||||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||||
|
|
||||||
#if !defined(MUL_MAT_ID)
|
#if !defined(MUL_MAT_ID)
|
||||||
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
||||||
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
||||||
#if QUANT_K == 1
|
#if QUANT_K == 1
|
||||||
(stride_a % 8) == 0 &&
|
(stride_a % 8) == 0 &&
|
||||||
#endif
|
#endif
|
||||||
|
@ -229,16 +235,54 @@ void main() {
|
||||||
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
||||||
|
|
||||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||||
|
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
|
||||||
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
||||||
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||||
|
|
||||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
}
|
||||||
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||||
|
|
||||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
|
||||||
|
return;
|
||||||
|
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
|
||||||
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
||||||
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||||
|
|
||||||
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||||
|
|
||||||
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
|
}
|
||||||
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||||
|
|
||||||
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||||
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||||
|
|
||||||
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
|
|
||||||
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
|
}
|
||||||
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||||
|
|
||||||
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
#endif // !defined(MUL_MAT_ID)
|
#endif // !defined(MUL_MAT_ID)
|
||||||
|
@ -251,6 +295,9 @@ void main() {
|
||||||
|
|
||||||
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
||||||
|
|
||||||
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||||
|
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||||
|
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||||
|
|
||||||
|
@ -263,7 +310,7 @@ void main() {
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
bool unclampedB = true;
|
bool unclampedB = true;
|
||||||
#else
|
#else
|
||||||
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
|
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
|
||||||
#endif
|
#endif
|
||||||
if (unclampedA && unclampedB) {
|
if (unclampedA && unclampedB) {
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||||
|
@ -293,19 +340,16 @@ void main() {
|
||||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Convert from ACC_TYPE to D_TYPE
|
// Convert from ACC_TYPE to D_TYPE
|
||||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
||||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
// Call callback to store each element, remapping row through shared memory
|
// Call callback to store each element, remapping row through shared memory
|
||||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||||
#else
|
#else
|
||||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||||
|
|
||||||
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
|
||||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -434,6 +434,7 @@ void process_shaders() {
|
||||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
|
@ -528,6 +529,8 @@ void process_shaders() {
|
||||||
|
|
||||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
|
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
|
|
91
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
Normal file
91
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 64
|
||||||
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(push_constant) uniform Parameters {
|
||||||
|
uint B;
|
||||||
|
uint T;
|
||||||
|
uint C;
|
||||||
|
uint H;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||||
|
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
||||||
|
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
||||||
|
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
||||||
|
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
||||||
|
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
||||||
|
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||||
|
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
||||||
|
|
||||||
|
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint head_size = BLOCK_SIZE;
|
||||||
|
const uint batch_id = gl_WorkGroupID.x / H;
|
||||||
|
const uint head_id = gl_WorkGroupID.x % H;
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
const uint state_size = C * head_size;
|
||||||
|
const uint n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
if (batch_id >= B || head_id >= H) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
A_TYPE state[BLOCK_SIZE];
|
||||||
|
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||||
|
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ tid * head_size + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||||
|
|
||||||
|
for (uint t = start_t; t < end_t; t += C) {
|
||||||
|
barrier();
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_w[tid] = w[t];
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_a[tid] = a[t];
|
||||||
|
_b[tid] = b[t];
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
A_TYPE sa = 0.0;
|
||||||
|
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||||
|
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||||
|
sa += dot(s_vec, a_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
const A_TYPE v_val = v[t];
|
||||||
|
A_TYPE y = 0.0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||||
|
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||||
|
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||||
|
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||||
|
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||||
|
|
||||||
|
vec4 kv = k_vec * v_val;
|
||||||
|
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||||
|
y += dot(r_vec, s_vec);
|
||||||
|
|
||||||
|
state[j] = s_vec.x;
|
||||||
|
state[j+1] = s_vec.y;
|
||||||
|
state[j+2] = s_vec.z;
|
||||||
|
state[j+3] = s_vec.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||||
|
+ tid * head_size + i] = state[i];
|
||||||
|
}
|
||||||
|
}
|
|
@ -942,6 +942,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"RMS_NORM",
|
"RMS_NORM",
|
||||||
"RMS_NORM_BACK",
|
"RMS_NORM_BACK",
|
||||||
"GROUP_NORM",
|
"GROUP_NORM",
|
||||||
|
"L2_NORM",
|
||||||
|
|
||||||
"MUL_MAT",
|
"MUL_MAT",
|
||||||
"MUL_MAT_ID",
|
"MUL_MAT_ID",
|
||||||
|
@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"ADD_REL_POS",
|
"ADD_REL_POS",
|
||||||
"RWKV_WKV6",
|
"RWKV_WKV6",
|
||||||
"GATED_LINEAR_ATTN",
|
"GATED_LINEAR_ATTN",
|
||||||
|
"RWKV_WKV7",
|
||||||
|
|
||||||
"UNARY",
|
"UNARY",
|
||||||
|
|
||||||
|
@ -1009,7 +1011,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"OPT_STEP_ADAMW",
|
"OPT_STEP_ADAMW",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"rms_norm(x)",
|
"rms_norm(x)",
|
||||||
"rms_norm_back(x)",
|
"rms_norm_back(x)",
|
||||||
"group_norm(x)",
|
"group_norm(x)",
|
||||||
|
"l2_norm(x)",
|
||||||
|
|
||||||
"X*Y",
|
"X*Y",
|
||||||
"X[i]*Y",
|
"X[i]*Y",
|
||||||
|
@ -1087,6 +1090,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"add_rel_pos(x)",
|
"add_rel_pos(x)",
|
||||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||||
"gated_linear_attn(k, v, q, gate, s)",
|
"gated_linear_attn(k, v, q, gate, s)",
|
||||||
|
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||||
|
|
||||||
"unary(x)",
|
"unary(x)",
|
||||||
|
|
||||||
|
@ -1106,7 +1110,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"adamw(x)",
|
"adamw(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -2699,6 +2703,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
||||||
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_l2_norm
|
||||||
|
|
||||||
|
static struct ggml_tensor * ggml_l2_norm_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps,
|
||||||
|
bool inplace) {
|
||||||
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
ggml_set_op_params_f32(result, 0, eps);
|
||||||
|
|
||||||
|
result->op = GGML_OP_L2_NORM;
|
||||||
|
result->src[0] = a;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_l2_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps) {
|
||||||
|
return ggml_l2_norm_impl(ctx, a, eps, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_l2_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps) {
|
||||||
|
return ggml_l2_norm_impl(ctx, a, eps, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_mul_mat
|
// ggml_mul_mat
|
||||||
|
|
||||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
|
@ -4733,6 +4768,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_rwkv_wkv7
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rwkv_wkv7(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * r,
|
||||||
|
struct ggml_tensor * w,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * state) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(r));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(w));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(k));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(v));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(b));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state));
|
||||||
|
|
||||||
|
const int64_t S = k->ne[0];
|
||||||
|
const int64_t H = k->ne[1];
|
||||||
|
const int64_t n_tokens = k->ne[2];
|
||||||
|
const int64_t n_seqs = state->ne[1];
|
||||||
|
{
|
||||||
|
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// concat output and new_state
|
||||||
|
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_RWKV_WKV7;
|
||||||
|
result->src[0] = r;
|
||||||
|
result->src[1] = w;
|
||||||
|
result->src[2] = k;
|
||||||
|
result->src[3] = v;
|
||||||
|
result->src[4] = a;
|
||||||
|
result->src[5] = b;
|
||||||
|
result->src[6] = state;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_unary
|
// ggml_unary
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_unary_impl(
|
static struct ggml_tensor * ggml_unary_impl(
|
||||||
|
|
|
@ -118,22 +118,26 @@ class Keys:
|
||||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||||
KEY_LENGTH = "{arch}.attention.key_length"
|
KEY_LENGTH = "{arch}.attention.key_length"
|
||||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||||
CAUSAL = "{arch}.attention.causal"
|
CAUSAL = "{arch}.attention.causal"
|
||||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
|
||||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
|
||||||
SCALE = "{arch}.attention.scale"
|
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
|
||||||
|
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
|
||||||
|
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||||
|
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||||
|
SCALE = "{arch}.attention.scale"
|
||||||
|
|
||||||
class Rope:
|
class Rope:
|
||||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||||
|
@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
|
||||||
STARCODER2 = auto()
|
STARCODER2 = auto()
|
||||||
RWKV6 = auto()
|
RWKV6 = auto()
|
||||||
RWKV6QWEN2 = auto()
|
RWKV6QWEN2 = auto()
|
||||||
|
RWKV7 = auto()
|
||||||
|
ARWKV7 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
|
@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
|
||||||
SSM_A = auto()
|
SSM_A = auto()
|
||||||
SSM_D = auto()
|
SSM_D = auto()
|
||||||
SSM_OUT = auto()
|
SSM_OUT = auto()
|
||||||
|
TIME_MIX_W0 = auto()
|
||||||
TIME_MIX_W1 = auto()
|
TIME_MIX_W1 = auto()
|
||||||
TIME_MIX_W2 = auto()
|
TIME_MIX_W2 = auto()
|
||||||
|
TIME_MIX_A0 = auto()
|
||||||
|
TIME_MIX_A1 = auto()
|
||||||
|
TIME_MIX_A2 = auto()
|
||||||
|
TIME_MIX_V0 = auto()
|
||||||
|
TIME_MIX_V1 = auto()
|
||||||
|
TIME_MIX_V2 = auto()
|
||||||
|
TIME_MIX_G1 = auto()
|
||||||
|
TIME_MIX_G2 = auto()
|
||||||
|
TIME_MIX_K_K = auto()
|
||||||
|
TIME_MIX_K_A = auto()
|
||||||
|
TIME_MIX_R_K = auto()
|
||||||
TIME_MIX_LERP_X = auto()
|
TIME_MIX_LERP_X = auto()
|
||||||
TIME_MIX_LERP_K = auto()
|
TIME_MIX_LERP_K = auto()
|
||||||
TIME_MIX_LERP_V = auto()
|
TIME_MIX_LERP_V = auto()
|
||||||
|
@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||||
MODEL_ARCH.RWKV6: "rwkv6",
|
MODEL_ARCH.RWKV6: "rwkv6",
|
||||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||||
|
MODEL_ARCH.RWKV7: "rwkv7",
|
||||||
|
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
|
@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||||
|
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||||
|
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
|
||||||
|
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
|
||||||
|
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
|
||||||
|
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
|
||||||
|
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
|
||||||
|
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
|
||||||
|
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
|
||||||
|
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
|
||||||
|
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||||
|
@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.RWKV7: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_NORM_2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_G1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_G2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_K,
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_A,
|
||||||
|
MODEL_TENSOR.TIME_MIX_R_K,
|
||||||
|
MODEL_TENSOR.TIME_MIX_KEY,
|
||||||
|
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||||
|
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||||
|
MODEL_TENSOR.TIME_MIX_LN,
|
||||||
|
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||||
|
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
|
||||||
|
MODEL_TENSOR.CHANNEL_MIX_KEY,
|
||||||
|
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||||
|
],
|
||||||
|
MODEL_ARCH.ARWKV7: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_W2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_A2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V0,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_V2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_G1,
|
||||||
|
MODEL_TENSOR.TIME_MIX_G2,
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_K,
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_A,
|
||||||
|
MODEL_TENSOR.TIME_MIX_R_K,
|
||||||
|
MODEL_TENSOR.TIME_MIX_KEY,
|
||||||
|
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||||
|
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||||
|
MODEL_TENSOR.TIME_MIX_LN,
|
||||||
|
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
MODEL_ARCH.MAMBA: [
|
MODEL_ARCH.MAMBA: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
@ -767,6 +767,18 @@ class GGUFWriter:
|
||||||
def add_kv_lora_rank(self, length: int) -> None:
|
def add_kv_lora_rank(self, length: int) -> None:
|
||||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_decay_lora_rank(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_iclr_lora_rank(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_gate_lora_rank(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,8 @@ class TensorNameMap:
|
||||||
"embedding.word_embeddings", # chatglm
|
"embedding.word_embeddings", # chatglm
|
||||||
"transformer.token_embeddings", # openelm
|
"transformer.token_embeddings", # openelm
|
||||||
"shared", # t5
|
"shared", # t5
|
||||||
"rwkv.embeddings", # rwkv
|
"rwkv.embeddings", # rwkv6
|
||||||
|
"model.embeddings", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
# Token type embeddings
|
# Token type embeddings
|
||||||
|
@ -42,6 +43,9 @@ class TensorNameMap:
|
||||||
"emb_ln", # nomic-bert
|
"emb_ln", # nomic-bert
|
||||||
"transformer.norm", # openelm
|
"transformer.norm", # openelm
|
||||||
"rwkv.blocks.0.pre_ln", # rwkv
|
"rwkv.blocks.0.pre_ln", # rwkv
|
||||||
|
"rwkv.blocks.0.pre_ln", # rwkv6
|
||||||
|
"model.pre_ln", # rwkv7
|
||||||
|
"model.layers.0.pre_norm", # rwkv7
|
||||||
"backbone.norm", # wavtokenizer
|
"backbone.norm", # wavtokenizer
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -81,7 +85,8 @@ class TensorNameMap:
|
||||||
"encoder.final_layernorm", # chatglm
|
"encoder.final_layernorm", # chatglm
|
||||||
"transformer.norm", # openelm
|
"transformer.norm", # openelm
|
||||||
"model.norm", # nemotron
|
"model.norm", # nemotron
|
||||||
"rwkv.ln_out", # rwkv
|
"rwkv.ln_out", # rwkv6
|
||||||
|
"model.ln_out", # rwkv7
|
||||||
"backbone.final_layer_norm", # wavtokenizer
|
"backbone.final_layer_norm", # wavtokenizer
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -122,14 +127,16 @@ class TensorNameMap:
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.attn_norm", # openelm
|
"transformer.layers.{bid}.attn_norm", # openelm
|
||||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||||
|
"model.layers.{bid}.ln1", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
MODEL_TENSOR.ATTN_NORM_2: (
|
MODEL_TENSOR.ATTN_NORM_2: (
|
||||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
"rwkv.blocks.{bid}.ln2", # rwkv6
|
||||||
|
"model.layers.{bid}.ln2", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention query-key-value
|
# Attention query-key-value
|
||||||
|
@ -462,112 +469,174 @@ class TensorNameMap:
|
||||||
"backbone.layers.{bid}.mixer.out_proj",
|
"backbone.layers.{bid}.mixer.out_proj",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_W0: (
|
||||||
|
"model.layers.{bid}.attention.w0", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_W1: (
|
MODEL_TENSOR.TIME_MIX_W1: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.w1", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_W2: (
|
MODEL_TENSOR.TIME_MIX_W2: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.w2", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_A0: (
|
||||||
|
"model.layers.{bid}.attention.a0", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_A1: (
|
||||||
|
"model.layers.{bid}.attention.a1", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_A2: (
|
||||||
|
"model.layers.{bid}.attention.a2", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_V0: (
|
||||||
|
"model.layers.{bid}.attention.v0", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_V1: (
|
||||||
|
"model.layers.{bid}.attention.v1", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_V2: (
|
||||||
|
"model.layers.{bid}.attention.v2", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_G1: (
|
||||||
|
"model.layers.{bid}.attention.g1", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_G2: (
|
||||||
|
"model.layers.{bid}.attention.g2", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_K: (
|
||||||
|
"model.layers.{bid}.attention.k_k", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_K_A: (
|
||||||
|
"model.layers.{bid}.attention.k_a", # rwkv7
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.TIME_MIX_R_K: (
|
||||||
|
"model.layers.{bid}.attention.r_k", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_decay", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
"rwkv.blocks.{bid}.attention.key", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.key", # rwkv7
|
||||||
|
"model.layers.{bid}.attention.k_proj", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
"rwkv.blocks.{bid}.attention.value", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.value", # rwkv7
|
||||||
|
"model.layers.{bid}.attention.v_proj", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
"rwkv.blocks.{bid}.attention.receptance", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.receptance", # rwkv7
|
||||||
|
"model.layers.{bid}.attention.r_proj", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
"rwkv.blocks.{bid}.attention.gate", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_LN: (
|
MODEL_TENSOR.TIME_MIX_LN: (
|
||||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
|
"rwkv.blocks.{bid}.attention.ln_x", # rwkv6
|
||||||
|
"model.layers.{bid}.attention.ln_x" # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
"rwkv.blocks.{bid}.attention.output", # rwkv6
|
||||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||||
|
"model.layers.{bid}.attention.output", # rwkv7
|
||||||
|
"model.layers.{bid}.attention.o_proj", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
|
||||||
|
"model.layers.{bid}.feed_forward.x_k", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
|
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
||||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
|
"rwkv.blocks.{bid}.feed_forward.key", # rwkv6
|
||||||
|
"model.layers.{bid}.feed_forward.key", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
|
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
"rwkv.blocks.{bid}.feed_forward.value", # rwkv6
|
||||||
|
"model.layers.{bid}.feed_forward.value", # rwkv7
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_A: (
|
MODEL_TENSOR.ATTN_Q_A: (
|
||||||
|
|
|
@ -154,7 +154,12 @@ class SpecialVocab:
|
||||||
return True
|
return True
|
||||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||||
tokenizer_config = json.load(f)
|
tokenizer_config = json.load(f)
|
||||||
chat_template = tokenizer_config.get('chat_template')
|
chat_template_alt = None
|
||||||
|
chat_template_file = path / 'chat_template.json'
|
||||||
|
if chat_template_file.is_file():
|
||||||
|
with open(chat_template_file, encoding = 'utf-8') as f:
|
||||||
|
chat_template_alt = json.load(f).get('chat_template')
|
||||||
|
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
else:
|
else:
|
||||||
|
|
94
klite.embd
94
klite.embd
|
@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
const LITEVER = 224;
|
const LITEVER = 225;
|
||||||
const urlParams = new URLSearchParams(window.location.search);
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
var localflag = true;
|
var localflag = true;
|
||||||
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
|
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
|
||||||
|
@ -2941,6 +2941,7 @@ Current version indicated by LITEVER below.
|
||||||
const OAI_TTS_ID = 1002;
|
const OAI_TTS_ID = 1002;
|
||||||
const KCPP_TTS_ID = 1003;
|
const KCPP_TTS_ID = 1003;
|
||||||
const HD_RES_PX = 768;
|
const HD_RES_PX = 768;
|
||||||
|
const VHD_RES_PX = 960;
|
||||||
const NO_HD_RES_PX = 512;
|
const NO_HD_RES_PX = 512;
|
||||||
const AVATAR_PX = 384;
|
const AVATAR_PX = 384;
|
||||||
const SAVE_SLOTS = 8;
|
const SAVE_SLOTS = 8;
|
||||||
|
@ -6271,7 +6272,7 @@ Current version indicated by LITEVER below.
|
||||||
document.getElementById("enhancedchatinterface").classList.add("transparentbg");
|
document.getElementById("enhancedchatinterface").classList.add("transparentbg");
|
||||||
document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg");
|
document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg");
|
||||||
indexeddb_save("bgimg", compressedImageURI);
|
indexeddb_save("bgimg", compressedImageURI);
|
||||||
}, true, false, 1024, 0.5);
|
}, false, 1024, 0.5);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
function clear_bg_img()
|
function clear_bg_img()
|
||||||
|
@ -6564,7 +6565,7 @@ Current version indicated by LITEVER below.
|
||||||
document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2);
|
document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2);
|
||||||
refreshAestheticPreview(true);
|
refreshAestheticPreview(true);
|
||||||
render_gametext();
|
render_gametext();
|
||||||
}, true, true, AVATAR_PX);
|
}, true, AVATAR_PX);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
//attempt to read as WEBP
|
//attempt to read as WEBP
|
||||||
|
@ -7838,7 +7839,7 @@ Current version indicated by LITEVER below.
|
||||||
temp_scenario.image = compressedImageURI;
|
temp_scenario.image = compressedImageURI;
|
||||||
temp_scenario.image_aspect = aspectratio;
|
temp_scenario.image_aspect = aspectratio;
|
||||||
preview_temp_scenario();
|
preview_temp_scenario();
|
||||||
}, true, true, AVATAR_PX);
|
}, true, AVATAR_PX);
|
||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
@ -7907,7 +7908,7 @@ Current version indicated by LITEVER below.
|
||||||
temp_scenario.image = compressedImageURI;
|
temp_scenario.image = compressedImageURI;
|
||||||
temp_scenario.image_aspect = aspectratio;
|
temp_scenario.image_aspect = aspectratio;
|
||||||
preview_temp_scenario();
|
preview_temp_scenario();
|
||||||
}, true, true, AVATAR_PX);
|
}, true, AVATAR_PX);
|
||||||
}
|
}
|
||||||
}else{
|
}else{
|
||||||
throw new Error("Selected scenario is invalid.");
|
throw new Error("Selected scenario is invalid.");
|
||||||
|
@ -7986,7 +7987,7 @@ Current version indicated by LITEVER below.
|
||||||
temp_scenario.image = compressedImageURI;
|
temp_scenario.image = compressedImageURI;
|
||||||
temp_scenario.image_aspect = aspectratio;
|
temp_scenario.image_aspect = aspectratio;
|
||||||
preview_temp_scenario();
|
preview_temp_scenario();
|
||||||
}, true, true, AVATAR_PX);
|
}, true, AVATAR_PX);
|
||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
throw new Error("Selected scenario is invalid.");
|
throw new Error("Selected scenario is invalid.");
|
||||||
|
@ -10045,6 +10046,7 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var adminrebootflag = 0;
|
||||||
function trigger_admin_reload()
|
function trigger_admin_reload()
|
||||||
{
|
{
|
||||||
document.getElementById("admincontainer").classList.add("hidden");
|
document.getElementById("admincontainer").classList.add("hidden");
|
||||||
|
@ -10070,9 +10072,32 @@ Current version indicated by LITEVER below.
|
||||||
.then(values => {
|
.then(values => {
|
||||||
let success = (values && values.success);
|
let success = (values && values.success);
|
||||||
if (success) {
|
if (success) {
|
||||||
msgbox("KoboldCpp is now restarting!\n\nIt may take some time before the new instance is ready to use. Please wait a moment, then press OK to refresh the page.", "KoboldCpp Reload Started", false,false,()=>{
|
msgbox("KoboldCpp is now restarting!\n\nIt may take some time before the new instance is ready to use.\n\nYour browser should automatically refresh after a few moments...", "KoboldCpp Reload Started", false,true);
|
||||||
location.reload(true);
|
|
||||||
});
|
setInterval(function () {
|
||||||
|
++adminrebootflag;
|
||||||
|
if(adminrebootflag>1 && adminrebootflag<20)
|
||||||
|
{
|
||||||
|
fetch(apply_proxy_url(custom_kobold_endpoint + kobold_custom_version_endpoint),
|
||||||
|
{
|
||||||
|
method: 'GET',
|
||||||
|
headers: get_kobold_header(),
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
adminrebootflag = 999;
|
||||||
|
location.reload(true);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if(adminrebootflag>2)
|
||||||
|
{
|
||||||
|
adminrebootflag -= 1;
|
||||||
|
}
|
||||||
|
console.error('Not Ready to Restart:', error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, 3000);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
msgbox("The request to reload KoboldCpp with a new configuration failed!\n\nPlease check if the feature is enabled, the admin directory is set, and selected config and password are correct.", "KoboldCpp Reload Failed");
|
msgbox("The request to reload KoboldCpp with a new configuration failed!\n\nPlease check if the feature is enabled, the admin directory is set, and selected config and password are correct.", "KoboldCpp Reload Failed");
|
||||||
}
|
}
|
||||||
|
@ -12764,7 +12789,7 @@ Current version indicated by LITEVER below.
|
||||||
image_db[imgid] = { done: false, queue: "Generating", result: "", prompt:"", poll_category:0 };
|
image_db[imgid] = { done: false, queue: "Generating", result: "", prompt:"", poll_category:0 };
|
||||||
image_db[imgid].aspect = 0;
|
image_db[imgid].aspect = 0;
|
||||||
image_db[imgid].imsource = 1; //0=generated,1=uploaded
|
image_db[imgid].imsource = 1; //0=generated,1=uploaded
|
||||||
let imgres = localsettings.img_allowhd?HD_RES_PX:NO_HD_RES_PX;
|
let imgres = localsettings.img_allowhd?VHD_RES_PX:NO_HD_RES_PX;
|
||||||
compressImage(origImg, (newDataUri, outAspect) => {
|
compressImage(origImg, (newDataUri, outAspect) => {
|
||||||
image_db[imgid].done = true;
|
image_db[imgid].done = true;
|
||||||
image_db[imgid].result = newDataUri;
|
image_db[imgid].result = newDataUri;
|
||||||
|
@ -12784,7 +12809,7 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
image_db[imgid].aspect = 2; //landscape
|
image_db[imgid].aspect = 2; //landscape
|
||||||
}
|
}
|
||||||
}, true, false, imgres,0.35,true);
|
}, false, imgres,0.35,true);
|
||||||
}
|
}
|
||||||
|
|
||||||
function clear_paste_window()
|
function clear_paste_window()
|
||||||
|
@ -15375,7 +15400,7 @@ Current version indicated by LITEVER below.
|
||||||
compressImage(origImg, (newDataUri) => {
|
compressImage(origImg, (newDataUri) => {
|
||||||
image_db[imgid].done = true;
|
image_db[imgid].done = true;
|
||||||
image_db[imgid].result = newDataUri;
|
image_db[imgid].result = newDataUri;
|
||||||
}, true, false, imgres,0.35,false);
|
}, false, imgres);
|
||||||
}else{
|
}else{
|
||||||
image_db[imgid].queue = "Failed";
|
image_db[imgid].queue = "Failed";
|
||||||
msgbox("Image Generation Failed!\n\nPlease make sure KoboldCpp / Forge / A1111 is running and properly configured!\nIn your local install of Automatic1111 WebUi, modify webui-user.bat and add these flags to enable API access:\n\nset COMMANDLINE_ARGS= --api --listen --cors-allow-origins=*\n");
|
msgbox("Image Generation Failed!\n\nPlease make sure KoboldCpp / Forge / A1111 is running and properly configured!\nIn your local install of Automatic1111 WebUi, modify webui-user.bat and add these flags to enable API access:\n\nset COMMANDLINE_ARGS= --api --listen --cors-allow-origins=*\n");
|
||||||
|
@ -15415,7 +15440,7 @@ Current version indicated by LITEVER below.
|
||||||
compressImage(origImg, (newDataUri) => {
|
compressImage(origImg, (newDataUri) => {
|
||||||
image_db[imgid].done = true;
|
image_db[imgid].done = true;
|
||||||
image_db[imgid].result = newDataUri;
|
image_db[imgid].result = newDataUri;
|
||||||
}, true, true, imgres,0.35,false);
|
}, true, imgres);
|
||||||
}else{
|
}else{
|
||||||
image_db[imgid].queue = "Failed";
|
image_db[imgid].queue = "Failed";
|
||||||
msgbox("Image Generation Failed!\n\nPlease make sure your OpenAI key is set correctly and you are allowed to use DALL-E.\n");
|
msgbox("Image Generation Failed!\n\nPlease make sure your OpenAI key is set correctly and you are allowed to use DALL-E.\n");
|
||||||
|
@ -16131,7 +16156,7 @@ Current version indicated by LITEVER below.
|
||||||
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
|
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
|
||||||
compressImage(origImg, (newDataUri) => {
|
compressImage(origImg, (newDataUri) => {
|
||||||
img.result = newDataUri;
|
img.result = newDataUri;
|
||||||
}, true, false, imgres,0.35,false);
|
}, false, imgres);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
|
@ -16181,7 +16206,7 @@ Current version indicated by LITEVER below.
|
||||||
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
|
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
|
||||||
compressImage(origImg, (newDataUri) => {
|
compressImage(origImg, (newDataUri) => {
|
||||||
img.result = newDataUri;
|
img.result = newDataUri;
|
||||||
}, true, false, imgres,0.35,false);
|
}, false, imgres);
|
||||||
};
|
};
|
||||||
reader.readAsDataURL(finalimg);
|
reader.readAsDataURL(finalimg);
|
||||||
})
|
})
|
||||||
|
@ -16235,10 +16260,11 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function compressImage(inputDataUri, onDone, isJpeg=true, fixedSize=true, maxSize=NO_HD_RES_PX, quality = 0.35, forceAspect=false) {
|
function compressImage(inputDataUri, onDone, fixedSize=true, maxSize=NO_HD_RES_PX, quality = 0.35, letterboxAspect=false) {
|
||||||
let img = document.createElement('img');
|
let img = document.createElement('img');
|
||||||
let wantedWidth = maxSize;
|
let wantedWidth = maxSize;
|
||||||
let wantedHeight = maxSize;
|
let wantedHeight = maxSize;
|
||||||
|
const isJpeg = true;
|
||||||
|
|
||||||
// When the event "onload" is triggered we can resize the image.
|
// When the event "onload" is triggered we can resize the image.
|
||||||
img.onload = function () {
|
img.onload = function () {
|
||||||
|
@ -16269,44 +16295,54 @@ Current version indicated by LITEVER below.
|
||||||
canvas.height = wantedHeight;
|
canvas.height = wantedHeight;
|
||||||
|
|
||||||
// We resize the image with the canvas method
|
// We resize the image with the canvas method
|
||||||
if(forceAspect)
|
if(letterboxAspect)
|
||||||
{
|
{
|
||||||
let minsizeW = Math.min(origW, origH);
|
let minsizeW = Math.min(origW, origH);
|
||||||
let minsizeH = Math.min(origW, origH);
|
let minsizeH = Math.min(origW, origH);
|
||||||
|
let targetMaxSize = maxSize;
|
||||||
|
//a bit of a hack, but if the input image is much smaller than the target canvas, we can use a smaller canvas
|
||||||
|
if(targetMaxSize>=VHD_RES_PX && origW<=HD_RES_PX && origH<=HD_RES_PX)
|
||||||
|
{
|
||||||
|
targetMaxSize = HD_RES_PX;
|
||||||
|
}
|
||||||
|
if(targetMaxSize>=HD_RES_PX && origW<=NO_HD_RES_PX && origH<=NO_HD_RES_PX)
|
||||||
|
{
|
||||||
|
targetMaxSize = NO_HD_RES_PX;
|
||||||
|
}
|
||||||
|
|
||||||
if(aspectratio<=0.5)
|
if(aspectratio<=0.5)
|
||||||
{
|
{
|
||||||
//portrait
|
//portrait
|
||||||
minsizeH *= 2;
|
minsizeH *= 2;
|
||||||
canvas.width = wantedWidth = maxSize/2;
|
canvas.width = wantedWidth = targetMaxSize/2;
|
||||||
canvas.height = wantedHeight = maxSize;
|
canvas.height = wantedHeight = targetMaxSize;
|
||||||
}
|
}
|
||||||
else if(aspectratio<0.7)
|
else if(aspectratio<0.7)
|
||||||
{
|
{
|
||||||
//portrait
|
//portrait
|
||||||
minsizeH *= 1.5;
|
minsizeH *= 1.5;
|
||||||
canvas.width = wantedWidth = maxSize/1.5;
|
canvas.width = wantedWidth = targetMaxSize/1.5;
|
||||||
canvas.height = wantedHeight = maxSize;
|
canvas.height = wantedHeight = targetMaxSize;
|
||||||
}
|
}
|
||||||
else if(aspectratio>=2)
|
else if(aspectratio>=2)
|
||||||
{
|
{
|
||||||
//landscape
|
//landscape
|
||||||
minsizeW *= 2;
|
minsizeW *= 2;
|
||||||
canvas.width = wantedWidth = maxSize;
|
canvas.width = wantedWidth = targetMaxSize;
|
||||||
canvas.height = wantedHeight = maxSize/2;
|
canvas.height = wantedHeight = targetMaxSize/2;
|
||||||
}
|
}
|
||||||
else if(aspectratio>1.4)
|
else if(aspectratio>1.4)
|
||||||
{
|
{
|
||||||
//landscape
|
//landscape
|
||||||
minsizeW *= 1.5;
|
minsizeW *= 1.5;
|
||||||
canvas.width = wantedWidth = maxSize;
|
canvas.width = wantedWidth = targetMaxSize;
|
||||||
canvas.height = wantedHeight = maxSize/1.5;
|
canvas.height = wantedHeight = targetMaxSize/1.5;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
//square
|
//square
|
||||||
canvas.width = wantedWidth = maxSize;
|
canvas.width = wantedWidth = targetMaxSize;
|
||||||
canvas.height = wantedHeight = maxSize;
|
canvas.height = wantedHeight = targetMaxSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
let newWidth, newHeight, mx, my;
|
let newWidth, newHeight, mx, my;
|
||||||
|
@ -19639,7 +19675,7 @@ Current version indicated by LITEVER below.
|
||||||
const file = event.target.files[0];
|
const file = event.target.files[0];
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.onload = function(img) {
|
reader.onload = function(img) {
|
||||||
compressImage(img.target.result, loadCompressedImage, true, true, AVATAR_PX);
|
compressImage(img.target.result, loadCompressedImage, true, AVATAR_PX);
|
||||||
function loadCompressedImage(compressedImageURI, aspectratio) {
|
function loadCompressedImage(compressedImageURI, aspectratio) {
|
||||||
|
|
||||||
if(isSelfPortrait)
|
if(isSelfPortrait)
|
||||||
|
@ -22412,6 +22448,8 @@ Current version indicated by LITEVER below.
|
||||||
<option value="command-r-plus">command-r-plus</option>
|
<option value="command-r-plus">command-r-plus</option>
|
||||||
<option value="command-r-08-2024">command-r-08-2024</option>
|
<option value="command-r-08-2024">command-r-08-2024</option>
|
||||||
<option value="command-r-plus-08-2024">command-r-plus-08-2024</option>
|
<option value="command-r-plus-08-2024">command-r-plus-08-2024</option>
|
||||||
|
<option value="command-r7b-12-2024">command-r7b-12-2024</option>
|
||||||
|
<option value="command-a-03-2025">command-a-03-2025</option>
|
||||||
</select>
|
</select>
|
||||||
<span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br>
|
<span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br>
|
||||||
<input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
|
<input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
|
||||||
|
|
|
@ -49,7 +49,7 @@ logit_bias_max = 512
|
||||||
dry_seq_break_max = 128
|
dry_seq_break_max = 128
|
||||||
|
|
||||||
# global vars
|
# global vars
|
||||||
KcppVersion = "1.86.2"
|
KcppVersion = "1.87"
|
||||||
showdebug = True
|
showdebug = True
|
||||||
kcpp_instance = None #global running instance
|
kcpp_instance = None #global running instance
|
||||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}
|
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}
|
||||||
|
|
|
@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||||
|
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||||
|
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||||
{ LLM_ARCH_GRANITE, "granite" },
|
{ LLM_ARCH_GRANITE, "granite" },
|
||||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||||
|
@ -110,22 +112,26 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||||
|
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||||
|
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||||
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||||
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||||
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
|
|
||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||||
|
@ -1238,6 +1244,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_RWKV7,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_ARWKV7,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
{
|
{
|
||||||
|
@ -1397,6 +1471,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
@ -1415,6 +1495,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
@ -1422,6 +1505,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
|
|
@ -63,6 +63,8 @@ enum llm_arch {
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
LLM_ARCH_RWKV6QWEN2,
|
LLM_ARCH_RWKV6QWEN2,
|
||||||
|
LLM_ARCH_RWKV7,
|
||||||
|
LLM_ARCH_ARWKV7,
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
LLM_ARCH_GRANITE_MOE,
|
LLM_ARCH_GRANITE_MOE,
|
||||||
LLM_ARCH_CHAMELEON,
|
LLM_ARCH_CHAMELEON,
|
||||||
|
@ -127,6 +129,10 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_CAUSAL,
|
LLM_KV_ATTENTION_CAUSAL,
|
||||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
@ -250,8 +256,20 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_SSM_A,
|
LLM_TENSOR_SSM_A,
|
||||||
LLM_TENSOR_SSM_D,
|
LLM_TENSOR_SSM_D,
|
||||||
LLM_TENSOR_SSM_OUT,
|
LLM_TENSOR_SSM_OUT,
|
||||||
|
LLM_TENSOR_TIME_MIX_W0,
|
||||||
LLM_TENSOR_TIME_MIX_W1,
|
LLM_TENSOR_TIME_MIX_W1,
|
||||||
LLM_TENSOR_TIME_MIX_W2,
|
LLM_TENSOR_TIME_MIX_W2,
|
||||||
|
LLM_TENSOR_TIME_MIX_A0,
|
||||||
|
LLM_TENSOR_TIME_MIX_A1,
|
||||||
|
LLM_TENSOR_TIME_MIX_A2,
|
||||||
|
LLM_TENSOR_TIME_MIX_V0,
|
||||||
|
LLM_TENSOR_TIME_MIX_V1,
|
||||||
|
LLM_TENSOR_TIME_MIX_V2,
|
||||||
|
LLM_TENSOR_TIME_MIX_G1,
|
||||||
|
LLM_TENSOR_TIME_MIX_G2,
|
||||||
|
LLM_TENSOR_TIME_MIX_K_K,
|
||||||
|
LLM_TENSOR_TIME_MIX_K_A,
|
||||||
|
LLM_TENSOR_TIME_MIX_R_K,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||||
|
|
|
@ -285,11 +285,15 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve worst-case graph
|
// reserve worst-case graph
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
|
|
||||||
|
// restore later
|
||||||
|
// TODO: something cleaner
|
||||||
|
const auto n_outputs_save = n_outputs;
|
||||||
|
|
||||||
// max number of outputs
|
// max number of outputs
|
||||||
n_outputs = n_tokens;
|
n_outputs = n_tokens;
|
||||||
|
|
||||||
|
@ -341,6 +345,8 @@ llama_context::llama_context(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n_outputs = n_outputs_save;
|
||||||
|
|
||||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||||
ggml_backend_t backend = backend_ptrs[i];
|
ggml_backend_t backend = backend_ptrs[i];
|
||||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||||
|
@ -1052,6 +1058,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
const auto causal_attn_org = cparams.causal_attn;
|
||||||
|
|
||||||
|
// always use non-causal attention for encoder graphs
|
||||||
|
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||||
|
cparams.causal_attn = false;
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||||
|
|
||||||
|
@ -1059,6 +1072,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
|
|
||||||
res->set_inputs(&ubatch);
|
res->set_inputs(&ubatch);
|
||||||
|
|
||||||
|
cparams.causal_attn = causal_attn_org;
|
||||||
|
|
||||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||||
switch (compute_status) {
|
switch (compute_status) {
|
||||||
case GGML_STATUS_SUCCESS:
|
case GGML_STATUS_SUCCESS:
|
||||||
|
@ -1129,6 +1144,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||||
//cross.t_embd = t_embd;
|
//cross.t_embd = t_embd;
|
||||||
|
|
||||||
|
synchronize();
|
||||||
|
|
||||||
cross.n_embd = t_embd->ne[0];
|
cross.n_embd = t_embd->ne[0];
|
||||||
cross.n_enc = t_embd->ne[1];
|
cross.n_enc = t_embd->ne[1];
|
||||||
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
||||||
|
|
|
@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
// note: storing RoPE-ed version of K in the KV cache
|
// note: storing RoPE-ed version of K in the KV cache
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||||
|
|
||||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
||||||
|
|
||||||
ggml_tensor * v_cache_view = nullptr;
|
ggml_tensor * v_cache_view = nullptr;
|
||||||
|
|
||||||
|
|
|
@ -487,9 +487,9 @@ struct llm_graph_context {
|
||||||
|
|
||||||
ggml_tensor * build_attn_mha(
|
ggml_tensor * build_attn_mha(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * q,
|
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||||
ggml_tensor * k,
|
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||||
ggml_tensor * v,
|
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
ggml_tensor * kq_mask,
|
ggml_tensor * kq_mask,
|
||||||
bool v_trans,
|
bool v_trans,
|
||||||
|
@ -502,9 +502,9 @@ struct llm_graph_context {
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
@ -516,9 +516,9 @@ struct llm_graph_context {
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
@ -530,9 +530,9 @@ struct llm_graph_context {
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
ggml_tensor * q_cur,
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
ggml_tensor * k_cur,
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
ggml_tensor * v_cur,
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
ggml_tensor * kq_b,
|
ggml_tensor * kq_b,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
|
@ -76,6 +76,10 @@ struct llama_hparams {
|
||||||
uint32_t time_decay_extra_dim = 0;
|
uint32_t time_decay_extra_dim = 0;
|
||||||
uint32_t wkv_head_size = 0;
|
uint32_t wkv_head_size = 0;
|
||||||
uint32_t token_shift_count = 2;
|
uint32_t token_shift_count = 2;
|
||||||
|
uint32_t n_lora_decay = 0;
|
||||||
|
uint32_t n_lora_iclr = 0;
|
||||||
|
uint32_t n_lora_value_res_mix = 0;
|
||||||
|
uint32_t n_lora_gate = 0;
|
||||||
|
|
||||||
float rope_attn_factor = 1.0f;
|
float rope_attn_factor = 1.0f;
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
|
|
1296
src/llama-model.cpp
1296
src/llama-model.cpp
File diff suppressed because it is too large
Load diff
|
@ -29,6 +29,7 @@ enum llm_type {
|
||||||
LLM_TYPE_109M,
|
LLM_TYPE_109M,
|
||||||
LLM_TYPE_137M,
|
LLM_TYPE_137M,
|
||||||
LLM_TYPE_160M,
|
LLM_TYPE_160M,
|
||||||
|
LLM_TYPE_190M,
|
||||||
LLM_TYPE_220M,
|
LLM_TYPE_220M,
|
||||||
LLM_TYPE_250M,
|
LLM_TYPE_250M,
|
||||||
LLM_TYPE_270M,
|
LLM_TYPE_270M,
|
||||||
|
@ -45,6 +46,7 @@ enum llm_type {
|
||||||
LLM_TYPE_1_6B,
|
LLM_TYPE_1_6B,
|
||||||
LLM_TYPE_2B,
|
LLM_TYPE_2B,
|
||||||
LLM_TYPE_2_8B,
|
LLM_TYPE_2_8B,
|
||||||
|
LLM_TYPE_2_9B,
|
||||||
LLM_TYPE_3B,
|
LLM_TYPE_3B,
|
||||||
LLM_TYPE_4B,
|
LLM_TYPE_4B,
|
||||||
LLM_TYPE_6B,
|
LLM_TYPE_6B,
|
||||||
|
@ -260,6 +262,20 @@ struct llama_layer {
|
||||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||||
struct ggml_tensor * time_mix_gate = nullptr;
|
struct ggml_tensor * time_mix_gate = nullptr;
|
||||||
|
|
||||||
|
// rwkv7
|
||||||
|
struct ggml_tensor * time_mix_w0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_g1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_g2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_k_k = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_k_a = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_r_k = nullptr;
|
||||||
|
|
||||||
struct ggml_tensor * time_mix_ln = nullptr;
|
struct ggml_tensor * time_mix_ln = nullptr;
|
||||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||||
struct ggml_tensor * time_mix_output = nullptr;
|
struct ggml_tensor * time_mix_output = nullptr;
|
||||||
|
|
|
@ -759,10 +759,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||||
|
|
||||||
// do not quantize RWKV's time_mix_first tensors
|
// do not quantize RWKV's small yet 2D weights
|
||||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue