Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	CMakeLists.txt
#	cmake/common.cmake
#	docs/backend/SYCL.md
#	examples/main/README.md
#	examples/speculative/speculative.cpp
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-musa/CMakeLists.txt
#	ggml/src/ggml-sycl/CMakeLists.txt
#	ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-03-19 19:27:11 +08:00
commit 0c90d2ebcf
58 changed files with 4222 additions and 1537 deletions

View file

@ -1,519 +0,0 @@
#!/bin/bash
#
# Options
IOS_MIN_OS_VERSION=16.4
MACOS_MIN_OS_VERSION=13.3
VISIONOS_MIN_OS_VERSION=1.0
TVOS_MIN_OS_VERSION=16.4
BUILD_SHARED_LIBS=OFF
LLAMA_BUILD_EXAMPLES=OFF
LLAMA_BUILD_TESTS=OFF
LLAMA_BUILD_SERVER=OFF
GGML_METAL=ON
GGML_METAL_EMBED_LIBRARY=ON
GGML_BLAS_DEFAULT=ON
GGML_METAL_USE_BF16=ON
GGML_OPENMP=OFF
COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
# Common options for all builds
COMMON_CMAKE_ARGS=(
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED=NO
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY=""
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED=NO
-DCMAKE_XCODE_ATTRIBUTE_DEBUG_INFORMATION_FORMAT="dwarf-with-dsym"
-DCMAKE_XCODE_ATTRIBUTE_GCC_GENERATE_DEBUGGING_SYMBOLS=YES
-DCMAKE_XCODE_ATTRIBUTE_COPY_PHASE_STRIP=NO
-DCMAKE_XCODE_ATTRIBUTE_STRIP_INSTALLED_PRODUCT=NO
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
-DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS}
-DLLAMA_BUILD_EXAMPLES=${LLAMA_BUILD_EXAMPLES}
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
-DGGML_METAL=${GGML_METAL}
-DGGML_METAL_USE_BF16=${GGML_METAL_USE_BF16}
-DGGML_NATIVE=OFF
-DGGML_OPENMP=${GGML_OPENMP}
)
check_required_tool() {
local tool=$1
local install_message=$2
if ! command -v $tool &> /dev/null; then
echo "Error: $tool is required but not found."
echo "$install_message"
exit 1
fi
}
echo "Checking for required tools..."
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
set -e
## Clean up previous builds
rm -rf build-apple
rm -rf build-ios-sim
rm -rf build-ios-device
rm -rf build-macos
rm -rf build-visionos
rm -rf build-visionos-sim
rm -rf build-tvos-sim
rm -rf build-tvos-device
# Setup the xcframework build directory structure
setup_framework_structure() {
local build_dir=$1
local min_os_version=$2
local platform=$3 # "ios", "macos", "visionos", or "tvos"
local framework_name="llama"
echo "Creating ${platform}-style framework structure for ${build_dir}"
if [[ "$platform" == "macos" ]]; then
# macOS versioned structure uses versioned directories
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Headers
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Modules
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Resources
# Create symbolic links
ln -sf A ${build_dir}/framework/${framework_name}.framework/Versions/Current
ln -sf Versions/Current/Headers ${build_dir}/framework/${framework_name}.framework/Headers
ln -sf Versions/Current/Modules ${build_dir}/framework/${framework_name}.framework/Modules
ln -sf Versions/Current/Resources ${build_dir}/framework/${framework_name}.framework/Resources
ln -sf Versions/Current/${framework_name} ${build_dir}/framework/${framework_name}.framework/${framework_name}
# Set header and module paths
local header_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Headers/
local module_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Modules/
else
# iOS/VisionOS/tvOS use a flat structure
mkdir -p ${build_dir}/framework/${framework_name}.framework/Headers
mkdir -p ${build_dir}/framework/${framework_name}.framework/Modules
# Remove any existing structure to ensure clean build
rm -rf ${build_dir}/framework/${framework_name}.framework/Versions
# Set header and module paths
local header_path=${build_dir}/framework/${framework_name}.framework/Headers/
local module_path=${build_dir}/framework/${framework_name}.framework/Modules/
fi
# Copy all required headers (common for all platforms)
cp include/llama.h ${header_path}
cp ggml/include/ggml.h ${header_path}
cp ggml/include/ggml-alloc.h ${header_path}
cp ggml/include/ggml-backend.h ${header_path}
cp ggml/include/ggml-metal.h ${header_path}
cp ggml/include/ggml-cpu.h ${header_path}
cp ggml/include/ggml-blas.h ${header_path}
cp ggml/include/gguf.h ${header_path}
# Create module map (common for all platforms)
cat > ${module_path}module.modulemap << EOF
framework module llama {
header "llama.h"
header "ggml.h"
header "ggml-alloc.h"
header "ggml-backend.h"
header "ggml-metal.h"
header "ggml-cpu.h"
header "ggml-blas.h"
header "gguf.h"
link "c++"
link framework "Accelerate"
link framework "Metal"
link framework "Foundation"
export *
}
EOF
# Platform-specific settings for Info.plist
local platform_name=""
local sdk_name=""
local supported_platform=""
case "$platform" in
"ios")
platform_name="iphoneos"
sdk_name="iphoneos${min_os_version}"
supported_platform="iPhoneOS"
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
local device_family=' <key>UIDeviceFamily</key>
<array>
<integer>1</integer>
<integer>2</integer>
</array>'
;;
"macos")
platform_name="macosx"
sdk_name="macosx${min_os_version}"
supported_platform="MacOSX"
local plist_path="${build_dir}/framework/${framework_name}.framework/Versions/A/Resources/Info.plist"
local device_family=""
;;
"visionos")
platform_name="xros"
sdk_name="xros${min_os_version}"
supported_platform="XRPlatform"
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
local device_family=""
;;
"tvos")
platform_name="appletvos"
sdk_name="appletvos${min_os_version}"
supported_platform="AppleTVOS"
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
local device_family=' <key>UIDeviceFamily</key>
<array>
<integer>3</integer>
</array>'
;;
esac
# Create Info.plist
cat > ${plist_path} << EOF
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleExecutable</key>
<string>llama</string>
<key>CFBundleIdentifier</key>
<string>org.ggml.llama</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>llama</string>
<key>CFBundlePackageType</key>
<string>FMWK</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleVersion</key>
<string>1</string>
<key>MinimumOSVersion</key>
<string>${min_os_version}</string>
<key>CFBundleSupportedPlatforms</key>
<array>
<string>${supported_platform}</string>
</array>${device_family}
<key>DTPlatformName</key>
<string>${platform_name}</string>
<key>DTSDKName</key>
<string>${sdk_name}</string>
</dict>
</plist>
EOF
}
# Create dynamic libraries from static libraries.
combine_static_libraries() {
local build_dir="$1"
local release_dir="$2"
local platform="$3" # "ios", "macos", "visionos", or "tvos"
local is_simulator="$4"
local base_dir="$(pwd)"
local framework_name="llama"
# Determine output path based on platform
local output_lib=""
if [[ "$platform" == "macos" ]]; then
# macOS uses versioned structure
output_lib="${build_dir}/framework/${framework_name}.framework/Versions/A/${framework_name}"
else
# iOS, visionOS, and tvOS use a directory flat structure
output_lib="${build_dir}/framework/${framework_name}.framework/${framework_name}"
fi
local libs=(
"${base_dir}/${build_dir}/src/${release_dir}/libllama.a"
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml.a"
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-base.a"
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
)
# Create temporary directory for processing
local temp_dir="${base_dir}/${build_dir}/temp"
mkdir -p "${temp_dir}"
# Since we have multiple architectures libtool will find object files that do not
# match the target architecture. We suppress these warnings.
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
# Determine SDK, architectures, and install_name based on platform and simulator flag.
local sdk=""
local archs=""
local min_version_flag=""
local install_name=""
case "$platform" in
"ios")
if [[ "$is_simulator" == "true" ]]; then
sdk="iphonesimulator"
archs="arm64 x86_64"
min_version_flag="-mios-simulator-version-min=${IOS_MIN_OS_VERSION}"
else
sdk="iphoneos"
archs="arm64"
min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
fi
install_name="@rpath/llama.framework/llama"
;;
"macos")
sdk="macosx"
archs="arm64 x86_64"
min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
install_name="@rpath/llama.framework/Versions/Current/llama"
;;
"visionos")
if [[ "$is_simulator" == "true" ]]; then
sdk="xrsimulator"
archs="arm64 x86_64"
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}-simulator"
else
sdk="xros"
archs="arm64"
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}"
fi
# Use flat structure for visionOS, same as iOS
install_name="@rpath/llama.framework/llama"
;;
"tvos")
if [[ "$is_simulator" == "true" ]]; then
sdk="appletvsimulator"
archs="arm64 x86_64"
min_version_flag="-mtvos-simulator-version-min=${TVOS_MIN_OS_VERSION}"
else
sdk="appletvos"
archs="arm64"
min_version_flag="-mtvos-version-min=${TVOS_MIN_OS_VERSION}"
fi
install_name="@rpath/llama.framework/llama"
;;
esac
# Build architecture flags
local arch_flags=""
for arch in $archs; do
arch_flags+=" -arch $arch"
done
# Create dynamic library
echo "Creating dynamic library for ${platform}."
xcrun -sdk $sdk clang++ -dynamiclib \
-isysroot $(xcrun --sdk $sdk --show-sdk-path) \
$arch_flags \
$min_version_flag \
-Wl,-force_load,"${temp_dir}/combined.a" \
-framework Foundation -framework Metal -framework Accelerate \
-install_name "$install_name" \
-o "${base_dir}/${output_lib}"
# Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then
if command -v vtool &>/dev/null; then
case "$platform" in
"ios")
echo "Marking binary as a framework binary for iOS..."
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
"visionos")
echo "Marking binary as a framework binary for visionOS..."
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
"tvos")
echo "Marking binary as a framework binary for tvOS..."
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
esac
else
echo "Warning: vtool not found. Binary may not pass App Store validation."
fi
fi
echo "Creating properly formatted dSYM..."
# Create a separate directory for dSYMs for all platforms
mkdir -p "${base_dir}/${build_dir}/dSYMs"
# iOS and visionOS style dSYM (flat structure)
if [[ "$platform" == "ios" || "$platform" == "visionos" || "$platform" == "tvos" ]]; then
# Generate dSYM in the dSYMs directory
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
# Create a copy of the binary that will be stripped
cp "${base_dir}/${output_lib}" "${temp_dir}/binary_to_strip"
# Strip debug symbols from the copy
xcrun strip -S "${temp_dir}/binary_to_strip" -o "${temp_dir}/stripped_lib"
# Replace the original with the stripped version
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
else
# macOS style dSYM
# First strip debug info to a separate file
xcrun strip -S "${base_dir}/${output_lib}" -o "${temp_dir}/stripped_lib"
# Generate dSYM in the dSYMs directory
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
# Replace original binary with stripped version
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
fi
# Remove any automatically generated dSYM files in the framework structure as they will
# otherwise case Invalid Bundle Structure validation errors.
if [ -d "${base_dir}/${output_lib}.dSYM" ]; then
echo "Removing generated dSYM file in framework structure: ${base_dir}/${output_lib}.dSYM"
rm -rf "${base_dir}/${output_lib}.dSYM"
fi
# Clean up
rm -rf "${temp_dir}"
}
echo "Building for iOS simulator..."
cmake -B build-ios-sim -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
-DIOS=ON \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_SYSROOT=iphonesimulator \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-ios-sim --config Release -- -quiet
echo "Building for iOS devices..."
cmake -B build-ios-device -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
-DCMAKE_OSX_SYSROOT=iphoneos \
-DCMAKE_OSX_ARCHITECTURES="arm64" \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-ios-device --config Release -- -quiet
echo "Building for macOS..."
cmake -B build-macos -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${MACOS_MIN_OS_VERSION} \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-macos --config Release -- -quiet
echo "Building for visionOS..."
cmake -B build-visionos -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
-DCMAKE_OSX_ARCHITECTURES="arm64" \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos --config Release -- -quiet
echo "Building for visionOS simulator..."
cmake -B build-visionos-sim -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos-sim --config Release -- -quiet
# Add tvOS builds (might need the same u_int definitions as watchOS and visionOS)
echo "Building for tvOS simulator..."
cmake -B build-tvos-sim -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
-DCMAKE_SYSTEM_NAME=tvOS \
-DCMAKE_OSX_SYSROOT=appletvsimulator \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DGGML_METAL=ON \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvsimulator \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-tvos-sim --config Release -- -quiet
echo "Building for tvOS devices..."
cmake -B build-tvos-device -G Xcode \
"${COMMON_CMAKE_ARGS[@]}" \
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
-DCMAKE_SYSTEM_NAME=tvOS \
-DCMAKE_OSX_SYSROOT=appletvos \
-DCMAKE_OSX_ARCHITECTURES="arm64" \
-DGGML_METAL=ON \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvos \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-tvos-device --config Release -- -quiet
# Setup frameworks and copy binaries and headers
echo "Setting up framework structures..."
setup_framework_structure "build-ios-sim" ${IOS_MIN_OS_VERSION} "ios"
setup_framework_structure "build-ios-device" ${IOS_MIN_OS_VERSION} "ios"
setup_framework_structure "build-macos" ${MACOS_MIN_OS_VERSION} "macos"
setup_framework_structure "build-visionos" ${VISIONOS_MIN_OS_VERSION} "visionos"
setup_framework_structure "build-visionos-sim" ${VISIONOS_MIN_OS_VERSION} "visionos"
setup_framework_structure "build-tvos-sim" ${TVOS_MIN_OS_VERSION} "tvos"
setup_framework_structure "build-tvos-device" ${TVOS_MIN_OS_VERSION} "tvos"
# Create dynamic libraries from static libraries
echo "Creating dynamic libraries from static libraries..."
combine_static_libraries "build-ios-sim" "Release-iphonesimulator" "ios" "true"
combine_static_libraries "build-ios-device" "Release-iphoneos" "ios" "false"
combine_static_libraries "build-macos" "Release" "macos" "false"
combine_static_libraries "build-visionos" "Release-xros" "visionos" "false"
combine_static_libraries "build-visionos-sim" "Release-xrsimulator" "visionos" "true"
combine_static_libraries "build-tvos-sim" "Release-appletvsimulator" "tvos" "true"
combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
# Create XCFramework with correct debug symbols paths
echo "Creating XCFramework..."
xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos-sim/dSYMs/llama.dSYM \
-framework $(pwd)/build-tvos-device/framework/llama.framework \
-debug-symbols $(pwd)/build-tvos-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-tvos-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-tvos-sim/dSYMs/llama.dSYM \
-output $(pwd)/build-apple/llama.xcframework

