Skip to content

Accessing cuCollections from CUDA.jl --- static_map in particular #2966

@mytraya-gattu

Description

@mytraya-gattu

I have a very particular problem, whose variations others might find useful as well.
I need to look up the indices of unsigned integers from a large constant array (typical sizes between 1e7-4e9) on the GPU from GPU threads: The elements of the array are known to be unique (and can be sorted beforehand). I think a static map from https://github.com/NVIDIA/cuCollections would be ideal for my use case. I want to be able to call from an arbitrary thread on the GPU the following function: Given x (which is guaranteed to be a key), what value does it correspond to?

Describe the solution you'd like

Ideally, I would like to call a function in CUDA.jl that can build a hash table for me based on arrays of keys and values, and then access the hash table from within a kernel.

Describe alternatives you've considered

I have tried to use ChatGPT / examples from cuCollections to piece together some C++ code that I can build into a library and access from julia using ccall; but I keep running into issues: In particular, I do not understand how to pass the pointer of the static_map to julia --- In julia, I would have just returned a function like lookup(x) which would give out my value.

// hash_lookup.cu
#include <cuco/static_map.cuh>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <cuda/std/limits>
#include <cuda_runtime.h>

#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <unordered_set>
#include <random>
#include <vector>
#include <cmath>
#include <type_traits>
#include <utility>

inline constexpr double kLoadFactor = 0.5;

// ---------------- CUDA error helper ----------------
#define CUDA_CHECK(call) do {                                     \
  cudaError_t _e = (call);                                        \
  if (_e != cudaSuccess) {                                        \
    fprintf(stderr,"CUDA error %s:%d: %s\n",                      \
            __FILE__, __LINE__, cudaGetErrorString(_e));          \
    std::exit(1);                                                 \
  }                                                               \
} while(0)

// ---------------- map builder (from DEVICE keys): key -> index ----------------
template <class Key, class Value>
cuco::static_map<Key, Value>
build_map_from_device_keys(const Key* d_keys, std::size_t num_keys)
{
  static_assert(std::is_unsigned_v<Value>, "Value should hold indices (e.g., uint32_t/uint64_t).");

  constexpr Key   kEmptyKey   = Key{0};  // you guaranteed 0 never appears
  constexpr Value kEmptyValue = cuda::std::numeric_limits<Value>::max();

  const std::size_t capacity = static_cast<std::size_t>(std::ceil(num_keys / kLoadFactor));

  cuco::static_map<Key, Value> map{
      capacity, cuco::empty_key{kEmptyKey}, cuco::empty_value{kEmptyValue}};

  // Iterator of pairs { d_keys[i], Value(i) }
  auto pairs = thrust::make_transform_iterator(
      thrust::counting_iterator<std::size_t>{0},
      [k = d_keys] __device__ (std::size_t i) {
        return cuco::pair<Key, Value>{k[i], static_cast<Value>(i)};
      });

  map.insert(pairs, pairs + num_keys);
  return map;  // map owns its device memory; d_keys can be freed after insert
}

// ---------------- device lookup + kernel (pass device-view by arg) ----------------
template <class Ref, class Key>
__device__ __forceinline__
auto map_lookup(Ref ref, Key k) -> decltype(ref.find(k)->second) {
  return ref.find(k)->second;  // key guaranteed present
}

template <class Ref, class Key, class Value>
__global__ void lookup_kernel(Ref ref, const Key* query_keys, Value* out, std::size_t N)
{
  std::size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < N) out[i] = map_lookup(ref, query_keys[i]);
}

// // ---- detect how to get the device view across cuco versions (C++17) ----
template <class M, class = void>
struct has_ref_method : std::false_type {};
template <class M>
struct has_ref_method<M, std::void_t<decltype(std::declval<M&>().ref())>> : std::true_type {};

template <class M>
auto get_ref(M& m) {
    return m.ref(cuco::find);
    // return m;
  // if constexpr (has_ref_method<M>::value) {
  //   return m.ref(cuco::find);              // newer cuco
  // } else {
  //   return m.get_device_view();  // older cuco
  // }
  // return m.get_device_view();
}

// ---------------- main test ----------------
int main() {
  using Key   = uint32_t;
  using Value = uint32_t;

  constexpr std::size_t N = 10'000;

  // 1) generate N unique, non-zero random keys
  std::mt19937_64 rng{1234567};
  std::uniform_int_distribution<uint32_t> dist(1u, 0xFFFFFFFEu); // exclude 0 sentinel

  std::unordered_set<uint32_t> used;
  used.reserve(N*2);
  std::vector<uint32_t> h_keys;
  h_keys.reserve(N);
  while (h_keys.size() < N) {
    uint32_t k = dist(rng);
    if (used.insert(k).second) h_keys.push_back(k);
  }
  // optional: shuffle
  std::shuffle(h_keys.begin(), h_keys.end(), rng);

  // 2) copy keys to device
  Key* d_keys = nullptr;
  CUDA_CHECK(cudaMalloc(&d_keys, N * sizeof(Key)));
  CUDA_CHECK(cudaMemcpy(d_keys, h_keys.data(), N * sizeof(Key), cudaMemcpyHostToDevice));

  // 3) build the map on host (uses device pointers internally)
  auto map = build_map_from_device_keys<Key, Value>(d_keys, N);

//   // 4) get device view/handle (works across cuco versions via get_ref)
  auto ref = get_ref(map);

  // 5) device output buffer
  Value* d_out = nullptr;
  CUDA_CHECK(cudaMalloc(&d_out, N * sizeof(Value)));

  // 6) launch lookup kernel
  dim3 block(256);
  dim3 grid(static_cast<unsigned>((N + block.x - 1) / block.x));
  lookup_kernel<<<grid, block>>>(ref, d_keys, d_out, N);
  CUDA_CHECK(cudaPeekAtLastError());
  CUDA_CHECK(cudaDeviceSynchronize());

  // 7) copy results back
  std::vector<Value> h_out(N);
  CUDA_CHECK(cudaMemcpy(h_out.data(), d_out, N * sizeof(Value), cudaMemcpyDeviceToHost));

  // 8) verify: since we queried with the original keys in order,
  //    we expect map[key[i]] == i for all i
  std::size_t mismatches = 0;
  for (std::size_t i = 0; i < N; ++i) {
    if (h_out[i] != static_cast<Value>(i)) {
      if (mismatches < 10) {
        std::fprintf(stderr, "mismatch at i=%zu: got %u, want %u\n",
                     i, (unsigned)h_out[i], (unsigned)i);
      }
      ++mismatches;
    }
  }

  if (mismatches == 0) {
    std::printf("OK: %zu lookups matched indices.\n", N);
  } else {
    std::printf("FAIL: %zu mismatches out of %zu.\n", mismatches, N);
  }

  // 9) cleanup
  CUDA_CHECK(cudaFree(d_out));
  CUDA_CHECK(cudaFree(d_keys));
  return (mismatches == 0) ? 0 : 1;
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions