Kt minimax (#1742)

[feat]: fp8 kernel and kt-cli support
This commit is contained in:
ErvinXie 2025-12-24 15:39:44 +08:00 committed by GitHub
parent e7d277d163
commit d8046e1bb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 12111 additions and 2502 deletions

View file

@ -17,6 +17,7 @@ Example:
>>> os.environ['KT_KERNEL_CPU_VARIANT'] = 'avx2'
>>> import kt_kernel # Will use AVX2 variant
"""
import os
import sys
from pathlib import Path
@ -35,82 +36,82 @@ def detect_cpu_features():
str: 'amx', 'avx512', or 'avx2'
"""
# Check environment override
variant = os.environ.get('KT_KERNEL_CPU_VARIANT', '').lower()
if variant in ['amx', 'avx512', 'avx2']:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
if variant in ["amx", "avx512", "avx2"]:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Using environment override: {variant}")
return variant
# Try to read /proc/cpuinfo on Linux
try:
with open('/proc/cpuinfo', 'r') as f:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read().lower()
# Check for AMX support (Intel Sapphire Rapids+)
# AMX requires amx_tile, amx_int8, and amx_bf16
amx_flags = ['amx_tile', 'amx_int8', 'amx_bf16']
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
has_amx = all(flag in cpuinfo for flag in amx_flags)
if has_amx:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
return 'amx'
return "amx"
# Check for AVX512 support
# AVX512F is the foundation for all AVX512 variants
if 'avx512f' in cpuinfo:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if "avx512f" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
return 'avx512'
return "avx512"
# Check for AVX2 support
if 'avx2' in cpuinfo:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if "avx2" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
return 'avx2'
return "avx2"
# Fallback to AVX2 (should be rare on modern CPUs)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
return 'avx2'
return "avx2"
except FileNotFoundError:
# /proc/cpuinfo doesn't exist (not Linux or in container)
# Try cpufeature package as fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] /proc/cpuinfo not found, trying cpufeature package")
try:
import cpufeature
# Check for AMX
if cpufeature.CPUFeature.get('AMX_TILE', False):
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if cpufeature.CPUFeature.get("AMX_TILE", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via cpufeature")
return 'amx'
return "amx"
# Check for AVX512
if cpufeature.CPUFeature.get('AVX512F', False):
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if cpufeature.CPUFeature.get("AVX512F", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via cpufeature")
return 'avx512'
return "avx512"
# Fallback to AVX2
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Using AVX2 fallback via cpufeature")
return 'avx2'
return "avx2"
except ImportError:
# cpufeature not available - ultimate fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] cpufeature not available, using AVX2 fallback")
return 'avx2'
return "avx2"
except Exception as e:
# Any other error - safe fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Error during CPU detection: {e}, using AVX2 fallback")
return 'avx2'
return "avx2"
def load_extension(variant):
@ -148,51 +149,53 @@ def load_extension(variant):
kt_kernel_dir = os.path.dirname(os.path.abspath(__file__))
# Try multi-variant naming first
pattern = os.path.join(kt_kernel_dir, f'_kt_kernel_ext_{variant}.*.so')
pattern = os.path.join(kt_kernel_dir, f"_kt_kernel_ext_{variant}.*.so")
so_files = glob.glob(pattern)
if not so_files:
# Try single-variant naming (fallback for builds without CPUINFER_BUILD_ALL_VARIANTS)
pattern = os.path.join(kt_kernel_dir, 'kt_kernel_ext.*.so')
pattern = os.path.join(kt_kernel_dir, "kt_kernel_ext.*.so")
so_files = glob.glob(pattern)
if so_files:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Multi-variant {variant} not found, using single-variant build")
else:
raise ImportError(f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)")
raise ImportError(
f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)"
)
so_file = so_files[0]
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Loading {variant} from: {so_file}")
# Load the module manually
# The module exports PyInit_kt_kernel_ext, so we use that as the module name
spec = importlib.util.spec_from_file_location('kt_kernel_ext', so_file)
spec = importlib.util.spec_from_file_location("kt_kernel_ext", so_file)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for {so_file}")
ext = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ext)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Successfully loaded {variant.upper()} variant")
return ext
except (ImportError, ModuleNotFoundError, FileNotFoundError) as e:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
# Automatic fallback to next best variant
if variant == 'amx':
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if variant == "amx":
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AMX to AVX512")
return load_extension('avx512')
elif variant == 'avx512':
if os.environ.get('KT_KERNEL_DEBUG') == '1':
return load_extension("avx512")
elif variant == "avx512":
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AVX512 to AVX2")
return load_extension('avx2')
return load_extension("avx2")
else:
# AVX2 is the last fallback - if this fails, we can't continue
raise ImportError(
@ -221,13 +224,13 @@ def initialize():
# Detect CPU features
variant = detect_cpu_features()
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Selected CPU variant: {variant}")
# Load the appropriate extension
ext = load_extension(variant)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Extension module loaded: {ext.__name__}")
return ext, variant