-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_util.h
More file actions
45 lines (43 loc) · 1.97 KB
/
test_util.h
File metadata and controls
45 lines (43 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <tt/device.h>
#include <tt/scalar.h>
#include <tt/tensor.h>
template <typename T, typename F, typename... Args>
void runner_single_type(const F &f, Args... args) {
f.template operator()<T>(tinytensor::kCPU, args...);
#ifdef TT_CUDA
f.template operator()<T>(tinytensor::kCUDA, args...);
#endif
}
template <typename F, typename... Args>
void runner_boolean(F &&f, Args... args) {
runner_single_type<bool>(std::forward<F>(f), args...);
}
template <typename F, typename... Args>
void runner_integral(F &&f, Args... args) {
runner_single_type<tinytensor::to_ctype_t<tinytensor::kU8>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI16>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI32>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI64>>(std::forward<F>(f), args...);
}
template <typename F, typename... Args>
void runner_signed_integral(F &&f, Args... args) {
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI16>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI32>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kI64>>(std::forward<F>(f), args...);
}
template <typename F, typename... Args>
void runner_floating_point(F &&f, Args... args) {
runner_single_type<tinytensor::to_ctype_t<tinytensor::kF32>>(std::forward<F>(f), args...);
runner_single_type<tinytensor::to_ctype_t<tinytensor::kF64>>(std::forward<F>(f), args...);
}
template <typename F, typename... Args>
void runner_all(F &&f, Args... args) {
runner_boolean(std::forward<F>(f), args...);
runner_integral(std::forward<F>(f), args...);
runner_floating_point(std::forward<F>(f), args...);
}
template <typename F, typename... Args>
void runner_all_except_bool(F &&f, Args... args) {
runner_integral(std::forward<F>(f), args...);
runner_floating_point(std::forward<F>(f), args...);
}