-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdevice.h
More file actions
72 lines (59 loc) · 1.72 KB
/
device.h
File metadata and controls
72 lines (59 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// device.h
// Device type
#ifndef TINYTENSOR_DEVICE_H_
#define TINYTENSOR_DEVICE_H_
#include <tt/exception.h>
#include <tt/export.h>
#include <format>
#include <ostream>
#include <string>
namespace tinytensor {
// Supported backend types
enum class TINYTENSOR_EXPORT Backend {
cpu,
#ifdef TT_CUDA
cuda,
#endif
};
TINYTENSOR_EXPORT constexpr auto to_string(Backend backend) -> std::string {
switch (backend) {
case Backend::cpu:
return "cpu";
#ifdef TT_CUDA
case Backend::cuda:
return "cuda";
#endif
}
TT_EXCEPTION("Unknown device type.");
}
// Device is an enum backend + device ID (multi-device support)
struct TINYTENSOR_EXPORT Device {
#ifdef TT_CUDA
constexpr static auto CUDA(int dev_id) -> Device {
return {.backend = Backend::cuda, .id = dev_id};
}
#endif
// Operator!= provided by compiler since C++20
constexpr auto operator==(const Device &other) const -> bool {
return backend == other.backend && id == other.id;
}
Backend backend;
int id;
};
// Shortnames
constexpr Device kCPU = Device{.backend = Backend::cpu, .id = 0};
#ifdef TT_CUDA
constexpr Device kCUDA = Device{.backend = Backend::cuda, .id = 0};
#endif
TINYTENSOR_EXPORT inline auto operator<<(std::ostream &os, const Device &device) -> std::ostream & {
os << to_string(device.backend) << ":" << device.id;
return os;
}
} // namespace tinytensor
template <>
struct TINYTENSOR_EXPORT std::formatter<tinytensor::Device> : std::formatter<std::string> {
auto format(const tinytensor::Device &device, format_context &ctx) const {
return formatter<string>::format(std::format("{}:{}", to_string(device.backend), device.id), ctx);
}
};
#endif // TINYTENSOR_DEVICE_H_