diff --git a/install.sh b/install.sh index c46c506..573826e 100644 --- a/install.sh +++ b/install.sh @@ -31,13 +31,6 @@ pip install -r ktransformers/server/requirements.txt echo "Installing ktransformers" KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation -# XPU-specific fix for triton -if [[ "$DEV_BACKEND" == "xpu" ]]; then - echo "Replacing triton for XPU backend" - pip uninstall -y triton pytorch-triton-xpu || true - pip install pytorch-triton-xpu==3.3.0 --extra-index-url https://download.pytorch.org/whl/xpu -fi - if [[ "$DEV_BACKEND" == "cuda" ]]; then echo "Installing custom_flashinfer for CUDA backend" pip install third_party/custom_flashinfer/ @@ -47,5 +40,4 @@ fi # cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ # patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* - -echo "Installation completed successfully" +echo "Installation completed successfully" \ No newline at end of file diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index dd3a206..25afaef 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -7,4 +7,3 @@ cpufeature; sys_platform == 'win32' or sys_platform == 'Windows' protobuf tiktoken blobfile -triton>=3.2 diff --git a/setup.py b/setup.py index 0961d93..c91d9dc 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,15 @@ except ImportError: MUSA_HOME=None KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() +# 检测 DEV_BACKEND 环境变量 +dev_backend = os.environ.get("DEV_BACKEND", "").lower() +if dev_backend == "xpu": + triton_dep = [ + "pytorch-triton-xpu==3.3.0" + ] +else: + triton_dep = ["triton>=3.2"] + with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1" class CpuInstructInfo: @@ -659,6 +668,7 @@ else: setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), + install_requires=triton_dep, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=ext_modules )