View file

@ -180,7 +180,8 @@ class Model:
extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0: if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}") raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
else: else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n" raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n" f"Missing tensors: {missing}\n"
@ -908,6 +909,40 @@ class Model:
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_rwkv_world(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)
tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)
self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int): def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf" tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'") logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@ -1065,13 +1100,6 @@ class BloomModel(Model):
tensors.append((self.map_tensor_name(name), data_torch)) tensors.append((self.map_tensor_name(name), data_torch))
if name == "word_embeddings.weight":
assert self.tensor_names is not None
# TODO: tie them at runtime, don't duplicate in the model file
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors return tensors
@ -1713,6 +1741,25 @@ class LlamaModel(Model):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = Model.load_hparams(kwargs["dir_model"])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
if "multi_modal_projector" in name or "vision_tower" in name:
return []
return super().modify_tensors(data_torch, name, bid)
@Model.register("DeciLMForCausalLM") @Model.register("DeciLMForCausalLM")
class DeciModel(Model): class DeciModel(Model):
model_arch = gguf.MODEL_ARCH.DECI model_arch = gguf.MODEL_ARCH.DECI
@ -2370,10 +2417,6 @@ class GPT2Model(Model):
tensors.append((new_name, data_torch)) tensors.append((new_name, data_torch))
# note: GPT2 output is tied to (same as) wte in original model
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors return tensors
@ -2703,21 +2746,26 @@ class CodeShellModel(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0) self.gguf_writer.add_rope_scaling_factor(1.0)
_has_tok_embd = False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
new_name = self.map_tensor_name(name) new_name = self.map_tensor_name(name)
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)] # assuming token_embd.weight is seen before output.weight
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
self.tensor_names.remove("transformer.wte.weight")
elif new_name == tok_embd_name:
self._has_tok_embd = True
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD): return [(new_name, data_torch)]
assert self.tensor_names is not None
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
# copy tok_embd.weight to output.weight
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors
@Model.register("InternLM2ForCausalLM") @Model.register("InternLM2ForCausalLM")
@ -3412,38 +3460,7 @@ class Rwkv6Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6 model_arch = gguf.MODEL_ARCH.RWKV6
def set_vocab(self): def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() self._set_vocab_rwkv_world()
vocab_size = self.hparams.get("vocab_size", 65536)
tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)
self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"] block_count = self.hparams["num_hidden_layers"]
@ -3565,6 +3582,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
yield (new_name, data) yield (new_name, data)
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
class Rwkv7Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV7
def set_vocab(self):
self._set_vocab_rwkv_world()
def calc_lora_rank(self, hidden_size, exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try:
head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
except KeyError:
head_size = self.hparams["head_dim"]
layer_norm_eps = self.hparams["norm_eps"]
hidden_size = self.hparams["hidden_size"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
# ICLR: In-Context-Learning-Rate
try:
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
except KeyError:
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
lerp_weights: dict[int, dict[str, Tensor]] = {}
lora_needs_transpose: bool = True
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# unify tensor names here to make life easier
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
name = name.replace("self_attn", "attention").replace("attn", "attention")
name = name.replace("time_mixer.", "")
# lora layer names in fla-hub's impl
if "_lora.lora" in name:
self.lora_needs_transpose = False
name = name.replace("_lora.lora.0.weight", "1.weight")
name = name.replace("_lora.lora.2.weight", "2.weight")
name = name.replace("_lora.lora.2.bias", "0.weight")
name = name.replace("feed_forward_norm", "ln2")
name = name.replace("g_norm", "ln_x")
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
# some models have dummy v0/v1/v2 on first layer while others don't
# ignore them all since they are not used
return
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
if bid is not None and "attention.x_" in name:
if "attention.x_x" in name:
# already concatenated
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
yield (new_name, data)
else:
try:
self.lerp_weights[bid][name] = data_torch
except KeyError:
self.lerp_weights[bid] = {name: data_torch}
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
yield (new_name, data)
return
else:
data_torch = data_torch.squeeze()
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if self.lora_needs_transpose and any(
new_name.endswith(t) for t in [
"time_mix_w1.weight", "time_mix_w2.weight",
"time_mix_a1.weight", "time_mix_a2.weight",
"time_mix_v1.weight", "time_mix_v2.weight",
"time_mix_g1.weight", "time_mix_g2.weight",
]
):
data_torch = data_torch.transpose(0, 1)
if 'r_k' in new_name:
data_torch = data_torch.flatten()
if bid == 0 and "time_mix_a" in new_name:
# dummy v0/v1/v2 on first layer
# easist way to make llama happy
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
yield (new_name, data_torch)
@Model.register("RwkvHybridForCausalLM")
class ARwkv7Model(Rwkv7Model):
model_arch = gguf.MODEL_ARCH.ARWKV7
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
wkv_has_gate = self.hparams["wkv_has_gate"]
assert self.hparams["wkv_version"] == 7
# ICLR: In-Context-Learning-Rate
lora_rank_decay = 64
lora_rank_iclr = 64
lora_rank_value_residual_mix = 32
lora_rank_gate = 128 if wkv_has_gate else 0
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_token_shift_count(1)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model): class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA

View file

@ -1872,6 +1872,10 @@ struct server_context {
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1; params_dft.n_parallel = 1;
// force F16 KV cache for the draft model for extra performance
params_dft.cache_type_k = GGML_TYPE_F16;
params_dft.cache_type_v = GGML_TYPE_F16;
llama_init_dft = common_init_from_params(params_dft); llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
@ -1892,10 +1896,6 @@ struct server_context {
cparams_dft = common_context_params_to_llama(params_dft); cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft; cparams_dft.n_batch = n_ctx_dft;
// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;
// the context is not needed - we will create one for each slot // the context is not needed - we will create one for each slot
llama_init_dft.context.reset(); llama_init_dft.context.reset();
} }

26
ggml/cmake/common.cmake Normal file
View file

@ -0,0 +1,26 @@
function(ggml_get_flags CCID CCVER)
set(C_FLAGS "")
set(CXX_FLAGS "")
if (CCID MATCHES "Clang")
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
if (
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
)
list(APPEND C_FLAGS -Wdouble-promotion)
endif()
elseif (CCID STREQUAL "GNU")
set(C_FLAGS -Wdouble-promotion)
set(CXX_FLAGS -Wno-array-bounds)
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
list(APPEND CXX_FLAGS -Wextra-semi)
endif()
endif()
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
endfunction()

View file

@ -460,6 +460,7 @@ extern "C" {
GGML_OP_RMS_NORM, GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK, GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM, GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,
GGML_OP_MUL_MAT, GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID,
@ -508,6 +509,7 @@ extern "C" {
GGML_OP_ADD_REL_POS, GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6, GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN, GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_UNARY, GGML_OP_UNARY,
@ -1108,6 +1110,18 @@ extern "C" {
int n_groups, int n_groups,
float eps); float eps);
// l2 normalize along rows
// used in rwkv v7
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
// a - x // a - x
// b - dy // b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back( GGML_API struct ggml_tensor * ggml_rms_norm_back(
@ -1903,6 +1917,16 @@ extern "C" {
struct ggml_tensor * state, struct ggml_tensor * state,
float scale); float scale);
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_context * ctx,
struct ggml_tensor * r,
struct ggml_tensor * w,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state);
// custom operators // custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View file

@ -8159,7 +8159,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K; const int nb = n / QK_K;
#ifdef __ARM_NEON #ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);
svuint8_t mone = svdup_n_u8(0x30);
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
for (int i = 0; i < nb; ++i) {
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const int8_t * GGML_RESTRICT scale = x[i].scales;
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
const svint64_t prod = svdup_n_s64(0);
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
svdot_s64(prod, q8sums_2, q6scales_2)));
int32_t isum = 0;
switch (vector_length) {
case 128:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
q8bytes_1 = svld1_s8(pg8_16, q8);
q8bytes_2 = svld1_s8(pg8_16, q8+16);
q8bytes_3 = svld1_s8(pg8_16, q8+32);
q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
}
isum += svaddv_s32(pg32_4, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
case 256:
case 512:
{
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; j++) {
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
q8 += 128;
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
scale += 8;
}
isum += svaddv_s32(pg32_8, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sum;
#elif __ARM_NEON
float sum = 0; float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF); const uint8x16_t m4b = vdupq_n_u8(0xF);

View file

@ -8578,6 +8578,69 @@ static void ggml_compute_forward_group_norm(
} }
} }
// ggml_compute_forward_l2_norm
static void ggml_compute_forward_l2_norm_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_compute_forward_l2_norm(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_l2_norm_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_mul_mat // ggml_compute_forward_mul_mat
static void ggml_compute_forward_mul_mat_one_chunk( static void ggml_compute_forward_mul_mat_one_chunk(
@ -13643,6 +13706,184 @@ static void ggml_compute_forward_gla(
} }
} }
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[6]->ne[1];
const int64_t head_size = C / HEADS;
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * r = (float *) dst->src[0]->data;
float * w = (float *) dst->src[1]->data;
float * k = (float *) dst->src[2]->data;
float * v = (float *) dst->src[3]->data;
float * a = (float *) dst->src[4]->data;
float * b = (float *) dst->src[5]->data;
int64_t t_stride = HEADS * head_size; // Same to C
int64_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
int64_t h_offset = h * h_stride;
int64_t t_h_offset = t_offset + h_offset;
int64_t h_2d_offset = h * h_stride_2d;
for (int64_t ii = 0; ii < head_size; ii++) {
int64_t t_h_i_offset = t_h_offset + ii;
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
float sa = 0;
{
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
}
}
GGML_F32_VEC_REDUCE(sa, sum);
}
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
int64_t j = 0;
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
for (; j < head_size; j += GGML_F32_STEP) {
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
// kv + s * decay + sa * b
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
}
}
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
// There shouldn't be left-overs though.
for (; j < head_size; j++) {
int64_t t_h_j_offset = t_h_offset + j;
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v[t_h_i_offset] * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
}
}
}
}
#else
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
int64_t h_offset = h * h_stride;
int64_t t_h_offset = t_offset + h_offset;
int64_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
int64_t t_h_i_offset = t_h_offset + i;
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
float v_val = v[t_h_i_offset];
float sa = 0, result = 0;
for (int64_t j = 0; j < head_size; j++) {
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
}
for (int64_t j = 0; j < head_size; j++) {
int64_t t_h_j_offset = t_h_offset + j;
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
result += state_cur[h_2d_i_j_offset] * r_val;
}
dst_data[t_h_i_offset] = result;
}
}
}
#endif
}
static void ggml_compute_forward_rwkv_wkv7(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_unary // ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32( static void ggml_compute_forward_map_unary_f32(
@ -14209,6 +14450,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_group_norm(params, tensor); ggml_compute_forward_group_norm(params, tensor);
} break; } break;
case GGML_OP_L2_NORM:
{
ggml_compute_forward_l2_norm(params, tensor);
} break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
ggml_compute_forward_mul_mat(params, tensor); ggml_compute_forward_mul_mat(params, tensor);
@ -14396,6 +14641,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_gla(params, tensor); ggml_compute_forward_gla(params, tensor);
} break; } break;
case GGML_OP_RWKV_WKV7:
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
{ {
ggml_unary_op_f32_t fun; ggml_unary_op_f32_t fun;
@ -14621,6 +14870,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
@ -14687,14 +14937,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32:

View file

@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
}; };
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
#define USE_CUDA_GRAPH #define USE_CUDA_GRAPH
#endif #endif

View file

@ -38,7 +38,7 @@ bool g_mul_mat_q = true;
#include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh" #include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh" #include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh" #include "ggml-cuda/gla.cuh"
#include "ggml.h" #include "ggml.h"
@ -265,6 +265,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize); device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA) #elif defined(GGML_USE_MUSA)
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info.devices[id].warp_size = 32;
// TODO: refine the .cc to reflect MUSA's actual CC capabilities // TODO: refine the .cc to reflect MUSA's actual CC capabilities
info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor; info.devices[id].cc = 100*prop.major + 10*prop.minor;
@ -2201,6 +2203,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst); ggml_cuda_op_group_norm(ctx, dst);
break; break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
ggml_cuda_op_concat(ctx, dst); ggml_cuda_op_concat(ctx, dst);
break; break;
@ -2309,6 +2314,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst); ggml_cuda_op_gated_linear_attn(ctx, dst);
break; break;
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst); ggml_cuda_cross_entropy_loss_back(ctx, dst);
break; break;
@ -2615,13 +2623,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info; cudaGraphExecUpdateResultInfo result_info;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t errorNode;
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#endif #else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) { if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG #ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
@ -3164,6 +3174,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
break; break;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true; return true;
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
@ -3218,6 +3229,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: { case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE #ifndef FLASH_ATTN_AVAILABLE

View file

@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
} }
} }
// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
// const int tid = threadIdx.x;
// float tmp = 0.0f; // partial sum for thread in warp
// for (int col = tid; col < ncols; col += block_size) {
// const float xi = x[row*ncols + col];
// tmp += xi * xi;
// }
// // sum up partial sums
// tmp = warp_reduce_sum(tmp);
// if (block_size > WARP_SIZE) {
// __shared__ float s_sum[32];
// int warp_id = threadIdx.x / WARP_SIZE;
// int lane_id = threadIdx.x % WARP_SIZE;
// if (lane_id == 0) {
// s_sum[warp_id] = tmp;
// }
// __syncthreads();
// tmp = s_sum[lane_id];
// tmp = warp_reduce_sum(tmp);
// }
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
// for (int col = tid; col < ncols; col += block_size) {
// dst[row*ncols + col] = scale * x[row*ncols + col];
// }
// }
template <int block_size>
static __global__ void l2_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
static void norm_f32_cuda( static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
} }
} }
static void l2_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
} }
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}

