Skip to content

Commit 250e89a

Browse files
vksnkxnnpack-bot
authored andcommitted
Add basic WASM SIMD128 support to elementwise kernel generator.
PiperOrigin-RevId: 891805395
1 parent 051c5c1 commit 250e89a

File tree

6 files changed

+79
-0
lines changed

6 files changed

+79
-0
lines changed

ynnpack/kernels/binary/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ ARCHS = [
2727
"x86_avx512",
2828
"x86_avx512bf16",
2929
"arm_neon",
30+
"wasm_simd128",
3031
]
3132

3233
[ynn_generate_src_hdr(

ynnpack/kernels/binary/generator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ def main(argv: Sequence[str]) -> None:
6565
(max_fp32, (8, 1)),
6666
(min_fp32, (8, 1)),
6767
],
68+
"wasm_simd128": [
69+
(add_fp32, (8, 1)),
70+
(subtract_fp32, (8, 1)),
71+
(multiply_fp32, (8, 1)),
72+
# (multiply_int32_fp32, (8, 1)),
73+
(divide_fp32, (8, 1)),
74+
(copysign_fp32, (8, 1)),
75+
(max_fp32, (8, 1)),
76+
(min_fp32, (8, 1)),
77+
],
78+
6879
}[target]
6980

7081
generate_elementwise_kernels(output_src, output_inc, target, kernels)

ynnpack/kernels/binary/kernels.inc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,8 @@
2020
#include "ynnpack/kernels/binary/arm_neon.inc"
2121
#endif // YNN_ARCH_ARM_NEON
2222

23+
#ifdef YNN_ARCH_WASM_SIMD128
24+
#include "ynnpack/kernels/binary/wasm_simd128.inc"
25+
#endif // YNN_ARCH_ARM_NEON
26+
27+

ynnpack/kernels/elementwise/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ py_library(
1313
"compiler.py",
1414
"generator.py",
1515
"rules.py",
16+
"wasm.py",
1617
"x86.py",
1718
],
1819
tags = ["manual"],

ynnpack/kernels/elementwise/generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# pylint: disable=undefined-variable
77
from ynnpack.kernels.elementwise.arm import * # pylint: disable=wildcard-import
8+
from ynnpack.kernels.elementwise.wasm import * # pylint: disable=wildcard-import
89
from ynnpack.kernels.elementwise.x86 import * # pylint: disable=wildcard-import
910

1011
arch_to_target = {
@@ -22,6 +23,7 @@
2223
"arm_neon": ARM(["NEON"]),
2324
"arm_neonfp16": ARM(["NEONFP16"]),
2425
"arm_neon_fma": ARM(["NEON", "FMA"]),
26+
"wasm_simd128": WASM(["WASM_SIMD128"]),
2527
}
2628

2729

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""WASM SIMD target for elementwise kernels compiler."""
2+
3+
# pylint: disable=undefined-variable
4+
from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import
5+
from ynnpack.kernels.elementwise.rules import * # pylint: disable=wildcard-import
6+
7+
8+
class WASM(Target):
9+
"""WASM SIMD target for elementwise kernels compiler."""
10+
11+
def update_for_simd128(self):
12+
"""Updates the target for WASM SIMD128 support."""
13+
self.header += """
14+
namespace ynn {
15+
namespace {
16+
template <>
17+
YNN_INTRINSIC simd::vec<float, 4> select_greater_than(simd::vec<float, 4> a, simd::vec<float, 4> b, simd::vec<float, 4> c, simd::vec<float, 4> d) {
18+
v128_t mask = wasm_f32x4_gt(a.v, b.v);
19+
return simd::vec<float, 4>{wasm_v128_bitselect(c.v, d.v, mask)};
20+
}
21+
}
22+
} // namespace ynn
23+
"""
24+
25+
def __init__(self, features):
26+
Target.__init__(self)
27+
self.patterns += add_select_rules()
28+
self.patterns += add_saturating_cast_rules()
29+
self.patterns += add_shift_rules()
30+
31+
self.features = features
32+
self.vector_bits = 128
33+
self.tail_strategy = TailStrategy.VECTOR
34+
35+
# These are transitive.
36+
implied_features = {
37+
"WASM_SIMD128": [],
38+
}
39+
all_features = []
40+
self.compute_all_features(features, implied_features, all_features)
41+
42+
known_features = ["WASM_SIMD128"]
43+
for feature in all_features:
44+
if feature not in known_features:
45+
raise ValueError(f"Unknown feature: {feature}")
46+
47+
self.header += (
48+
'#include "ynnpack/base/simd/wasm_simd128.h"\n'
49+
)
50+
51+
if "WASM_SIMD128" in all_features:
52+
self.update_for_simd128()
53+
54+
def arch_flags(self):
55+
return "|".join(["arch_flag::" + i.lower() for i in self.features])
56+
57+
def arch_string(self):
58+
features_str = "_".join([i.lower() for i in self.features])
59+
return features_str

0 commit comments

Comments
 (0)