-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathsetup.py
More file actions
103 lines (91 loc) · 3.18 KB
/
setup.py
File metadata and controls
103 lines (91 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Setup script for shapiq package with C extensions."""
from __future__ import annotations
import sys
from pathlib import Path
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
# Extend the default build_ext class to bootstrap numpy installation
# that are needed to build C extensions.
# see https://stackoverflow.com/questions/19919905/how-to-bootstrap-numpy-installation-in-setup-py
class BuildExt(_build_ext):
"""Custom build_ext command to include numpy headers."""
def finalize_options(self) -> None:
"""Finalize options and set numpy setup flag."""
_build_ext.finalize_options(self)
if isinstance(__builtins__, dict):
__builtins__["__NUMPY_SETUP__"] = False
else:
__builtins__.__NUMPY_SETUP__ = False
import numpy as np
self.include_dirs.append(np.get_include())
def get_openmp_flags() -> dict[str, list[str]]:
"""Get OpenMP compiler and linker flags based on platform."""
if sys.platform == "win32": # Windows (MSVC)
return {
"extra_compile_args": ["/std:c++17", "/openmp", "/O2"],
"extra_link_args": [],
"include_dirs": [],
"library_dirs": [],
}
if sys.platform == "darwin": # macOS
# Prefer standard Homebrew libomp locations to avoid subprocess calls in setup.
for brew_prefix in (Path("/opt/homebrew/opt/libomp"), Path("/usr/local/opt/libomp")):
include_dir = brew_prefix / "include"
library_dir = brew_prefix / "lib"
if include_dir.exists() and library_dir.exists():
return {
"extra_compile_args": [
"-std=c++17",
"-Xpreprocessor",
"-fopenmp",
"-O3",
"-ffast-math",
],
"extra_link_args": ["-lomp"],
"include_dirs": [str(include_dir)],
"library_dirs": [str(library_dir)],
}
msg = (
"OpenMP support on macOS requires libomp. Please install it via Homebrew: "
"brew install libomp"
)
raise RuntimeError(msg)
# Linux and others
return {
"extra_compile_args": ["-std=c++17", "-fopenmp", "-O3", "-ffast-math"],
"extra_link_args": ["-fopenmp"],
"include_dirs": [],
"library_dirs": [],
}
ext_modules = [
Extension(
"shapiq.tree.conversion.cext",
sources=[
"src/shapiq/tree/conversion/cext/cext.cc",
],
language="c++",
**get_openmp_flags(),
),
Extension(
"shapiq.tree.interventional.cext",
sources=[
"src/shapiq/tree/interventional/cext/cext.cc",
],
language="c++",
**get_openmp_flags(),
),
Extension(
"shapiq.tree.linear.cext",
sources=[
"src/shapiq/tree/linear/cext/cext.cc",
],
language="c++",
**get_openmp_flags(),
),
]
setup(
name="shapiq",
ext_modules=ext_modules,
setup_requires=["numpy"],
cmdclass={"build_ext": BuildExt},
)