View file

@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -112,7 +112,7 @@
#define cudaGraphExecDestroy hipGraphExecDestroy #define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch #define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult #define cudaGraphExecUpdateResult hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType #define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate #define cudaGraphInstantiate hipGraphInstantiate

View file

@ -119,7 +119,7 @@
#define cudaGraphExecDestroy musaGraphExecDestroy #define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t #define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate #define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult #define cudaGraphExecUpdateResult musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes #define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate #define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@ -132,6 +132,7 @@
#define cudaGraph_t musaGraph_t #define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams #define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture #define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16; typedef mt_bfloat16 nv_bfloat16;

199
ggml/src/ggml-cuda/wkv.cu Normal file
View file

@ -0,0 +1,199 @@
#include "common.cuh"
#include "wkv.cuh"
template <int block_size>
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
template <int block_size>
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
#ifndef GGML_USE_MUSA
#pragma unroll
#endif
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < head_size; j += 4)
{
const float4& a = (float4&)(_a[j]);
const float4& s = (float4&)(state[j]);
sa += a.x * s.x;
sa += a.y * s.y;
sa += a.z * s.z;
sa += a.w * s.w;
}
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& r = (float4&)(_r[j]);
const float4& w = (float4&)(_w[j]);
const float4& k = (float4&)(_k[j]);
const float4& b = (float4&)(_b[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
s.x = s.x * w.x + kv.x + sa * b.x;
s.y = s.y * w.y + kv.y + sa * b.y;
s.z = s.z * w.z + kv.z + sa * b.z;
s.w = s.w * w.w + kv.w + sa * b.w;
y += s.x * r.x;
y += s.y * r.y;
y += s.z * r.z;
y += s.w * r.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
} else {
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}
}
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * r_d = (const float *)dst->src[0]->data;
const float * w_d = (const float *)dst->src[1]->data;
const float * k_d = (const float *)dst->src[2]->data;
const float * v_d = (const float *)dst->src[3]->data;
const float * a_d = (const float *)dst->src[4]->data;
const float * b_d = (const float *)dst->src[5]->data;
const float * s_d = (const float *)dst->src[6]->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
}

View file

@ -3,3 +3,5 @@
#define CUDA_WKV_BLOCK_SIZE 64 #define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1,89 +0,0 @@
#include "common.cuh"
#include "wkv6.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View file

@ -285,6 +285,13 @@ typedef struct {
float eps; float eps;
} ggml_metal_kargs_rms_norm; } ggml_metal_kargs_rms_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct { typedef struct {
int64_t ne00; int64_t ne00;
int64_t ne01; int64_t ne01;

View file

@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
return true; return true;
@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true; return true;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_RWKV_WKV6:
{
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&B length:sizeof(B) atIndex:7];
[encoder setBytes:&T length:sizeof(T) atIndex:8];
[encoder setBytes:&C length:sizeof(C) atIndex:9];
[encoder setBytes:&H length:sizeof(H) atIndex:10];
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
} break;
case GGML_OP_RWKV_WKV7:
{
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&B length:sizeof(B) atIndex:8];
[encoder setBytes:&T length:sizeof(T) atIndex:9];
[encoder setBytes:&C length:sizeof(C) atIndex:10];
[encoder setBytes:&H length:sizeof(H) atIndex:11];
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
} break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_L2_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00/4);
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:

View file

@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
} }
} }
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (int j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
kernel void kernel_argmax( kernel void kernel_argmax(
device const void * x, device const void * x,
device int32_t * dst, device int32_t * dst,
@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
} }
} }
kernel void kernel_l2_norm(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
}
kernel void kernel_group_norm( kernel void kernel_group_norm(
device const float * src0, device const float * src0,
device float * dst, device float * dst,

View file

@ -297,8 +297,27 @@ static int ggml_backend_opencl_n_devices = 0;
struct ProfilingInfo { struct ProfilingInfo {
std::string op_name; std::string op_name;
std::string kernel_name; std::string kernel_name;
// Kernel execution time in nanoseconds.
cl_ulong duration_ns; cl_kernel kernel;
cl_event evt;
cl_ulong cmd_queued;
cl_ulong cmd_submit;
cl_ulong cmd_start;
cl_ulong cmd_end;
cl_ulong overhead_start;
cl_ulong overhead_end;
// For the times below, see spec for clGetEventProfilingInfo
// The time kernel spent in cmd queue - SUBMIT - QUEUED
cl_ulong cmd_queued_duration_ns;
// The time kernel spent for submission - START - SUBMIT
cl_ulong cmd_submit_duration_ns;
// Kernel execution time in nanoseconds - END - START
cl_ulong cmd_duration_ns;
// The time for the kernel to complete - COMPLETE - END
cl_ulong cmd_complete_duration_ns;
// Total time to finish the kernel - COMPELTE - QUEUED
cl_ulong cmd_total_duration_ns;
// Global and local work sizes. // Global and local work sizes.
size_t global_size[3]; size_t global_size[3];
size_t local_size[3]; size_t local_size[3];
@ -903,12 +922,56 @@ static void ggml_cl2_free(void) {
return; return;
} }
// Populate profiling info
for (ProfilingInfo & info : g_profiling_info) {
cl_ulong cmd_queued;
cl_ulong cmd_submit;
cl_ulong cmd_start;
cl_ulong cmd_end;
cl_ulong cmd_complete;
CL_CHECK(clWaitForEvents(1, &info.evt));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
CL_CHECK(clReleaseEvent(info.evt));
char kernel_name[512];
CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
sizeof(kernel_name), kernel_name, NULL));
info.kernel_name = kernel_name;
info.cmd_queued = cmd_queued;
info.cmd_submit = cmd_submit;
info.cmd_start = cmd_start;
info.cmd_end = cmd_end;
info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
info.cmd_submit_duration_ns = cmd_start - cmd_submit;
info.cmd_duration_ns = cmd_end - cmd_start;
info.cmd_complete_duration_ns = cmd_complete - cmd_end;
info.cmd_total_duration_ns = cmd_complete - cmd_queued;
}
// Dump a csv
float total_kernel_time = 0; float total_kernel_time = 0;
fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n"); fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
for (const ProfilingInfo & info : g_profiling_info) { for (const ProfilingInfo & info : g_profiling_info) {
total_kernel_time += info.duration_ns/1.e6f; total_kernel_time += info.cmd_duration_ns/1.e6f;
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f, info.op_name.c_str(), info.kernel_name.c_str(),
info.cmd_queued_duration_ns/1.e6f,
info.cmd_submit_duration_ns/1.e6f,
info.cmd_duration_ns/1.e6f,
info.cmd_complete_duration_ns/1.e6f,
info.cmd_total_duration_ns/1.e6f,
info.global_size[0], info.global_size[1], info.global_size[2], info.global_size[0], info.global_size[1], info.global_size[2],
info.local_size[0], info.local_size[2], info.local_size[2], info.local_size[0], info.local_size[2], info.local_size[2],
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
@ -916,6 +979,27 @@ static void ggml_cl2_free(void) {
fclose(fperf); fclose(fperf);
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time); GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
// Dump a simple chrome trace
FILE* ftrace = fopen("cl_trace.json", "w");
if (!ftrace) {
GGML_LOG_ERROR("Failed to open cl_trace.json\n");
return;
}
fprintf(ftrace, "[\n");
for (const ProfilingInfo & info : g_profiling_info) {
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_queued/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_submit/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_start/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_end/1000);
}
fclose(ftrace);
#endif #endif
} }
@ -2062,25 +2146,14 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
// Profiling utility // Profiling utility
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
#ifdef GGML_OPENCL_PROFILING #ifdef GGML_OPENCL_PROFILING
void populateProfilingInfo( static void populateProfilingInfo(
ProfilingInfo& info, cl_event evt, cl_kernel kernel, ProfilingInfo& info, cl_event evt, cl_kernel kernel,
size_t global_size[3], size_t local_size[3], size_t global_size[3], size_t local_size[3],
const ggml_tensor * tensor) { const ggml_tensor * tensor) {
cl_ulong start;
cl_ulong end;
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clGetEventProfilingInfo(
evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
CL_CHECK(clGetEventProfilingInfo(
evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
char kernel_name[512];
CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
sizeof(kernel_name), kernel_name, NULL));
info.duration_ns = end - start;
info.op_name = tensor->name; info.op_name = tensor->name;
info.kernel_name = kernel_name; info.kernel = kernel;
info.evt = evt;
info.local_size[0] = local_size[0]; info.local_size[0] = local_size[0];
info.local_size[1] = local_size[1]; info.local_size[1] = local_size[1];
info.local_size[2] = local_size[2]; info.local_size[2] = local_size[2];

View file

@ -26,7 +26,7 @@
#include "softmax.hpp" #include "softmax.hpp"
#include "tsembd.hpp" #include "tsembd.hpp"
#include "im2col.hpp" #include "im2col.hpp"
#include "wkv6.hpp" #include "wkv.hpp"
#include "outprod.hpp" #include "outprod.hpp"
#include "element_wise.hpp" #include "element_wise.hpp"
#include "cpy.hpp" #include "cpy.hpp"

View file

@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
return opt; return opt;
} }
namespace sycl_ex = sycl::ext::oneapi::experimental;
struct ggml_backend_sycl_context { struct ggml_backend_sycl_context {
int device; int device;
std::string name; std::string name;
@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
return pool(device); return pool(device);
} }
#ifdef GGML_SYCL_GRAPH
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
#endif
ggml_sycl_pool & host_pool(int device) { ggml_sycl_pool & host_pool(int device) {
if (host_pools[device] == nullptr) { if (host_pools[device] == nullptr) {
host_pools[device] = new_pool_for_host(stream(device, 0), device); host_pools[device] = new_pool_for_host(stream(device, 0), device);

View file

@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE), sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)), sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q4_0_reorder(vx, y, k, item_ct1); dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
}); });

View file

@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
nrows, item_ct1); nrows, item_ct1);
}); });
@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>( dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>( dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>( dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>( dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>( dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>( dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
}); });
} }
@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
default: default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break;
} }
GGML_UNUSED(src1); GGML_UNUSED(src1);

View file

