-
Notifications
You must be signed in to change notification settings - Fork 270
Description
I have a very particular problem, whose variations others might find useful as well.
I need to look up the indices of unsigned integers from a large constant array (typical sizes between 1e7-4e9) on the GPU from GPU threads: The elements of the array are known to be unique (and can be sorted beforehand). I think a static map from https://github.com/NVIDIA/cuCollections would be ideal for my use case. I want to be able to call from an arbitrary thread on the GPU the following function: Given x (which is guaranteed to be a key), what value does it correspond to?
Describe the solution you'd like
Ideally, I would like to call a function in CUDA.jl that can build a hash table for me based on arrays of keys and values, and then access the hash table from within a kernel.
Describe alternatives you've considered
I have tried to use ChatGPT / examples from cuCollections to piece together some C++ code that I can build into a library and access from julia using ccall; but I keep running into issues: In particular, I do not understand how to pass the pointer of the static_map to julia --- In julia, I would have just returned a function like lookup(x) which would give out my value.
// hash_lookup.cu
#include <cuco/static_map.cuh>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <cuda/std/limits>
#include <cuda_runtime.h>
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <unordered_set>
#include <random>
#include <vector>
#include <cmath>
#include <type_traits>
#include <utility>
inline constexpr double kLoadFactor = 0.5;
// ---------------- CUDA error helper ----------------
#define CUDA_CHECK(call) do { \
cudaError_t _e = (call); \
if (_e != cudaSuccess) { \
fprintf(stderr,"CUDA error %s:%d: %s\n", \
__FILE__, __LINE__, cudaGetErrorString(_e)); \
std::exit(1); \
} \
} while(0)
// ---------------- map builder (from DEVICE keys): key -> index ----------------
template <class Key, class Value>
cuco::static_map<Key, Value>
build_map_from_device_keys(const Key* d_keys, std::size_t num_keys)
{
static_assert(std::is_unsigned_v<Value>, "Value should hold indices (e.g., uint32_t/uint64_t).");
constexpr Key kEmptyKey = Key{0}; // you guaranteed 0 never appears
constexpr Value kEmptyValue = cuda::std::numeric_limits<Value>::max();
const std::size_t capacity = static_cast<std::size_t>(std::ceil(num_keys / kLoadFactor));
cuco::static_map<Key, Value> map{
capacity, cuco::empty_key{kEmptyKey}, cuco::empty_value{kEmptyValue}};
// Iterator of pairs { d_keys[i], Value(i) }
auto pairs = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
[k = d_keys] __device__ (std::size_t i) {
return cuco::pair<Key, Value>{k[i], static_cast<Value>(i)};
});
map.insert(pairs, pairs + num_keys);
return map; // map owns its device memory; d_keys can be freed after insert
}
// ---------------- device lookup + kernel (pass device-view by arg) ----------------
template <class Ref, class Key>
__device__ __forceinline__
auto map_lookup(Ref ref, Key k) -> decltype(ref.find(k)->second) {
return ref.find(k)->second; // key guaranteed present
}
template <class Ref, class Key, class Value>
__global__ void lookup_kernel(Ref ref, const Key* query_keys, Value* out, std::size_t N)
{
std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) out[i] = map_lookup(ref, query_keys[i]);
}
// // ---- detect how to get the device view across cuco versions (C++17) ----
template <class M, class = void>
struct has_ref_method : std::false_type {};
template <class M>
struct has_ref_method<M, std::void_t<decltype(std::declval<M&>().ref())>> : std::true_type {};
template <class M>
auto get_ref(M& m) {
return m.ref(cuco::find);
// return m;
// if constexpr (has_ref_method<M>::value) {
// return m.ref(cuco::find); // newer cuco
// } else {
// return m.get_device_view(); // older cuco
// }
// return m.get_device_view();
}
// ---------------- main test ----------------
int main() {
using Key = uint32_t;
using Value = uint32_t;
constexpr std::size_t N = 10'000;
// 1) generate N unique, non-zero random keys
std::mt19937_64 rng{1234567};
std::uniform_int_distribution<uint32_t> dist(1u, 0xFFFFFFFEu); // exclude 0 sentinel
std::unordered_set<uint32_t> used;
used.reserve(N*2);
std::vector<uint32_t> h_keys;
h_keys.reserve(N);
while (h_keys.size() < N) {
uint32_t k = dist(rng);
if (used.insert(k).second) h_keys.push_back(k);
}
// optional: shuffle
std::shuffle(h_keys.begin(), h_keys.end(), rng);
// 2) copy keys to device
Key* d_keys = nullptr;
CUDA_CHECK(cudaMalloc(&d_keys, N * sizeof(Key)));
CUDA_CHECK(cudaMemcpy(d_keys, h_keys.data(), N * sizeof(Key), cudaMemcpyHostToDevice));
// 3) build the map on host (uses device pointers internally)
auto map = build_map_from_device_keys<Key, Value>(d_keys, N);
// // 4) get device view/handle (works across cuco versions via get_ref)
auto ref = get_ref(map);
// 5) device output buffer
Value* d_out = nullptr;
CUDA_CHECK(cudaMalloc(&d_out, N * sizeof(Value)));
// 6) launch lookup kernel
dim3 block(256);
dim3 grid(static_cast<unsigned>((N + block.x - 1) / block.x));
lookup_kernel<<<grid, block>>>(ref, d_keys, d_out, N);
CUDA_CHECK(cudaPeekAtLastError());
CUDA_CHECK(cudaDeviceSynchronize());
// 7) copy results back
std::vector<Value> h_out(N);
CUDA_CHECK(cudaMemcpy(h_out.data(), d_out, N * sizeof(Value), cudaMemcpyDeviceToHost));
// 8) verify: since we queried with the original keys in order,
// we expect map[key[i]] == i for all i
std::size_t mismatches = 0;
for (std::size_t i = 0; i < N; ++i) {
if (h_out[i] != static_cast<Value>(i)) {
if (mismatches < 10) {
std::fprintf(stderr, "mismatch at i=%zu: got %u, want %u\n",
i, (unsigned)h_out[i], (unsigned)i);
}
++mismatches;
}
}
if (mismatches == 0) {
std::printf("OK: %zu lookups matched indices.\n", N);
} else {
std::printf("FAIL: %zu mismatches out of %zu.\n", mismatches, N);
}
// 9) cleanup
CUDA_CHECK(cudaFree(d_out));
CUDA_CHECK(cudaFree(d_keys));
return (mismatches == 0) ? 0 : 1;
}