-
Notifications
You must be signed in to change notification settings - Fork 245
Description
I was testing rmm's uvm managed allocator with pytorch and unsloth:
import os
import torch
import rmm
from rmm.allocators.torch import rmm_torch_allocator
rmm.reinitialize(pool_allocator=True)
custom_pool = torch.cuda.MemPool(rmm_torch_allocator.allocator())
mr = rmm.mr.PrefetchResourceAdaptor(
rmm.mr.PoolMemoryResource(
rmm.mr.ManagedMemoryResource()
)
)
rmm.mr.set_current_device_resource(mr)
# Unsloth stuffs
from unsloth import PatchDPOTrainer
PatchDPOTrainer()
from unsloth import FastLanguageModel
from datasets import load_dataset
from unsloth import apply_chat_template
from unsloth.chat_templates import get_chat_template
model, tokenizer = FastLanguageModel.from_pretrained(
...
model = FastLanguageModel.get_peft_model(
...
trainer = SFTTrainer(
...
# end of unsloth stuffs
with torch.cuda.use_mem_pool(custom_pool):
trainer_stats = trainer.train()This code does work, but for some reason the memory usage isnt shown correctly in both nvtop/nvidia-smi

The process I highlighted is the script and as you see it's shown only allocated 158MB, tho the gpu's memory is full(and used by it)
nvidia-smi:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4060 ... Off | 00000000:01:00.0 Off | N/A |
| N/A 55C P0 49W / 139W | 7805MiB / 8188MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2018 G /usr/lib/Xorg 4MiB |
| 0 N/A N/A 5619 C python 158MiB |
+-----------------------------------------------------------------------------------------+
The model should take 10g+ vram so there's no way this process actually only hold 158MB gpu ram.
However when using my simple allocator it's shown in nvtop/nvidia-smi correctly as expected
def get_cuda_malloc_allocator():
"""Get the cuda_malloc_with_pluggable allocator."""
#print("current directory: ", os.getcwd())
allocator_path = './alloc.so'
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
allocator_path, 'custom_malloc', 'custom_free')
return new_alloc
allocator = get_cuda_malloc_allocator()
custom_pool = torch.cuda.MemPool(allocator.allocator())alloc.cpp:
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC -lcuda
void checkCudaError(cudaError_t error, const char* message) {
if (error != cudaSuccess) {
std::cerr << message << ": " << cudaGetErrorString(error) << std::endl;
throw std::runtime_error("CUDA error");
}
}
extern "C" {
void* custom_malloc(ssize_t size, int device, cudaStream_t stream) {
void* ptr = nullptr;
cudaError_t error = cudaMallocManaged(&ptr, size);
//std::cout<< "CUDA DEBUG" <<"cuda_malloc "<<ptr<< " " <<size<<std::endl;
checkCudaError(error, "Failed to allocate memory");
return ptr;
}
void custom_free(void* ptr, int device, cudaStream_t stream) {
cudaError_t error = cudaFree(ptr);
checkCudaError(error, "Failed to free memory");
}
}Again I can train the model just fine, and rmm's performance is way better than the simple allocator I wrote, but it seems that something might be wrong here and I dont know if it's actually gonna cause any trouble.
I'm using rmm 26.02.00, cuda 13.1, pytorch 2.10.0 git main branch(built yesterday) and unsloth git main branch.
My environment is archlinux x86-64.
For some reason rmm wont work with pytorch 2.10.0 stable release and will trigger oom even though it's using uvm, but with the main branch pytorch it works fine.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status