@ -1,7 +1,7 @@
#include "common.hpp" #include "common.hpp"
#include "element_wise.hpp" #include "element_wise.hpp"
void acc_f32(const float * x, const float * y, float * dst, const int ne, static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12, const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
} }
} }
void gelu_f32(const float * x, float * dst, const int k, static void gelu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const float GELU_COEF_A = 0.044715f; const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
} }
void silu_f32(const float * x, float * dst, const int k, static void silu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
} }
void gelu_quick_f32(const float *x, float *dst, int k, static void gelu_quick_f32(const float *x, float *dst, int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const float GELU_QUICK_COEF = -1.702f; const float GELU_QUICK_COEF = -1.702f;
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
} }
void tanh_f32(const float *x, float *dst, int k, static void tanh_f32(const float *x, float *dst, int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
dst[i] = sycl::tanh((float)(x[i])); dst[i] = sycl::tanh((float)(x[i]));
} }
void relu_f32(const float * x, float * dst, const int k, static void relu_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
dst[i] = sycl::fmax((float)(x[i]), (float)0); dst[i] = sycl::fmax((float)(x[i]), (float)0);
} }
void sigmoid_f32(const float * x, float * dst, const int k, static void sigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i])); dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
} }
void sqrt_f32(const float * x, float * dst, const int k, static void sqrt_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
dst[i] = sycl::sqrt(x[i]); dst[i] = sycl::sqrt(x[i]);
} }
void sin_f32(const float * x, float * dst, const int k, static void sin_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
dst[i] = sycl::sin(x[i]); dst[i] = sycl::sin(x[i]);
} }
void cos_f32(const float * x, float * dst, const int k, static void cos_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
dst[i] = sycl::cos(x[i]); dst[i] = sycl::cos(x[i]);
} }
void hardsigmoid_f32(const float * x, float * dst, const int k, static void hardsigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
} }
void hardswish_f32(const float * x, float * dst, const int k, static void hardswish_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
} }
void exp_f32(const float * x, float * dst, const int k, static void exp_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
dst[i] = sycl::exp(x[i]); dst[i] = sycl::exp(x[i]);
} }
void log_f32(const float * x, float * dst, const int k, static void log_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
} }
} }
void neg_f32(const float * x, float * dst, const int k, static void neg_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
dst[i] = -x[i]; dst[i] = -x[i];
} }
void step_f32(const float * x, float * dst, const int k, static void step_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
dst[i] = x[i] > 0.0f; dst[i] = x[i] > 0.0f;
} }
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
sycl::fmin((float)(x[i]), 0.0f) * negative_slope; sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
} }
void sqr_f32(const float * x, float * dst, const int k, static void sqr_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * x[i]; dst[i] = x[i] * x[i];
} }
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11, const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1, const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
} }
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) + int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2); item_ct1.get_group(2) * item_ct1.get_local_range(2);
@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
void acc_f32_sycl(const float *x, const float *y, float *dst, static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11, const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2, const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) { const int offset, queue_ptr stream) {
@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
}); });
} }
void gelu_f32_sycl(const float *x, float *dst, const int k, static void gelu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void silu_f32_sycl(const float *x, float *dst, const int k, static void silu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void gelu_quick_f32_sycl(const float *x, float *dst, const int k, static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void tanh_f32_sycl(const float *x, float *dst, const int k, static void tanh_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void relu_f32_sycl(const float *x, float *dst, const int k, static void relu_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void hardswish_f32_sycl(const float *x, float *dst, const int k, static void hardswish_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void exp_f32_sycl(const float *x, float *dst, const int k, static void exp_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void log_f32_sycl(const float *x, float *dst, const int k, static void log_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void neg_f32_sycl(const float *x, float *dst, const int k, static void neg_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void step_f32_sycl(const float *x, float *dst, const int k, static void step_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void sigmoid_f32_sycl(const float *x, float *dst, const int k, static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void sqrt_f32_sycl(const float *x, float *dst, const int k, static void sqrt_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void sin_f32_sycl(const float *x, float *dst, const int k, static void sin_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void cos_f32_sycl(const float *x, float *dst, const int k, static void cos_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void leaky_relu_f32_sycl(const float *x, float *dst, const int k, static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
const float negative_slope, const float negative_slope,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void sqr_f32_sycl(const float *x, float *dst, const int k, static void sqr_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
}); });
} }
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11, const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1, const int ne12, const int ne13, const float sf0, const float sf1,
const float sf2, const float sf3, queue_ptr stream) { const float sf2, const float sf3, queue_ptr stream) {
@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
}); });
} }
void pad_f32_sycl(const float *x, float *dst, const int ne00, static void pad_f32_sycl(const float *x, float *dst, const int ne00,
const int ne01, const int ne02, const int ne0, const int ne01, const int ne02, const int ne0,
const int ne1, const int ne2, queue_ptr stream) { const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;

View file

@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
const size_t nrows = ne01; const size_t nrows = ne01;
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2); const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
k_get_rows_reorder<qk, qr, dq_reorder>( k_get_rows_reorder<qk, qr, dq_reorder>(
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2, src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
// TODO: k-quants // TODO: k-quants
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break;
} }
} }

View file

@ -46,6 +46,7 @@
static bool g_sycl_loaded = false; static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0; int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
static ggml_sycl_device_info ggml_sycl_init() { static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {}; ggml_sycl_device_info info = {};
@ -95,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
return info; return info;
} }
void print_device_detail(int id, sycl::device &device, std::string device_type) { static void print_device_detail(int id, sycl::device &device, std::string device_type) {
dpct::device_info prop; dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR( SYCL_CHECK(CHECK_TRY_ERROR(
@ -118,7 +119,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str()); global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
} }
void print_device_opt_feature(int device_count) { static void print_device_opt_feature(int device_count) {
GGML_LOG_INFO("SYCL Optimization Feature:\n"); GGML_LOG_INFO("SYCL Optimization Feature:\n");
GGML_LOG_INFO( GGML_LOG_INFO(
"|ID| Device Type|Reorder|\n"); "|ID| Device Type|Reorder|\n");
@ -191,10 +192,12 @@ static void ggml_check_sycl() try {
if (!initialized) { if (!initialized) {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
GGML_LOG_INFO("Build with Macros:\n"); GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ) #if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
@ -333,10 +336,11 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft); assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
if (tensor->type == GGML_TYPE_Q4_0) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra; tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
}
if (ggml_is_quantized(tensor->type)) { if (ggml_is_quantized(tensor->type)) {
// initialize padding to 0 to avoid possible NaN values // initialize padding to 0 to avoid possible NaN values
@ -400,7 +404,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
const void *ptr_src, size_t size) { const void *ptr_src, size_t size) {
char *host_buf = (char *)malloc(size); char *host_buf = (char *)malloc(size);
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
@ -486,6 +490,22 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
if (buffer == nullptr) {
return;
}
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
if (ctx != nullptr) {
for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
release_extra_gpu(extra);
}
ctx->tensor_extras.clear(); // reset the tensor_extras vector
}
}
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
/* .get_base = */ ggml_backend_sycl_buffer_get_base, /* .get_base = */ ggml_backend_sycl_buffer_get_base,
@ -495,7 +515,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
/* .clear = */ ggml_backend_sycl_buffer_clear, /* .clear = */ ggml_backend_sycl_buffer_clear,
/* .reset = */ NULL, /* .reset = */ ggml_backend_sycl_buffer_reset,
}; };
// sycl buffer type // sycl buffer type
@ -576,7 +596,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
auto dev_count = ggml_backend_sycl_get_device_count(); auto dev_count = ggml_backend_sycl_get_device_count();
@ -604,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
return &ggml_backend_sycl_buffer_types[device]; return &ggml_backend_sycl_buffer_types[device];
} }
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) { static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
int device = ctx->device; int device = ctx->device;
@ -1666,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(num_blocks * block_size, block_size), sycl::nd_range<3>(num_blocks * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1); quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
}); });
} }
@ -1687,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x, mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
nchannels_y, item_ct1); nchannels_y, item_ct1);
}); });
@ -1707,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
row_stride_x, channel_stride_x, row_stride_x, channel_stride_x,
nchannels_y / nchannels_x, item_ct1); nchannels_y / nchannels_x, item_ct1);
@ -1748,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const sycl::range<3> block_nums(1, nrows, 1); const sycl::range<3> block_nums(1, nrows, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
k_sum_rows_f32(x, dst, ncols, item_ct1); k_sum_rows_f32(x, dst, ncols, item_ct1);
}); });
} }
@ -2680,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@ -2898,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
return false; return false;
} }
bool ggml_sycl_supports_dmmv(enum ggml_type type) { static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
@ -3271,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
} }
void ggml_sycl_set_main_device(const int main_device) try { static void ggml_sycl_set_main_device(const int main_device) try {
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) { if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
return; return;
} }
@ -3292,7 +3317,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) { static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
if (!g_sycl_loaded) return false; if (!g_sycl_loaded) return false;
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@ -3394,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst); ggml_sycl_rms_norm(ctx, dst);
break; break;
case GGML_OP_L2_NORM:
ggml_sycl_l2_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false; return false;
@ -3471,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst); ggml_sycl_op_rwkv_wkv6(ctx, dst);
break; break;
case GGML_OP_RWKV_WKV7:
ggml_sycl_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst); ggml_sycl_op_gated_linear_attn(ctx, dst);
break; break;
@ -3610,7 +3641,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
void reorder_qw(char *data_device, const int ncols, const int nrows, static void reorder_qw(char *data_device, const int ncols, const int nrows,
size_t size, size_t offset, dpct::queue_ptr stream) { size_t size, size_t offset, dpct::queue_ptr stream) {
auto tmp_buf = sycl::malloc_shared<char>(size, *stream); auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
SYCL_CHECK( SYCL_CHECK(
@ -3624,7 +3655,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
stream->parallel_for( stream->parallel_for(
size / sizeof(block_q4_0), size / sizeof(block_q4_0),
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const block_q4_0* x = (const block_q4_0*)tmp_buf; const block_q4_0* x = (const block_q4_0*)tmp_buf;
const int ib = i; const int ib = i;
@ -3638,7 +3669,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
sycl::free(tmp_buf, *stream); sycl::free(tmp_buf, *stream);
} }
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) { static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
char*data_device = (char*)src0->data; char*data_device = (char*)src0->data;
size_t ncols = src0->ne[0]; size_t ncols = src0->ne[0];
size_t nrows = src0->ne[1]; size_t nrows = src0->ne[1];
@ -3647,7 +3678,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
reorder_qw(data_device, ncols, nrows, size, 0, stream); reorder_qw(data_device, ncols, nrows, size, 0, stream);
} }
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) { static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
ggml_tensor *src0 = dst->src[0]; ggml_tensor *src0 = dst->src[0];
ggml_tensor *src1 = dst->src[1]; ggml_tensor *src1 = dst->src[1];
@ -3660,7 +3691,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
} }
} }
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) { static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
dpct::queue_ptr stream = ctx->stream(); dpct::queue_ptr stream = ctx->stream();
if (ctx->optimized_graph) { if (ctx->optimized_graph) {
return; return;
@ -3671,10 +3702,9 @@ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx)
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream); if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
} }
} }
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
ggml_sycl_set_main_device(sycl_ctx->device);
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
ggml_sycl_set_main_device(sycl_ctx->device);
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx); if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
@ -3696,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
} }
GGML_ASSERT(ok); GGML_ASSERT(ok);
} }
}
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
if (!g_ggml_sycl_disable_graph) {
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
return GGML_STATUS_SUCCESS;
}
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
if (!sycl_ctx->exec_graph) {
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
} else {
try {
sycl_ctx->exec_graph->update(model_sycl_graph);
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
} catch (sycl::exception const & e) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
}
}
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
} else
#endif
{
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
}
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
@ -3761,7 +3830,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
} }
int ggml_backend_sycl_get_device_count() { int ggml_backend_sycl_get_device_count() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
return ggml_sycl_info().device_count; return ggml_sycl_info().device_count;
} }
@ -3851,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true; return true;
} }
return false; return false;
} break; }
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) { switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_NEG:
@ -3869,7 +3937,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
default: default:
return false; return false;
} }
break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
@ -3900,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false; return false;
} }
return true; return true;
} break; }
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
@ -3917,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
default: default:
return false; return false;
} }
} break; }
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
@ -3968,12 +4035,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true; return true;
} }
return false; return false;
} break; }
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
} break; }
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_NONE: case GGML_OP_NONE:
@ -3997,6 +4064,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->src[0]->type == GGML_TYPE_F32); return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE: case GGML_OP_SCALE:
@ -4030,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_GATED_LINEAR_ATTN:
return true; return true;
default: default:

View file

@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break;
} }
GGML_UNUSED(src1); GGML_UNUSED(src1);

View file

