114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
"""
|
|
Copyright (c) 2023 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import platform
|
|
import re
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import List, Mapping
|
|
|
|
import setuptools
|
|
from setuptools.dist import Distribution
|
|
|
|
root = Path(__file__).parent.resolve()
|
|
aot_ops_package_dir = root / "build" / "aot-ops-package-dir"
|
|
enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir())
|
|
|
|
|
|
def write_if_different(path: Path, content: str) -> None:
|
|
if path.exists() and path.read_text() == content:
|
|
return
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_text(content)
|
|
|
|
|
|
def get_version():
|
|
package_version = (root / "version.txt").read_text().strip()
|
|
local_version = os.environ.get("FLASHINFER_LOCAL_VERSION")
|
|
if local_version is None:
|
|
return package_version
|
|
return f"{package_version}+{local_version}"
|
|
|
|
|
|
def generate_build_meta(aot_build_meta: dict) -> None:
|
|
build_meta_str = f"__version__ = {get_version()!r}\n"
|
|
if len(aot_build_meta) != 0:
|
|
build_meta_str += f"build_meta = {aot_build_meta!r}\n"
|
|
write_if_different(root / "flashinfer" / "_build_meta.py", build_meta_str)
|
|
|
|
|
|
ext_modules: List[setuptools.Extension] = []
|
|
cmdclass: Mapping[str, type[setuptools.Command]] = {}
|
|
install_requires = [
|
|
"numpy",
|
|
"torch",
|
|
"ninja",
|
|
"requests",
|
|
"pynvml",
|
|
"einops",
|
|
"click",
|
|
"tqdm",
|
|
"tabulate",
|
|
"packaging>=24.2",
|
|
"nvidia-cudnn-frontend>=1.13.0",
|
|
]
|
|
generate_build_meta({})
|
|
|
|
if enable_aot:
|
|
import torch
|
|
import torch.utils.cpp_extension as torch_cpp_ext
|
|
from packaging.version import Version
|
|
|
|
def get_cuda_version() -> Version:
|
|
if torch_cpp_ext.CUDA_HOME is None:
|
|
nvcc = "nvcc"
|
|
else:
|
|
nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc")
|
|
txt = subprocess.check_output([nvcc, "--version"], text=True)
|
|
return Version(re.findall(r"release (\d+\.\d+),", txt)[0])
|
|
|
|
cuda_version = get_cuda_version()
|
|
torch_full_version = Version(torch.__version__)
|
|
torch_version = f"{torch_full_version.major}.{torch_full_version.minor}"
|
|
install_requires = [req for req in install_requires if not req.startswith("torch ")]
|
|
install_requires.append(f"torch == {torch_version}.*")
|
|
|
|
aot_build_meta = {}
|
|
aot_build_meta["cuda_major"] = cuda_version.major
|
|
aot_build_meta["cuda_minor"] = cuda_version.minor
|
|
aot_build_meta["torch"] = torch_version
|
|
aot_build_meta["python"] = platform.python_version()
|
|
aot_build_meta["FLASHINFER_CUDA_ARCH_LIST"] = os.environ.get(
|
|
"FLASHINFER_CUDA_ARCH_LIST"
|
|
)
|
|
generate_build_meta(aot_build_meta)
|
|
|
|
|
|
class AotDistribution(Distribution):
|
|
def has_ext_modules(self) -> bool:
|
|
return enable_aot
|
|
|
|
|
|
setuptools.setup(
|
|
version=get_version(),
|
|
ext_modules=ext_modules,
|
|
cmdclass=cmdclass,
|
|
install_requires=install_requires,
|
|
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
|
distclass=AotDistribution,
|
|
)
|