|
| 1 | +"""WASM SIMD target for elementwise kernels compiler.""" |
| 2 | + |
| 3 | +# pylint: disable=undefined-variable |
| 4 | +from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import |
| 5 | +from ynnpack.kernels.elementwise.rules import * # pylint: disable=wildcard-import |
| 6 | + |
| 7 | + |
| 8 | +class WASM(Target): |
| 9 | + """WASM SIMD target for elementwise kernels compiler.""" |
| 10 | + |
| 11 | + def update_for_simd128(self): |
| 12 | + """Updates the target for WASM SIMD128 support.""" |
| 13 | + self.header += """ |
| 14 | +namespace ynn { |
| 15 | +namespace { |
| 16 | +template <> |
| 17 | +YNN_INTRINSIC simd::vec<float, 4> select_greater_than(simd::vec<float, 4> a, simd::vec<float, 4> b, simd::vec<float, 4> c, simd::vec<float, 4> d) { |
| 18 | + v128_t mask = wasm_f32x4_gt(a.v, b.v); |
| 19 | + return simd::vec<float, 4>{wasm_v128_bitselect(c.v, d.v, mask)}; |
| 20 | +} |
| 21 | +} |
| 22 | +} // namespace ynn |
| 23 | +""" |
| 24 | + |
| 25 | + def __init__(self, features): |
| 26 | + Target.__init__(self) |
| 27 | + self.patterns += add_select_rules() |
| 28 | + self.patterns += add_saturating_cast_rules() |
| 29 | + self.patterns += add_shift_rules() |
| 30 | + |
| 31 | + self.features = features |
| 32 | + self.vector_bits = 128 |
| 33 | + self.tail_strategy = TailStrategy.VECTOR |
| 34 | + |
| 35 | + # These are transitive. |
| 36 | + implied_features = { |
| 37 | + "WASM_SIMD128": [], |
| 38 | + } |
| 39 | + all_features = [] |
| 40 | + self.compute_all_features(features, implied_features, all_features) |
| 41 | + |
| 42 | + known_features = ["WASM_SIMD128"] |
| 43 | + for feature in all_features: |
| 44 | + if feature not in known_features: |
| 45 | + raise ValueError(f"Unknown feature: {feature}") |
| 46 | + |
| 47 | + self.header += ( |
| 48 | + '#include "ynnpack/base/simd/wasm_simd128.h"\n' |
| 49 | + ) |
| 50 | + |
| 51 | + if "WASM_SIMD128" in all_features: |
| 52 | + self.update_for_simd128() |
| 53 | + |
| 54 | + def arch_flags(self): |
| 55 | + return "|".join(["arch_flag::" + i.lower() for i in self.features]) |
| 56 | + |
| 57 | + def arch_string(self): |
| 58 | + features_str = "_".join([i.lower() for i in self.features]) |
| 59 | + return features_str |
0 commit comments