@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>( VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>( VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>( VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>( VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>( VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>( VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>( VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>( VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>( VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>( VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -734,7 +734,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>( mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -755,7 +755,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>( mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -777,7 +777,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>( mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -799,7 +799,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>( mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -821,7 +821,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>( mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -843,7 +843,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>( mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -864,7 +864,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>( mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -886,7 +886,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>( mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -908,7 +908,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>( mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -1003,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break;
} }
} }
GGML_UNUSED(src1); GGML_UNUSED(src1);

View file

@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
} }
} }
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp, item_ct1);
if (block_size > WARP_SIZE) {
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
/*
DPCT1118:3: SYCL group functions and algorithms must be encountered in
converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
}
}
static void norm_f32_sycl(const float* x, float* dst, const int ncols, static void norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps, const int nrows, const float eps,
queue_ptr stream, int device) { queue_ptr stream, int device) {
@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1, norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
}); });
@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1, norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32( group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1, x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements, group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1, eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1, rms_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
}); });
@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1, rms_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
} }
} }
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
ggml_tensor* dst, const float* src0_dd, ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd, const float* src1_dd, float* dst_dd,
@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
(void)dst; (void)dst;
(void)src1_dd; (void)src1_dd;
} }
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
}

View file

@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
float* dst_dd, float* dst_dd,
const queue_ptr& main_stream); const queue_ptr& main_stream);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
#endif // GGML_SYCL_NORM_HPP #endif // GGML_SYCL_NORM_HPP

View file

@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0, nrows_y, scale, max_bias, m0,
m1, n_head_log2, item_ct1, m1, n_head_log2, item_ct1,

305
ggml/src/ggml-sycl/wkv.cpp Normal file
View file

@ -0,0 +1,305 @@
#include <sycl/sycl.hpp>
#include "wkv.hpp"
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
// Helper function for the main kernel
template <int block_size>
static void rwkv_wkv6_f32_kernel(
const int B, const int T, const int C, const int H,
const float* k, const float* v, const float* r,
const float* tf, const float* td, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
// Set up shared memory pointers
float* _k = shared_mem;
float* _r = _k + head_size;
float* _tf = _r + head_size;
float* _td = _tf + head_size;
// Local state array
float state[block_size];
// Load initial state
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
// Sync threads before shared memory operations
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load time-mixing parameters
_tf[tid] = tf[head_i * head_size + tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main sequence processing loop
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load current timestep data to shared memory
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0;
// Process in chunks of 4 for better vectorization
sycl::float4 k4, r4, tf4, td4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
// Load data in vec4 chunks
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
// Compute key-value product
sycl::float4 kv4 = k4 * _v;
// Accumulate weighted sum
y += sycl::dot(r4, tf4 * kv4 + s4);
// Update state
s4 = s4 * td4 + kv4;
// Store updated state
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
// Save final state
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
template <int block_size>
static void rwkv_wkv7_f32_kernel(
const int B, const int T, const int C, const int H,
const float* r, const float* w, const float* k, const float* v,
const float* a, const float* b, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = block_size;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float* _r = shared_mem;
float* _w = _r + head_size;
float* _k = _w + head_size;
float* _a = _k + head_size;
float* _b = _a + head_size;
float state[block_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0, sa = 0;
sycl::float4 a4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sa += sycl::dot(a4, s4);
}
sycl::float4 r4, w4, k4, b4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sycl::float4 kv4 = k4 * _v;
s4 = s4 * w4 + kv4 + sa * b4;
y += sycl::dot(r4, s4);
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
}
}
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* k_d = (const float*)dst->src[0]->data;
const float* v_d = (const float*)dst->src[1]->data;
const float* r_d = (const float*)dst->src[2]->data;
const float* tf_d = (const float*)dst->src[3]->data;
const float* td_d = (const float*)dst->src[4]->data;
const float* s_d = (const float*)dst->src[5]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* r_d = (const float*)dst->src[0]->data;
const float* w_d = (const float*)dst->src[1]->data;
const float* k_d = (const float*)dst->src[2]->data;
const float* v_d = (const float*)dst->src[3]->data;
const float* a_d = (const float*)dst->src[4]->data;
const float* b_d = (const float*)dst->src[5]->data;
const float* s_d = (const float*)dst->src[6]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}

View file

@ -0,0 +1,10 @@
#ifndef GGML_SYCL_WKV_HPP
#define GGML_SYCL_WKV_HPP
#include "common.hpp"
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_WKV_HPP

View file

@ -1,143 +0,0 @@
#include <sycl/sycl.hpp>
#include "wkv6.hpp"
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
// Helper function for the main kernel
static void rwkv_wkv_f32_kernel(
const int B, const int T, const int C, const int H,
const float* k, const float* v, const float* r,
const float* tf, const float* td, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
// Set up shared memory pointers
float* _k = shared_mem;
float* _r = _k + head_size;
float* _tf = _r + head_size;
float* _td = _tf + head_size;
// Local state array
float state[WKV_BLOCK_SIZE];
// Load initial state
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
// Sync threads before shared memory operations
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load time-mixing parameters
_tf[tid] = tf[head_i * head_size + tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main sequence processing loop
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
// Load current timestep data to shared memory
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0;
// Process in chunks of 4 for better vectorization
sycl::float4 k4, r4, tf4, td4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
// Load data in vec4 chunks
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
// Compute key-value product
sycl::float4 kv4 = k4 * _v;
// Accumulate weighted sum
y += sycl::dot(r4, tf4 * kv4 + s4);
// Update state
s4 = s4 * td4 + kv4;
// Store updated state
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
// Save final state
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* k_d = (const float*)dst->src[0]->data;
const float* v_d = (const float*)dst->src[1]->data;
const float* r_d = (const float*)dst->src[2]->data;
const float* tf_d = (const float*)dst->src[3]->data;
const float* td_d = (const float*)dst->src[4]->data;
const float* s_d = (const float*)dst->src[5]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv_f32_kernel(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}

View file

@ -1,9 +0,0 @@
#ifndef GGML_SYCL_WKV6_HPP
#define GGML_SYCL_WKV6_HPP
#include "common.hpp"
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_WKV6_HPP

View file

@ -33,6 +33,7 @@
#include "ggml-vulkan-shaders.cpp" #include "ggml-vulkan-shaders.cpp"
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_AMD 0x1002
@ -153,6 +154,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t mul_mat_vec_max_cols = 8;
enum vk_device_architecture {
OTHER,
AMD_GCN,
AMD_RDNA1,
AMD_RDNA2,
AMD_RDNA3,
};
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties();
if (props.vendorID == VK_VENDOR_ID_AMD) {
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
bool amd_shader_core_properties = false;
bool integer_dot_product = false;
bool subgroup_size_control = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
amd_shader_core_properties = true;
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
integer_dot_product = true;
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
subgroup_size_control = true;
}
}
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
return vk_device_architecture::OTHER;
}
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
props2.pNext = &shader_core_props_amd;
shader_core_props_amd.pNext = &integer_dot_props;
integer_dot_props.pNext = &subgroup_size_control_props;
device.getProperties2(&props2);
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
return vk_device_architecture::AMD_GCN;
}
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
// RDNA
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
return vk_device_architecture::AMD_RDNA1;
}
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
return vk_device_architecture::AMD_RDNA3;
}
return vk_device_architecture::AMD_RDNA2;
}
}
return vk_device_architecture::OTHER;
}
struct vk_device_struct { struct vk_device_struct {
std::mutex mutex; std::mutex mutex;
@ -165,6 +226,7 @@ struct vk_device_struct {
bool pipeline_robustness; bool pipeline_robustness;
vk::Device device; vk::Device device;
uint32_t vendor_id; uint32_t vendor_id;
vk_device_architecture architecture;
vk_queue compute_queue; vk_queue compute_queue;
vk_queue transfer_queue; vk_queue transfer_queue;
bool single_queue; bool single_queue;
@ -246,6 +308,7 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
vk_pipeline pipeline_gelu_f32; vk_pipeline pipeline_gelu_f32;
vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_gelu_quick_f32;
vk_pipeline pipeline_silu_f32; vk_pipeline pipeline_silu_f32;
@ -270,6 +333,7 @@ struct vk_device_struct {
vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@ -372,6 +436,7 @@ struct vk_mat_mat_push_constants {
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t k_split; uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
}; };
struct vk_mat_vec_push_constants { struct vk_mat_vec_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@ -384,6 +449,7 @@ struct vk_mat_mat_id_push_constants {
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
uint32_t padded_N;
}; };
struct vk_mat_vec_id_push_constants { struct vk_mat_vec_id_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@ -569,6 +635,13 @@ struct vk_op_rwkv_wkv6_push_constants {
uint32_t H; uint32_t H;
}; };
struct vk_op_rwkv_wkv7_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
uint32_t H;
};
// Allow pre-recording command buffers // Allow pre-recording command buffers
struct vk_staging_memcpy { struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -1449,6 +1522,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
return supported; return supported;
} }
struct GpuPipelineConfig {
// GPU architecture identifier.
// Example: vk_device_architecture::AMD_GCN
vk_device_architecture arch;
// Mapping of pipeline names to their specific subgroup sizes.
// Example: {"soft_max_f32", 64}
std::unordered_map<std::string, uint32_t> pipelines;
// Default subgroup size for this GPU.
// Defaults to 0 if not explicitly provided.
uint32_t default_subgroup_size = 0;
};
// Pipeline configuration for RDNA1 GPUs.
static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
{"soft_max", 64}, {"im2col", 64},
{"argmax", 64}, {"mul_mat_vec", 64},
{"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
};
// Pipeline configuration for RDNA2 GPUs.
static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
{"soft_max", 64}, {"im2col", 64},
};
static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
// Define configurations for different GPUs.
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
{
vk_device_architecture::AMD_RDNA1,
{
rdna1_pipelines,
},
RDNA_DEFAULT_SUBGROUP_SIZE
},
{
vk_device_architecture::AMD_RDNA2,
{
rdna2_pipelines,
},
RDNA_DEFAULT_SUBGROUP_SIZE
},
};
static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
for (const auto &config : gpu_pipeline_configs) {
if (config.arch == arch) {
auto pipIt = config.pipelines.find(pipeline_name);
if (pipIt != config.pipelines.end()) {
return pipIt->second;
}
std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
[](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
for (const auto &entry : sorted_pipelines) {
if (pipeline_name.find(entry.first) != std::string::npos) {
return entry.second;
}
}
return config.default_subgroup_size;
}
}
return 0; // If no matching configuration is found
}
static void ggml_vk_load_shaders(vk_device& device) { static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@ -1470,36 +1610,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t l_align, m_align, s_align; uint32_t l_align, m_align, s_align;
if (device->coopmat2) { if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id // spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64 }; l_warptile = { 256, 128, 256, 64, 1 };
m_warptile = { 256, 128, 128, 64 }; m_warptile = { 256, 128, 128, 64, 0 };
s_warptile = { 128, 64, 64, 64 }; s_warptile = { 128, 64, 64, 64, 0 };
l_wg_denoms = {128, 256, 1 }; l_wg_denoms = {128, 256, 1 };
m_wg_denoms = {128, 128, 1 }; m_wg_denoms = {128, 128, 1 };
s_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 64, 64, 1 };
// spec constants and tile sizes for quant matmul (non-Qi_K) // spec constants and tile sizes for quant matmul (non-Qi_K)
l_warptile_mmq = { 256, 128, 256, 64 }; l_warptile_mmq = { 256, 128, 256, 64, 1 };
m_warptile_mmq = { 256, 128, 128, 64 }; m_warptile_mmq = { 256, 128, 128, 64, 1 };
s_warptile_mmq = { 256, 128, 128, 64 }; s_warptile_mmq = { 256, 32, 64, 128, 0 };
l_mmq_wg_denoms = { 128, 256, 1 }; l_mmq_wg_denoms = { 128, 256, 1 };
m_mmq_wg_denoms = { 128, 128, 1 }; m_mmq_wg_denoms = { 128, 128, 1 };
s_mmq_wg_denoms = { 128, 128, 1 }; s_mmq_wg_denoms = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul (Qi_K) // spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 128, 512, 16 }; l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
m_warptile_mmq_k = { 256, 128, 256, 16 }; m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
s_warptile_mmq_k = { 256, 32, 128, 64 }; s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms_k = { 128, 512, 1 }; l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 128, 256, 1 }; m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 128, 1 }; s_mmq_wg_denoms_k = { 32, 32, 1 };
// spec constants and tile sizes for quant matmul_id // spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16 }; l_warptile_mmqid = { 256, 128, 64, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16 }; m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 64, 64, 16 }; s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 128, 1 }; l_mmqid_wg_denoms = { 128, 64, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 64, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 };
l_align = 128; l_align = 128;
m_align = 64; m_align = 64;
@ -1575,6 +1715,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
if (!require_full_subgroups && required_subgroup_size == 0) {
required_subgroup_size = get_subgroup_size(name, device->architecture);
}
if (!pipeline) { if (!pipeline) {
pipeline = std::make_shared<vk_pipeline_struct>(); pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name; pipeline->name = name;
@ -2132,6 +2276,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -2243,6 +2388,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) { for (auto &c : compiles) {
@ -2251,7 +2398,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->need_compiles = false; device->need_compiles = false;
} }
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
static vk_device ggml_vk_get_device(size_t idx) { static vk_device ggml_vk_get_device(size_t idx) {
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@ -2280,6 +2427,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->physical_device = physical_devices[dev_num]; device->physical_device = physical_devices[dev_num];
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
device->architecture = get_device_architecture(device->physical_device);
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
@ -2292,7 +2441,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool coopmat2_support = false; bool coopmat2_support = false;
device->coopmat_support = false; device->coopmat_support = false;
// Check if maintenance4 is supported
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
maintenance4_support = true; maintenance4_support = true;
@ -2380,13 +2528,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
#if defined(_WIN32) } else {
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
// Limit batching of allocations to 1GB by default to avoid fragmentation issues // Limit batching of allocations to 1GB by default to avoid fragmentation issues
device->suballocation_block_size = 1024*1024*1024; device->suballocation_block_size = 1024*1024*1024;
#endif
} else {
device->suballocation_block_size = device->max_memory_allocation_size;
} }
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
@ -2405,7 +2549,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
device->coopmat_support = false; device->coopmat_support = false;
} }
@ -2787,7 +2931,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
subgroup_props.pNext = &driver_props; subgroup_props.pNext = &driver_props;
physical_device.getProperties2(&props2); physical_device.getProperties2(&props2);
const size_t subgroup_size = subgroup_props.subgroupSize; vk_device_architecture arch = get_device_architecture(physical_device);
uint32_t default_subgroup_size = get_subgroup_size("", arch);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
bool fp16_storage = false; bool fp16_storage = false;
@ -2813,7 +2960,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} }
} }
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { const vk_device_architecture device_architecture = get_device_architecture(physical_device);
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
coopmat_support = false; coopmat_support = false;
} }
@ -3858,10 +4007,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) { if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { // Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l; return aligned ? mmp->a_l : mmp->l;
} }
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { // Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m; return aligned ? mmp->a_m : mmp->m;
} }
return aligned ? mmp->a_s : mmp->s; return aligned ? mmp->a_s : mmp->s;
@ -3886,18 +4039,19 @@ static void ggml_vk_matmul(
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
if (split_k == 1) { if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
return; return;
} }
GGML_ASSERT(batch_stride_d == m * n); GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work // Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
@ -3906,13 +4060,17 @@ static void ggml_vk_matmul(
} }
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) { if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { // Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l; return aligned ? mmp->a_l : mmp->l;
} }
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { // Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m; return aligned ? mmp->a_m : mmp->m;
} }
return aligned ? mmp->a_s : mmp->s; return aligned ? mmp->a_s : mmp->s;
@ -3937,14 +4095,15 @@ static void ggml_vk_matmul_id(
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11 }; nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
} }
@ -4106,15 +4265,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const int x_ne = ne01 * ne00;
const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01;
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@ -4237,7 +4398,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10, ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
split_k, ne12*ne13, ne02, ne12, r2, r3 split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT ); // NOLINT
} }
@ -4688,15 +4849,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = padded_n * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@ -4815,7 +4978,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01, ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21, stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT ); // NOLINT
} }
@ -5326,6 +5489,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rms_norm_back_f32; return ctx->device->pipeline_rms_norm_back_f32;
} }
return nullptr; return nullptr;
case GGML_OP_L2_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_l2_norm_f32;
}
return nullptr;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) { switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
@ -5465,6 +5633,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv6_f32; return ctx->device->pipeline_rwkv_wkv6_f32;
} }
return nullptr; return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32; return ctx->device->pipeline_opt_step_adamw_f32;
@ -5712,6 +5885,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
@ -5961,23 +6135,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun); }, dryrun);
} }
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
const ggml_tensor * k = dst->src[0]; GGML_ASSERT(version == 6 || version == 7);
const ggml_tensor * v = dst->src[1]; int num_srcs = version == 6 ? 6 : 7;
const ggml_tensor * r = dst->src[2];
const ggml_tensor * tf = dst->src[3]; for (int i = 0; i < num_srcs; i++) {
const ggml_tensor * td = dst->src[4]; GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
const ggml_tensor * state = dst->src[5]; }
GGML_ASSERT(!ggml_is_quantized(k->type));
GGML_ASSERT(!ggml_is_quantized(v->type));
GGML_ASSERT(!ggml_is_quantized(r->type));
GGML_ASSERT(!ggml_is_quantized(tf->type));
GGML_ASSERT(!ggml_is_quantized(td->type));
GGML_ASSERT(!ggml_is_quantized(state->type));
GGML_ASSERT(dst->buffer != nullptr); GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr); GGML_ASSERT(pipeline != nullptr);
if (dryrun) { if (dryrun) {
@ -5986,89 +6154,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
} }
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; for (int i = 0; i < num_srcs; i++) {
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; }
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
if (ctx->device->uma) { if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); for (int i = 0; i < num_srcs; i++) {
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); srcs_uma[i] = d_srcs[i] != nullptr;
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); }
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
} }
if (!K_uma) { uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
d_K = k_buf_ctx->dev_buffer; for (int i = 0; i < num_srcs; i++) {
k_offset = vk_tensor_offset(k) + k->view_offs; src_sizes[i] = ggml_nbytes(dst->src[i]);
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
} }
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
} }
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer; const uint64_t dst_size = ggml_nbytes(dst);
r_offset = vk_tensor_offset(r) + r->view_offs; if (!dst_uma) {
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
d_D = dst_buf_ctx->dev_buffer; d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs; dst_offset = vk_tensor_offset(dst) + dst->view_offs;
} }
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
const uint64_t r_size = ggml_nbytes(r);
const uint64_t tf_size = ggml_nbytes(tf);
const uint64_t td_size = ggml_nbytes(td);
const uint64_t state_size = ggml_nbytes(state);
const uint64_t dst_size = ggml_nbytes(dst);
std::array<uint32_t, 3> elements = { std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H), (uint32_t)(pc.B * pc.H),
1, 1,
1 1
}; };
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size }, vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_V, v_offset, v_size }, vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_R, r_offset, r_size }, vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_TF, tf_offset, tf_size }, vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_TD, td_offset, td_size }, vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_State, state_offset, state_size }, vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_D, dst_offset, dst_size } vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
} }
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@ -6077,7 +6229,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
const size_t n_heads = dst->src[0]->ne[1]; const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6( ggml_vk_op_f32_wkv(
ctx, subctx, dst, ctx, subctx, dst,
{ {
(uint32_t)n_seqs, (uint32_t)n_seqs,
@ -6085,6 +6237,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
(uint32_t)n_embed, (uint32_t)n_embed,
(uint32_t)n_heads, (uint32_t)n_heads,
}, },
6,
dryrun
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7,
dryrun dryrun
); );
} }
@ -6386,6 +6558,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
} }
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
} }
@ -6775,7 +6952,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k, m, n, k,
k, k, m, k*m, k*n, m*n, k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1 split_k, batch, batch, batch, 1, 1, n
); );
} }
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);
@ -7120,7 +7297,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k, m, n, k,
k, k, m, k*m, k*n, m*n, k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1 split_k, batch, batch, batch, 1, 1, n
); );
} }
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);
@ -7381,6 +7558,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
@ -7397,6 +7575,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
@ -7443,6 +7622,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_UNARY: case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -7560,6 +7740,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_L2_NORM:
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) { switch (ggml_get_unary_op(node)) {
@ -7650,6 +7834,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
break; break;
case GGML_OP_RWKV_WKV7:
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@ -7723,6 +7912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
@ -7742,6 +7932,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK: case GGML_OP_REPEAT_BACK:
@ -8253,8 +8444,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
} }
if (ctx->device->need_compiles) { if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device); ggml_vk_load_shaders(ctx->device);
@ -8275,17 +8470,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool first_node_in_batch = true; // true if next node will be first node in a batch bool first_node_in_batch = true; // true if next node will be first node in a batch
int submit_node_idx = 0; // index to first node in a batch int submit_node_idx = 0; // index to first node in a batch
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Start with a smaller count to get work submitted right away, and increase it after each submit. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
int nodes_per_submit = 20; // (and scaled down based on model size, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
int nodes_per_submit = 100;
int submitted_nodes = 0; int submitted_nodes = 0;
int submit_count = 0; int submit_count = 0;
uint64_t mul_mat_bytes = 0;
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) { if (first_node_in_batch) {
submit_node_idx = i; submit_node_idx = i;
} }
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i == last_node);
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
@ -8302,13 +8507,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
if (submit) { if (submit) {
first_node_in_batch = true; first_node_in_batch = true;
submitted_nodes = 0; submitted_nodes = 0;
switch (submit_count) { mul_mat_bytes = 0;
case 0: if (submit_count < 3) {
nodes_per_submit = 50; mul_mat_bytes_per_submit *= 2;
break;
default:
nodes_per_submit = 100;
break;
} }
submit_count++; submit_count++;
} }
@ -8659,6 +8860,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_SUB: case GGML_OP_SUB:
@ -8688,6 +8890,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
return true; return true;
@ -8834,7 +9037,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
UNUSED(instance_extensions); UNUSED(instance_extensions);
} }
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) { switch (props.vendorID) {
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:
// Intel drivers don't support coopmat properly yet // Intel drivers don't support coopmat properly yet
@ -8842,10 +9045,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
// Workaround for AMD proprietary driver reporting support on all GPUs // Workaround for AMD proprietary driver reporting support on all GPUs
const std::string name = props.deviceName; return arch == vk_device_architecture::AMD_RDNA3;
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
} }
return true; return true;
default: default:
@ -9075,6 +9275,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
} else if (tensor->op == GGML_OP_SILU_BACK) { } else if (tensor->op == GGML_OP_SILU_BACK) {
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_L2_NORM) {
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) { } else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) { if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@ -9194,6 +9397,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_RWKV_WKV6) { } else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]); src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags; src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],

View file

@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
uvec4 v = bl128.block.q4k[0]; uvec4 v = bl128.block.q4k[0];
const f16vec2 loadd = unpackFloat2x16(v.x); const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc; uint32_t sc;
uint32_t mbyte; uint32_t mbyte;
@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
sc &= 0x3F; sc &= 0x3F;
mbyte &= 0x3F; mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc); const float d = loadd.x * float(sc);
const float16_t m = loadd.y * float16_t(mbyte); const float m = loadd.y * float(mbyte);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float16_t ret = d * float16_t(qs) - m; float ret = d * float(qs) - m;
return ret; return float16_t(ret);
} }
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {

View file

@ -0,0 +1,41 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
sum[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
}
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
}
}

View file

@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64; layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint M; uint M;
@ -48,6 +52,8 @@ layout (push_constant) uniform parameter
uint broadcast2; uint broadcast2;
uint broadcast3; uint broadcast3;
#endif #endif
// N dimension for the B matrix can be >= p.N
uint padded_N;
} p; } p;
@ -166,15 +172,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split); const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif #endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
uint pos_b = 0; uint pos_b = 0;
#else #else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
uint pos_b = batch_idx * p.batch_stride_b; uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif
uint stride_a = p.stride_a / QUANT_K; uint stride_a = p.stride_a / QUANT_K;
@ -195,6 +199,7 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
#if QUANT_K > 1 #if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@ -202,18 +207,19 @@ void main() {
#endif #endif
// Use end_k rather than p.K as the dimension because that's what // Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k // we need to bound check against when using split_k.
// Bounds check B against padded_N, but bounds check D against N.
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID) #if !defined(MUL_MAT_ID)
// Detect a fast path where all loads are entirely in bounds and no clamping is required // Detect a fast path where all loads are entirely in bounds and no clamping is required
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1 #if QUANT_K == 1
(stride_a % 8) == 0 && (stride_a % 8) == 0 &&
#endif #endif
@ -229,7 +235,40 @@ void main() {
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
uint k_iters = (end_k - start_k + BK - 1) / BK; uint k_iters = (end_k - start_k + BK - 1) / BK;
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
@ -240,6 +279,11 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
} }
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
return;
}
} else } else
#endif // !defined(MUL_MAT_ID) #endif // !defined(MUL_MAT_ID)
{ {
@ -251,6 +295,9 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]] [[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) { for (uint block_k = start_k; block_k < end_k; block_k += BK) {
@ -263,7 +310,7 @@ void main() {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
bool unclampedB = true; bool unclampedB = true;
#else #else
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
#endif #endif
if (unclampedA && unclampedB) { if (unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
@ -293,7 +340,6 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
} }
} }
}
// Convert from ACC_TYPE to D_TYPE // Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
@ -303,9 +349,7 @@ void main() {
// Call callback to store each element, remapping row through shared memory // Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
#else #else
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
#endif #endif
} }
}

View file

@ -434,6 +434,7 @@ void process_shaders() {
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@ -528,6 +529,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) { for (auto &c : compiles) {

View file

@ -0,0 +1,91 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint B;
uint T;
uint C;
uint H;
};
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
void main() {
const uint head_size = BLOCK_SIZE;
const uint batch_id = gl_WorkGroupID.x / H;
const uint head_id = gl_WorkGroupID.x % H;
const uint tid = gl_LocalInvocationID.x;
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
if (batch_id >= B || head_id >= H) {
return;
}
A_TYPE state[BLOCK_SIZE];
[[unroll]] for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
barrier();
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
barrier();
A_TYPE sa = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
sa += dot(s_vec, a_vec);
}
const A_TYPE v_val = v[t];
A_TYPE y = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(r_vec, s_vec);
state[j] = s_vec.x;
state[j+1] = s_vec.y;
state[j+2] = s_vec.z;
state[j+3] = s_vec.w;
}
dst[t] = y;
}
[[unroll]] for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}

View file

@ -942,6 +942,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM", "RMS_NORM",
"RMS_NORM_BACK", "RMS_NORM_BACK",
"GROUP_NORM", "GROUP_NORM",
"L2_NORM",
"MUL_MAT", "MUL_MAT",
"MUL_MAT_ID", "MUL_MAT_ID",
@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ADD_REL_POS", "ADD_REL_POS",
"RWKV_WKV6", "RWKV_WKV6",
"GATED_LINEAR_ATTN", "GATED_LINEAR_ATTN",
"RWKV_WKV7",
"UNARY", "UNARY",
@ -1009,7 +1011,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
}; };
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)", "rms_norm(x)",
"rms_norm_back(x)", "rms_norm_back(x)",
"group_norm(x)", "group_norm(x)",
"l2_norm(x)",
"X*Y", "X*Y",
"X[i]*Y", "X[i]*Y",
@ -1087,6 +1090,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"add_rel_pos(x)", "add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)", "rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, gate, s)", "gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"unary(x)", "unary(x)",
@ -1106,7 +1110,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)", "adamw(x)",
}; };
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -2699,6 +2703,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
return ggml_group_norm_impl(ctx, a, n_groups, eps, true); return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
} }
// ggml_l2_norm
static struct ggml_tensor * ggml_l2_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params_f32(result, 0, eps);
result->op = GGML_OP_L2_NORM;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, true);
}
// ggml_mul_mat // ggml_mul_mat
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@ -4733,6 +4768,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
return result; return result;
} }
// ggml_rwkv_wkv7
struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_context * ctx,
struct ggml_tensor * r,
struct ggml_tensor * w,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(w));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}
// concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_RWKV_WKV7;
result->src[0] = r;
result->src[1] = w;
result->src[2] = k;
result->src[3] = v;
result->src[4] = a;
result->src[5] = b;
result->src[6] = state;
return result;
}
// ggml_unary // ggml_unary
static struct ggml_tensor * ggml_unary_impl( static struct ggml_tensor * ggml_unary_impl(

View file

@ -131,6 +131,10 @@ class Keys:
CAUSAL = "{arch}.attention.causal" CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank" Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window" SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale" SCALE = "{arch}.attention.scale"
@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
STARCODER2 = auto() STARCODER2 = auto()
RWKV6 = auto() RWKV6 = auto()
RWKV6QWEN2 = auto() RWKV6QWEN2 = auto()
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto() MAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_OUT = auto() SSM_OUT = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto() TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto() TIME_MIX_W2 = auto()
TIME_MIX_A0 = auto()
TIME_MIX_A1 = auto()
TIME_MIX_A2 = auto()
TIME_MIX_V0 = auto()
TIME_MIX_V1 = auto()
TIME_MIX_V2 = auto()
TIME_MIX_G1 = auto()
TIME_MIX_G2 = auto()
TIME_MIX_K_K = auto()
TIME_MIX_K_A = auto()
TIME_MIX_R_K = auto()
TIME_MIX_LERP_X = auto() TIME_MIX_LERP_X = auto()
TIME_MIX_LERP_K = auto() TIME_MIX_LERP_K = auto()
TIME_MIX_LERP_V = auto() TIME_MIX_LERP_V = auto()
@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x", MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k", MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.RWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
MODEL_TENSOR.CHANNEL_MIX_KEY,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
MODEL_ARCH.ARWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MAMBA: [ MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -767,6 +767,18 @@ class GGUFWriter:
def add_kv_lora_rank(self, length: int) -> None: def add_kv_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
def add_decay_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
def add_iclr_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
def add_value_residual_mix_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
def add_gate_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
def add_relative_attn_buckets_count(self, value: int) -> None: def add_relative_attn_buckets_count(self, value: int) -> None:
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)

View file

@ -27,7 +27,8 @@ class TensorNameMap:
"embedding.word_embeddings", # chatglm "embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm "transformer.token_embeddings", # openelm
"shared", # t5 "shared", # t5
"rwkv.embeddings", # rwkv "rwkv.embeddings", # rwkv6
"model.embeddings", # rwkv7
), ),
# Token type embeddings # Token type embeddings
@ -42,6 +43,9 @@ class TensorNameMap:
"emb_ln", # nomic-bert "emb_ln", # nomic-bert
"transformer.norm", # openelm "transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv "rwkv.blocks.0.pre_ln", # rwkv
"rwkv.blocks.0.pre_ln", # rwkv6
"model.pre_ln", # rwkv7
"model.layers.0.pre_norm", # rwkv7
"backbone.norm", # wavtokenizer "backbone.norm", # wavtokenizer
), ),
@ -81,7 +85,8 @@ class TensorNameMap:
"encoder.final_layernorm", # chatglm "encoder.final_layernorm", # chatglm
"transformer.norm", # openelm "transformer.norm", # openelm
"model.norm", # nemotron "model.norm", # nemotron
"rwkv.ln_out", # rwkv "rwkv.ln_out", # rwkv6
"model.ln_out", # rwkv7
"backbone.final_layer_norm", # wavtokenizer "backbone.final_layer_norm", # wavtokenizer
), ),
@ -122,14 +127,16 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm "encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm "transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv "rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
), ),
# Attention norm 2 # Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: ( MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b "transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code "encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv "rwkv.blocks.{bid}.ln2", # rwkv6
"model.layers.{bid}.ln2", # rwkv7
), ),
# Attention query-key-value # Attention query-key-value
@ -462,112 +469,174 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",
), ),
MODEL_TENSOR.TIME_MIX_W0: (
"model.layers.{bid}.attention.w0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_W1: ( MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
"model.layers.{bid}.attention.w1", # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_W2: ( MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
"model.layers.{bid}.attention.w2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_A0: (
"model.layers.{bid}.attention.a0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_A1: (
"model.layers.{bid}.attention.a1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_A2: (
"model.layers.{bid}.attention.a2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V0: (
"model.layers.{bid}.attention.v0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V1: (
"model.layers.{bid}.attention.v1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V2: (
"model.layers.{bid}.attention.v2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_G1: (
"model.layers.{bid}.attention.g1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_G2: (
"model.layers.{bid}.attention.g2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_K_K: (
"model.layers.{bid}.attention.k_k", # rwkv7
),
MODEL_TENSOR.TIME_MIX_K_A: (
"model.layers.{bid}.attention.k_a", # rwkv7
),
MODEL_TENSOR.TIME_MIX_R_K: (
"model.layers.{bid}.attention.r_k", # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_LERP_X: ( MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_K: ( MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_V: ( MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_R: ( MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_G: ( MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_W: ( MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_FIRST: ( MODEL_TENSOR.TIME_MIX_FIRST: (
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6 "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
), ),
MODEL_TENSOR.TIME_MIX_DECAY: ( MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay", # rwkv6
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W1: ( MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W2: ( MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_KEY: ( MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv "rwkv.blocks.{bid}.attention.key", # rwkv6
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
"model.layers.{bid}.attention.key", # rwkv7
"model.layers.{bid}.attention.k_proj", # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_VALUE: ( MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv "rwkv.blocks.{bid}.attention.value", # rwkv6
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
"model.layers.{bid}.attention.value", # rwkv7
"model.layers.{bid}.attention.v_proj", # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv "rwkv.blocks.{bid}.attention.receptance", # rwkv6
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
"model.layers.{bid}.attention.receptance", # rwkv7
"model.layers.{bid}.attention.r_proj", # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_GATE: ( MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv "rwkv.blocks.{bid}.attention.gate", # rwkv6
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2 "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LN: ( MODEL_TENSOR.TIME_MIX_LN: (
"rwkv.blocks.{bid}.attention.ln_x", # rwkv "rwkv.blocks.{bid}.attention.ln_x", # rwkv6
"model.layers.{bid}.attention.ln_x" # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_OUTPUT: ( MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv "rwkv.blocks.{bid}.attention.output", # rwkv6
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
"model.layers.{bid}.attention.output", # rwkv7
"model.layers.{bid}.attention.o_proj", # rwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
"model.layers.{bid}.feed_forward.x_k", # rwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
), ),
MODEL_TENSOR.CHANNEL_MIX_KEY: ( MODEL_TENSOR.CHANNEL_MIX_KEY: (
"rwkv.blocks.{bid}.feed_forward.key", # rwkv "rwkv.blocks.{bid}.feed_forward.key", # rwkv6
"model.layers.{bid}.feed_forward.key", # rwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
), ),
MODEL_TENSOR.CHANNEL_MIX_VALUE: ( MODEL_TENSOR.CHANNEL_MIX_VALUE: (
"rwkv.blocks.{bid}.feed_forward.value", # rwkv "rwkv.blocks.{bid}.feed_forward.value", # rwkv6
"model.layers.{bid}.feed_forward.value", # rwkv7
), ),
MODEL_TENSOR.ATTN_Q_A: ( MODEL_TENSOR.ATTN_Q_A: (

View file

@ -154,7 +154,12 @@ class SpecialVocab:
return True return True
with open(tokenizer_config_file, encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
chat_template = tokenizer_config.get('chat_template') chat_template_alt = None
chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file():
with open(chat_template_file, encoding = 'utf-8') as f:
chat_template_alt = json.load(f).get('chat_template')
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
if chat_template is None or isinstance(chat_template, (str, list)): if chat_template is None or isinstance(chat_template, (str, list)):
self.chat_template = chat_template self.chat_template = chat_template
else: else:

View file

@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
--> -->
<script> <script>
const LITEVER = 224; const LITEVER = 225;
const urlParams = new URLSearchParams(window.location.search); const urlParams = new URLSearchParams(window.location.search);
var localflag = true; var localflag = true;
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_"; const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
@ -2941,6 +2941,7 @@ Current version indicated by LITEVER below.
const OAI_TTS_ID = 1002; const OAI_TTS_ID = 1002;
const KCPP_TTS_ID = 1003; const KCPP_TTS_ID = 1003;
const HD_RES_PX = 768; const HD_RES_PX = 768;
const VHD_RES_PX = 960;
const NO_HD_RES_PX = 512; const NO_HD_RES_PX = 512;
const AVATAR_PX = 384; const AVATAR_PX = 384;
const SAVE_SLOTS = 8; const SAVE_SLOTS = 8;
@ -6271,7 +6272,7 @@ Current version indicated by LITEVER below.
document.getElementById("enhancedchatinterface").classList.add("transparentbg"); document.getElementById("enhancedchatinterface").classList.add("transparentbg");
document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg"); document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg");
indexeddb_save("bgimg", compressedImageURI); indexeddb_save("bgimg", compressedImageURI);
}, true, false, 1024, 0.5); }, false, 1024, 0.5);
} }
}; };
function clear_bg_img() function clear_bg_img()
@ -6564,7 +6565,7 @@ Current version indicated by LITEVER below.
document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2); document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2);
refreshAestheticPreview(true); refreshAestheticPreview(true);
render_gametext(); render_gametext();
}, true, true, AVATAR_PX); }, true, AVATAR_PX);
} }
else { else {
//attempt to read as WEBP //attempt to read as WEBP
@ -7838,7 +7839,7 @@ Current version indicated by LITEVER below.
temp_scenario.image = compressedImageURI; temp_scenario.image = compressedImageURI;
temp_scenario.image_aspect = aspectratio; temp_scenario.image_aspect = aspectratio;
preview_temp_scenario(); preview_temp_scenario();
}, true, true, AVATAR_PX); }, true, AVATAR_PX);
}) })
.catch(error => { .catch(error => {
console.error(error); console.error(error);
@ -7907,7 +7908,7 @@ Current version indicated by LITEVER below.
temp_scenario.image = compressedImageURI; temp_scenario.image = compressedImageURI;
temp_scenario.image_aspect = aspectratio; temp_scenario.image_aspect = aspectratio;
preview_temp_scenario(); preview_temp_scenario();
}, true, true, AVATAR_PX); }, true, AVATAR_PX);
} }
}else{ }else{
throw new Error("Selected scenario is invalid."); throw new Error("Selected scenario is invalid.");
@ -7986,7 +7987,7 @@ Current version indicated by LITEVER below.
temp_scenario.image = compressedImageURI; temp_scenario.image = compressedImageURI;
temp_scenario.image_aspect = aspectratio; temp_scenario.image_aspect = aspectratio;
preview_temp_scenario(); preview_temp_scenario();
}, true, true, AVATAR_PX); }, true, AVATAR_PX);
}) })
.catch(error => { .catch(error => {
throw new Error("Selected scenario is invalid."); throw new Error("Selected scenario is invalid.");
@ -10045,6 +10046,7 @@ Current version indicated by LITEVER below.
} }
} }
var adminrebootflag = 0;
function trigger_admin_reload() function trigger_admin_reload()
{ {
document.getElementById("admincontainer").classList.add("hidden"); document.getElementById("admincontainer").classList.add("hidden");
@ -10070,9 +10072,32 @@ Current version indicated by LITEVER below.
.then(values => { .then(values => {
let success = (values && values.success); let success = (values && values.success);
if (success) { if (success) {
msgbox("KoboldCpp is now restarting!\n\nIt may take some time before the new instance is ready to use. Please wait a moment, then press OK to refresh the page.", "KoboldCpp Reload Started", false,false,()=>{ msgbox("KoboldCpp is now restarting!\n\nIt may take some time before the new instance is ready to use.\n\nYour browser should automatically refresh after a few moments...", "KoboldCpp Reload Started", false,true);
setInterval(function () {
++adminrebootflag;
if(adminrebootflag>1 && adminrebootflag<20)
{
fetch(apply_proxy_url(custom_kobold_endpoint + kobold_custom_version_endpoint),
{
method: 'GET',
headers: get_kobold_header(),
})
.then(response => response.json())
.then((data) => {
adminrebootflag = 999;
location.reload(true); location.reload(true);
})
.catch((error) => {
if(adminrebootflag>2)
{
adminrebootflag -= 1;
}
console.error('Not Ready to Restart:', error);
}); });
}
}, 3000);
} else { } else {
msgbox("The request to reload KoboldCpp with a new configuration failed!\n\nPlease check if the feature is enabled, the admin directory is set, and selected config and password are correct.", "KoboldCpp Reload Failed"); msgbox("The request to reload KoboldCpp with a new configuration failed!\n\nPlease check if the feature is enabled, the admin directory is set, and selected config and password are correct.", "KoboldCpp Reload Failed");
} }
@ -12764,7 +12789,7 @@ Current version indicated by LITEVER below.
image_db[imgid] = { done: false, queue: "Generating", result: "", prompt:"", poll_category:0 }; image_db[imgid] = { done: false, queue: "Generating", result: "", prompt:"", poll_category:0 };
image_db[imgid].aspect = 0; image_db[imgid].aspect = 0;
image_db[imgid].imsource = 1; //0=generated,1=uploaded image_db[imgid].imsource = 1; //0=generated,1=uploaded
let imgres = localsettings.img_allowhd?HD_RES_PX:NO_HD_RES_PX; let imgres = localsettings.img_allowhd?VHD_RES_PX:NO_HD_RES_PX;
compressImage(origImg, (newDataUri, outAspect) => { compressImage(origImg, (newDataUri, outAspect) => {
image_db[imgid].done = true; image_db[imgid].done = true;
image_db[imgid].result = newDataUri; image_db[imgid].result = newDataUri;
@ -12784,7 +12809,7 @@ Current version indicated by LITEVER below.
{ {
image_db[imgid].aspect = 2; //landscape image_db[imgid].aspect = 2; //landscape
} }
}, true, false, imgres,0.35,true); }, false, imgres,0.35,true);
} }
function clear_paste_window() function clear_paste_window()
@ -15375,7 +15400,7 @@ Current version indicated by LITEVER below.
compressImage(origImg, (newDataUri) => { compressImage(origImg, (newDataUri) => {
image_db[imgid].done = true; image_db[imgid].done = true;
image_db[imgid].result = newDataUri; image_db[imgid].result = newDataUri;
}, true, false, imgres,0.35,false); }, false, imgres);
}else{ }else{
image_db[imgid].queue = "Failed"; image_db[imgid].queue = "Failed";
msgbox("Image Generation Failed!\n\nPlease make sure KoboldCpp / Forge / A1111 is running and properly configured!\nIn your local install of Automatic1111 WebUi, modify webui-user.bat and add these flags to enable API access:\n\nset COMMANDLINE_ARGS= --api --listen --cors-allow-origins=*\n"); msgbox("Image Generation Failed!\n\nPlease make sure KoboldCpp / Forge / A1111 is running and properly configured!\nIn your local install of Automatic1111 WebUi, modify webui-user.bat and add these flags to enable API access:\n\nset COMMANDLINE_ARGS= --api --listen --cors-allow-origins=*\n");
@ -15415,7 +15440,7 @@ Current version indicated by LITEVER below.
compressImage(origImg, (newDataUri) => { compressImage(origImg, (newDataUri) => {
image_db[imgid].done = true; image_db[imgid].done = true;
image_db[imgid].result = newDataUri; image_db[imgid].result = newDataUri;
}, true, true, imgres,0.35,false); }, true, imgres);
}else{ }else{
image_db[imgid].queue = "Failed"; image_db[imgid].queue = "Failed";
msgbox("Image Generation Failed!\n\nPlease make sure your OpenAI key is set correctly and you are allowed to use DALL-E.\n"); msgbox("Image Generation Failed!\n\nPlease make sure your OpenAI key is set correctly and you are allowed to use DALL-E.\n");
@ -16131,7 +16156,7 @@ Current version indicated by LITEVER below.
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX; let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
compressImage(origImg, (newDataUri) => { compressImage(origImg, (newDataUri) => {
img.result = newDataUri; img.result = newDataUri;
}, true, false, imgres,0.35,false); }, false, imgres);
} }
}) })
.catch((error) => { .catch((error) => {
@ -16181,7 +16206,7 @@ Current version indicated by LITEVER below.
let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX; let imgres = localsettings.img_allowhd?(localsettings.img_aspect==0?NO_HD_RES_PX:HD_RES_PX):NO_HD_RES_PX;
compressImage(origImg, (newDataUri) => { compressImage(origImg, (newDataUri) => {
img.result = newDataUri; img.result = newDataUri;
}, true, false, imgres,0.35,false); }, false, imgres);
}; };
reader.readAsDataURL(finalimg); reader.readAsDataURL(finalimg);
}) })
@ -16235,10 +16260,11 @@ Current version indicated by LITEVER below.
} }
} }
function compressImage(inputDataUri, onDone, isJpeg=true, fixedSize=true, maxSize=NO_HD_RES_PX, quality = 0.35, forceAspect=false) { function compressImage(inputDataUri, onDone, fixedSize=true, maxSize=NO_HD_RES_PX, quality = 0.35, letterboxAspect=false) {
let img = document.createElement('img'); let img = document.createElement('img');
let wantedWidth = maxSize; let wantedWidth = maxSize;
let wantedHeight = maxSize; let wantedHeight = maxSize;
const isJpeg = true;
// When the event "onload" is triggered we can resize the image. // When the event "onload" is triggered we can resize the image.
img.onload = function () { img.onload = function () {
@ -16269,44 +16295,54 @@ Current version indicated by LITEVER below.
canvas.height = wantedHeight; canvas.height = wantedHeight;
// We resize the image with the canvas method // We resize the image with the canvas method
if(forceAspect) if(letterboxAspect)
{ {
let minsizeW = Math.min(origW, origH); let minsizeW = Math.min(origW, origH);
let minsizeH = Math.min(origW, origH); let minsizeH = Math.min(origW, origH);
let targetMaxSize = maxSize;
//a bit of a hack, but if the input image is much smaller than the target canvas, we can use a smaller canvas
if(targetMaxSize>=VHD_RES_PX && origW<=HD_RES_PX && origH<=HD_RES_PX)
{
targetMaxSize = HD_RES_PX;
}
if(targetMaxSize>=HD_RES_PX && origW<=NO_HD_RES_PX && origH<=NO_HD_RES_PX)
{
targetMaxSize = NO_HD_RES_PX;
}
if(aspectratio<=0.5) if(aspectratio<=0.5)
{ {
//portrait //portrait
minsizeH *= 2; minsizeH *= 2;
canvas.width = wantedWidth = maxSize/2; canvas.width = wantedWidth = targetMaxSize/2;
canvas.height = wantedHeight = maxSize; canvas.height = wantedHeight = targetMaxSize;
} }
else if(aspectratio<0.7) else if(aspectratio<0.7)
{ {
//portrait //portrait
minsizeH *= 1.5; minsizeH *= 1.5;
canvas.width = wantedWidth = maxSize/1.5; canvas.width = wantedWidth = targetMaxSize/1.5;
canvas.height = wantedHeight = maxSize; canvas.height = wantedHeight = targetMaxSize;
} }
else if(aspectratio>=2) else if(aspectratio>=2)
{ {
//landscape //landscape
minsizeW *= 2; minsizeW *= 2;
canvas.width = wantedWidth = maxSize; canvas.width = wantedWidth = targetMaxSize;
canvas.height = wantedHeight = maxSize/2; canvas.height = wantedHeight = targetMaxSize/2;
} }
else if(aspectratio>1.4) else if(aspectratio>1.4)
{ {
//landscape //landscape
minsizeW *= 1.5; minsizeW *= 1.5;
canvas.width = wantedWidth = maxSize; canvas.width = wantedWidth = targetMaxSize;
canvas.height = wantedHeight = maxSize/1.5; canvas.height = wantedHeight = targetMaxSize/1.5;
} }
else else
{ {
//square //square
canvas.width = wantedWidth = maxSize; canvas.width = wantedWidth = targetMaxSize;
canvas.height = wantedHeight = maxSize; canvas.height = wantedHeight = targetMaxSize;
} }
let newWidth, newHeight, mx, my; let newWidth, newHeight, mx, my;
@ -19639,7 +19675,7 @@ Current version indicated by LITEVER below.
const file = event.target.files[0]; const file = event.target.files[0];
const reader = new FileReader(); const reader = new FileReader();
reader.onload = function(img) { reader.onload = function(img) {
compressImage(img.target.result, loadCompressedImage, true, true, AVATAR_PX); compressImage(img.target.result, loadCompressedImage, true, AVATAR_PX);
function loadCompressedImage(compressedImageURI, aspectratio) { function loadCompressedImage(compressedImageURI, aspectratio) {
if(isSelfPortrait) if(isSelfPortrait)
@ -22412,6 +22448,8 @@ Current version indicated by LITEVER below.
<option value="command-r-plus">command-r-plus</option> <option value="command-r-plus">command-r-plus</option>
<option value="command-r-08-2024">command-r-08-2024</option> <option value="command-r-08-2024">command-r-08-2024</option>
<option value="command-r-plus-08-2024">command-r-plus-08-2024</option> <option value="command-r-plus-08-2024">command-r-plus-08-2024</option>
<option value="command-r7b-12-2024">command-r7b-12-2024</option>
<option value="command-a-03-2025">command-a-03-2025</option>
</select> </select>
<span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br> <span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br>
<input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br> <input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>

View file

@ -49,7 +49,7 @@ logit_bias_max = 512
dry_seq_break_max = 128 dry_seq_break_max = 128
# global vars # global vars
KcppVersion = "1.86.2" KcppVersion = "1.87"
showdebug = True showdebug = True
kcpp_instance = None #global running instance kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False} global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False}

View file

@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
@ -123,6 +125,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
@ -1238,6 +1244,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_RWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
},
},
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {
@ -1397,6 +1471,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1415,6 +1495,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@ -1422,6 +1505,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View file

@ -63,6 +63,8 @@ enum llm_arch {
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
@ -127,6 +129,10 @@ enum llm_kv {
LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_DECAY_LORA_RANK,
LLM_KV_ATTENTION_ICLR_LORA_RANK,
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
@ -250,8 +256,20 @@ enum llm_tensor {
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2, LLM_TENSOR_TIME_MIX_W2,
LLM_TENSOR_TIME_MIX_A0,
LLM_TENSOR_TIME_MIX_A1,
LLM_TENSOR_TIME_MIX_A2,
LLM_TENSOR_TIME_MIX_V0,
LLM_TENSOR_TIME_MIX_V1,
LLM_TENSOR_TIME_MIX_V2,
LLM_TENSOR_TIME_MIX_G1,
LLM_TENSOR_TIME_MIX_G2,
LLM_TENSOR_TIME_MIX_K_K,
LLM_TENSOR_TIME_MIX_K_A,
LLM_TENSOR_TIME_MIX_R_K,
LLM_TENSOR_TIME_MIX_LERP_X, LLM_TENSOR_TIME_MIX_LERP_X,
LLM_TENSOR_TIME_MIX_LERP_W, LLM_TENSOR_TIME_MIX_LERP_W,
LLM_TENSOR_TIME_MIX_LERP_K, LLM_TENSOR_TIME_MIX_LERP_K,

View file

@ -285,11 +285,15 @@ llama_context::llama_context(
// reserve worst-case graph // reserve worst-case graph
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
uint32_t n_seqs = 1; // TODO: worst-case number of sequences const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
// restore later
// TODO: something cleaner
const auto n_outputs_save = n_outputs;
// max number of outputs // max number of outputs
n_outputs = n_tokens; n_outputs = n_tokens;
@ -341,6 +345,8 @@ llama_context::llama_context(
} }
} }
n_outputs = n_outputs_save;
for (size_t i = 0; i < backend_ptrs.size(); ++i) { for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i]; ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i]; ggml_backend_buffer_type_t buft = backend_buft[i];
@ -1052,6 +1058,13 @@ int llama_context::encode(llama_batch & inp_batch) {
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
const auto causal_attn_org = cparams.causal_attn;
// always use non-causal attention for encoder graphs
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
cparams.causal_attn = false;
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
@ -1059,6 +1072,8 @@ int llama_context::encode(llama_batch & inp_batch) {
res->set_inputs(&ubatch); res->set_inputs(&ubatch);
cparams.causal_attn = causal_attn_org;
const auto compute_status = graph_compute(gf, n_tokens > 1); const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) { switch (compute_status) {
case GGML_STATUS_SUCCESS: case GGML_STATUS_SUCCESS:
@ -1129,6 +1144,8 @@ int llama_context::encode(llama_batch & inp_batch) {
if (model.arch == LLM_ARCH_T5 && t_embd) { if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd; //cross.t_embd = t_embd;
synchronize();
cross.n_embd = t_embd->ne[0]; cross.n_embd = t_embd->ne[0];
cross.n_enc = t_embd->ne[1]; cross.n_enc = t_embd->ne[1];
cross.v_embd.resize(cross.n_embd*cross.n_enc); cross.v_embd.resize(cross.n_embd*cross.n_enc);

View file

@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
// note: storing RoPE-ed version of K in the KV cache // note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
ggml_tensor * v_cache_view = nullptr; ggml_tensor * v_cache_view = nullptr;

View file

@ -487,9 +487,9 @@ struct llm_graph_context {
ggml_tensor * build_attn_mha( ggml_tensor * build_attn_mha(
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * q, ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * kq_mask, ggml_tensor * kq_mask,
bool v_trans, bool v_trans,
@ -502,9 +502,9 @@ struct llm_graph_context {
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * wo, ggml_tensor * wo,
ggml_tensor * wo_b, ggml_tensor * wo_b,
ggml_tensor * q_cur, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
float kq_scale, float kq_scale,
int il) const; int il) const;
@ -516,9 +516,9 @@ struct llm_graph_context {
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * wo, ggml_tensor * wo,
ggml_tensor * wo_b, ggml_tensor * wo_b,
ggml_tensor * q_cur, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
float kq_scale, float kq_scale,
int il) const; int il) const;
@ -530,9 +530,9 @@ struct llm_graph_context {
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * wo, ggml_tensor * wo,
ggml_tensor * wo_b, ggml_tensor * wo_b,
ggml_tensor * q_cur, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
float kq_scale, float kq_scale,
int il) const; int il) const;

View file

@ -76,6 +76,10 @@ struct llama_hparams {
uint32_t time_decay_extra_dim = 0; uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0; uint32_t wkv_head_size = 0;
uint32_t token_shift_count = 2; uint32_t token_shift_count = 2;
uint32_t n_lora_decay = 0;
uint32_t n_lora_iclr = 0;
uint32_t n_lora_value_res_mix = 0;
uint32_t n_lora_gate = 0;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;

File diff suppressed because it is too large Load diff

View file

@ -29,6 +29,7 @@ enum llm_type {
LLM_TYPE_109M, LLM_TYPE_109M,
LLM_TYPE_137M, LLM_TYPE_137M,
LLM_TYPE_160M, LLM_TYPE_160M,
LLM_TYPE_190M,
LLM_TYPE_220M, LLM_TYPE_220M,
LLM_TYPE_250M, LLM_TYPE_250M,
LLM_TYPE_270M, LLM_TYPE_270M,
@ -45,6 +46,7 @@ enum llm_type {
LLM_TYPE_1_6B, LLM_TYPE_1_6B,
LLM_TYPE_2B, LLM_TYPE_2B,
LLM_TYPE_2_8B, LLM_TYPE_2_8B,
LLM_TYPE_2_9B,
LLM_TYPE_3B, LLM_TYPE_3B,
LLM_TYPE_4B, LLM_TYPE_4B,
LLM_TYPE_6B, LLM_TYPE_6B,
@ -260,6 +262,20 @@ struct llama_layer {
struct ggml_tensor * time_mix_receptance_b = nullptr; struct ggml_tensor * time_mix_receptance_b = nullptr;
struct ggml_tensor * time_mix_gate = nullptr; struct ggml_tensor * time_mix_gate = nullptr;
// rwkv7
struct ggml_tensor * time_mix_w0 = nullptr;
struct ggml_tensor * time_mix_a0 = nullptr;
struct ggml_tensor * time_mix_a1 = nullptr;
struct ggml_tensor * time_mix_a2 = nullptr;
struct ggml_tensor * time_mix_v0 = nullptr;
struct ggml_tensor * time_mix_v1 = nullptr;
struct ggml_tensor * time_mix_v2 = nullptr;
struct ggml_tensor * time_mix_g1 = nullptr;
struct ggml_tensor * time_mix_g2 = nullptr;
struct ggml_tensor * time_mix_k_k = nullptr;
struct ggml_tensor * time_mix_k_a = nullptr;
struct ggml_tensor * time_mix_r_k = nullptr;
struct ggml_tensor * time_mix_ln = nullptr; struct ggml_tensor * time_mix_ln = nullptr;
struct ggml_tensor * time_mix_ln_b = nullptr; struct ggml_tensor * time_mix_ln_b = nullptr;
struct ggml_tensor * time_mix_output = nullptr; struct ggml_tensor * time_mix_output = nullptr;

View file

@ -759,10 +759,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// NOTE: can't use LLM_TN here because the layer number is not known // NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos; quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
// do not quantize RWKV's time_mix_first tensors // do not quantize RWKV's small yet 2D weights
quantize &= name.find("time_mix_first.weight") == std::string::npos; quantize &= name.find("time_mix_first.weight") == std::string::npos;
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
quantize &= name.find("time_mix_w1.weight") == std::string::npos; quantize &= name.find("time_mix_w1.weight") == std::string::npos;
quantize &= name.find("time_mix_w2.weight") == std::string::npos; quantize &= name.find("time_mix_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;