Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds a full OpenReg backend: native OpenReg runtime and third-party library, PyTorch integration (ATen wrappers, allocator, streams/events, guards, serialization), Python package and bindings, build/packaging, extensive tests, and CI adjustments to install and run the new Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant PyTorch as PyTorch Core
participant OpenRegPy as torch_openreg (Python)
participant OpenRegC as torch_openreg._C (C ext)
participant OpenRegRT as c10::openreg runtime
participant OpenRegLib as openreg backend lib
App->>PyTorch: set device "openreg" / create tensor
PyTorch->>OpenRegPy: backend hooks -> _lazy_init()
OpenRegPy->>OpenRegC: _init()
OpenRegC->>OpenRegRT: initialize device context
OpenRegRT->>OpenRegLib: orGetDeviceCount()
OpenRegLib-->>OpenRegRT: device count
OpenRegRT-->>OpenRegC: init complete
OpenRegC-->>OpenRegPy: init ack
OpenRegPy-->>PyTorch: backend ready
App->>PyTorch: allocate tensor (nbytes)
PyTorch->>OpenRegRT: allocate(nbytes)
OpenRegRT->>OpenRegLib: orMalloc()
OpenRegLib-->>OpenRegRT: device pointer
OpenRegRT-->>PyTorch: DataPtr with deleter
PyTorch-->>App: tensor on "openreg"
App->>PyTorch: get stream / run op on stream
PyTorch->>OpenRegRT: getStreamFromPool() / dispatch op
OpenRegRT->>OpenRegLib: orStreamCreate() / enqueue task
OpenRegLib-->>OpenRegRT: stream handle / task executed
OpenRegRT->>OpenRegRT: MemoryGuard unprotect/re-protect
OpenRegRT-->>PyTorch: op complete
PyTorch-->>App: result ready
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR integrates OpenReg, a PyTorch CPU-based accelerator simulator, to enable CUDA-like testing on CPUs. The integration adds a complete third-party device backend implementation using PyTorch's PrivateUse1 mechanism, allowing GPU-style tests to run without expensive GPU hardware.
Changes:
- Added torch_openreg as a third-party dependency providing a complete PrivateUse1 backend implementation
- Integrated OpenReg device support into DeePMD-kit's PyTorch utilities and CI/CD pipeline
- Pinned PyTorch version to 2.10.0 to ensure compatibility with the OpenReg implementation
Reviewed changes
Copilot reviewed 75 out of 75 changed files in this pull request and generated 22 comments.
Show a summary per file
| File | Description |
|---|---|
| source/3rdparty/torch_openreg/* | Complete OpenReg backend implementation with Python bindings, C++ runtime, operators, and device simulation |
| pyproject.toml | Pinned PyTorch version to 2.10.0 for OpenReg compatibility |
| deepmd/pt/utils/env.py | Added OpenReg device detection and initialization |
| .github/workflows/test_python.yml | Added OpenReg installation and enabled OpenReg device in CI |
| source/3rdparty/README.md | Added torch_openreg to third-party dependencies list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
There was a problem hiding this comment.
Actionable comments posted: 15
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🤖 Fix all issues with AI agents
In `@source/3rdparty/torch_openreg/csrc/aten/native/Minimal.cpp`:
- Around line 131-136: The _copy_from_and_resize function performs resize_/copy_
on openreg tensors without unprotecting memory; wrap the operations with a
MemoryGuard so protected pages are unprotected for the duration of the
operations. Specifically, in _copy_from_and_resize create a MemoryGuard
(matching usage in _copy_from, _local_scalar_dense, view) before calling
at::native::resize_ and at::native::copy_, ensuring the guard's lifetime spans
both calls and using the same dst/self tensor references as shown in the file.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp`:
- Around line 169-188: The code calls device_allocators_[current_device_index]
without validating current_device_index; update OpenRegDeviceAllocator::allocate
to check that orGetDevice succeeded and that current_device_index is within [0,
device_allocators_.size()). If the index is out of bounds, fail fast (use
TORCH_CHECK or equivalent) or return an empty/zero DataPtr consistently; only
dereference device_allocators_ after the bounds check. Apply the same defensive
bounds validation to other methods that index device_allocators_ (e.g.,
freeMemory(), getDeviceStats()) and protect accesses to allocated_blocks_ and
mutex_ accordingly.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegEvent.h`:
- Around line 10-12: The constructor for OpenRegEvent maps the bool parameter
enable_timing to an unsigned int field enable_timing_ implicitly; change this to
an explicit mapping to OpenReg flags by checking enable_timing in the
OpenRegEvent(bool enable_timing) noexcept constructor and assign enable_timing_
to the appropriate flag constant (e.g., orEventEnableTiming when true,
orEventDisableTiming when false) so the class OpenRegEvent no longer relies on
implicit bool-to-int conversion of enable_timing.
- Around line 13-16: The OpenRegEvent destructor currently calls
OPENREG_CHECK(orEventDestroy(event_)) which can throw; change it to a
non-throwing cleanup: call orEventDestroy(event_) directly, capture its return
value, and handle failures without throwing (e.g., log the error and continue or
ignore the result). Update the destructor in ~OpenRegEvent to use is_created_,
call orEventDestroy(event_), and on non-success use a non-throwing logger or
noexcept-safe error handling instead of OPENREG_CHECK so destruction remains
noexcept; reference symbols: ~OpenRegEvent, is_created_, OPENREG_CHECK,
orEventDestroy, event_.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegException.cpp`:
- Around line 3-8: The orCheckFail function currently throws ::c10::Error using
the raw msg C-string which can be null and cause UB; update orCheckFail to guard
against a null msg (e.g., if msg == nullptr use a safe fallback like "unknown
error" or "(null)") before constructing ::c10::Error, and ensure callers such as
OPENREG_CHECK still work with the non-null fallback.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegException.h`:
- Around line 9-15: The OPENREG_CHECK macro uses the non-standard GNU
'##__VA_ARGS__' token-pasting which breaks when C++ extensions are disabled;
change the macro so it always calls orCheckFail(__func__, __FILE__,
static_cast<uint32_t>(__LINE__)) without the variadic argument. Update the
OPENREG_CHECK definition (which currently evaluates EXPR into an orError_t and
checks against orSuccess using C10_UNLIKELY) to remove the trailing comma and
'##__VA_ARGS__' so callers only pass EXPR, relying on orCheckFail's default
parameter instead.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegGenerator.cpp`:
- Around line 9-26: The function getDefaultOpenRegGenerator may index
default_generators with idx==-1 resolved via current_device() without ensuring
device_count()>0; validate the resolved c10::DeviceIndex idx against
device_count() before returning default_generators[idx]. Specifically, after
computing idx (and after the existing else branch), call device_count() and
TORCH_CHECK that device_count() > 0 and idx >= 0 && idx < device_count(), and
handle the failure path consistently (e.g., TORCH_CHECK with explanatory
message) so accessing default_generators is safe; update references in
getDefaultOpenRegGenerator to use this validated idx.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegGuard.h`:
- Around line 91-93: synchronizeDevice currently ignores the DeviceIndex
argument; update OpenRegGuard::synchronizeDevice to synchronize the specified
device_index instead of the current device by calling the device-aware API
(e.g., orDeviceSynchronize(device_index)) if available, or by temporarily
setting the device (e.g., save current device, call orSetDevice(device_index),
call orDeviceSynchronize(), then restore previous device) so the provided
DeviceIndex is honored; modify the implementation referencing synchronizeDevice
and orDeviceSynchronize (and orSetDevice if used) accordingly.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegSerialization.cpp`:
- Around line 17-25: The dynamic_cast result o_meta_ptr may be nullptr and is
dereferenced unconditionally; update the block that currently does
dynamic_cast<OpenRegBackendMeta*>(meta_ptr) to check that o_meta_ptr != nullptr
before reading o_meta_ptr->version_number_ and o_meta_ptr->format_number_, and
only set m["version_number"] and m["format_number"] when the cast succeeds;
reference the variables meta_ptr and o_meta_ptr and the class OpenRegBackendMeta
to locate and safeguard the accesses.
In `@source/3rdparty/torch_openreg/csrc/runtime/OpenRegStream.cpp`:
- Around line 230-239: The functions getStreamFromPool, getStreamFromExternal,
and setCurrentOpenRegStream currently index per-device arrays (priority_counters
and current_streams) without validating device_index; after resolving
device_index (i.e., after the device_index == -1 -> current_device() fallback)
call check_device(device_index) to validate it before any access, and return or
throw appropriately if invalid; update getStreamFromPool (around the
device_index resolution and before using priority_counters),
getStreamFromExternal (before accessing priority_counters/current_streams), and
setCurrentOpenRegStream (before mutating current_streams) to perform this check
so array indexing is safe.
- Around line 200-224: OpenRegStream::stream() can access uninitialized globals
(streams, max_stream_priorities) when an OpenRegStream is constructed directly
(e.g., via unpack3), and its bounds check uses <= which allows OOB; ensure the
global stream state is initialized at the start of OpenRegStream::stream() (call
whatever initializer is used elsewhere, e.g.,
ensureStreamsInitialized()/initializeGlobalStreams()) before using streams or
max_stream_priorities, and change the priority bounds check from "streamType >=
0 && streamType <= max_stream_priorities" to "streamType >= 0 && streamType <
max_stream_priorities" (also add defensive checks for device_index and si to
avoid OOB against streams[device_index][streamType][si]).
In `@source/3rdparty/torch_openreg/third_party/openreg/csrc/memory.h`:
- Around line 20-94: The functions defined in this header (mmap, munmap,
mprotect, alloc, free, get_pagesize) must be marked inline to avoid
ODR/multiple-definition linker errors when the header is included in multiple
translation units; update each function declaration/definition in this header to
prepend the inline keyword (e.g., inline void* mmap(...), inline void
munmap(...), inline int mprotect(...), inline int alloc(...), inline void
free(...), inline long get_pagesize(...)) so each definition becomes an inline
function.
In `@source/3rdparty/torch_openreg/third_party/openreg/csrc/stream.cpp`:
- Around line 33-58: The worker currently pops a task and runs it outside the
lock which allows orStreamQuery to see an empty queue while a task is still
executing; add an std::atomic<int> in struct orStream (e.g., in_flight_count)
and update the worker to increment in_flight_count immediately after popping the
task (before releasing the lock) and decrement it after task() returns; update
orStreamQuery to consider both tasks.empty() and in_flight_count==0 when
deciding completion; ensure stop logic still waits for in_flight_count to reach
zero before exiting.
In `@source/3rdparty/torch_openreg/third_party/openreg/include/openreg.h`:
- Around line 34-38: The struct orPointerAttributes currently uses a C++11
default member initializer on the member "type"
(orMemoryType::orMemoryTypeUnmanaged) which breaks C compatibility because the
header is exposed under extern "C"; remove the default initializer from
orPointerAttributes so all members are plain declarations (e.g., keep fields:
orMemoryType type; int device; void* pointer;), and update callers to initialize
instances explicitly (e.g., orPointerAttributes attr = {0}; or memset) to ensure
deterministic defaults.
In `@source/3rdparty/torch_openreg/torch_openreg/_utils.py`:
- Around line 9-18: The call to os.add_dll_directory(openreg_dll_path) must be
guarded by the with_load_library_flags check so it is only invoked when the
platform supports AddDllDirectory; update the block around kernel32,
with_load_library_flags, and os.add_dll_directory to call
os.add_dll_directory(openreg_dll_path) only if with_load_library_flags is True,
and otherwise skip that call so the PATH fallback is used (ensure no exceptions
are raised when AddDllDirectory is unavailable).
🟡 Minor comments (29)
source/3rdparty/torch_openreg/README.md-125-126 (1)
125-126:⚠️ Potential issue | 🟡 MinorFix typos in operator/AMP references.
These misspellings can mislead readers looking for the referenced APIs/files.📝 Suggested corrections
- - Fallback Registration for `AutogradPriavateUse1`: See `custom_abs` + - Fallback Registration for `AutogradPrivateUse1`: See `custom_abs` ... -- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`. +- Register specific operator conversion rules: See `autocast_mode.cpp` in `csrc/amp`.Also applies to: 144-145
source/3rdparty/torch_openreg/pyproject.toml-18-20 (1)
18-20:⚠️ Potential issue | 🟡 MinorConstrain
torchto a compatible minimum version.
The code usestorch.utils.rename_privateuse1_backend()andtorch.utils.generate_methods_for_privateuse1_backend(), which require PyTorch 2.1.0 or later. The suggested minimum should be 2.1.0, not 2.10.0.♻️ Corrected minimum version
dependencies = [ - "torch", + "torch>=2.1.0", ]source/3rdparty/torch_openreg/tests/test_utils.py-12-19 (1)
12-19:⚠️ Potential issue | 🟡 MinorUse
assertEqualfor clearer device equality diagnostics.
assertTrue(x_out.device == x_in.device)produces weaker failure messages thanassertEqual. When the assertion fails, unittest can provide more detailed output usingassertEqual.🔧 Suggested fix
- self.assertTrue(x_out.device == x_in.device) + self.assertEqual(x_out.device, x_in.device)source/3rdparty/torch_openreg/torch_openreg/_utils.py-11-42 (1)
11-42:⚠️ Potential issue | 🟡 MinorRestore
SetErrorModeeven when DLL loading fails.At lines 29 and 40, exceptions can be raised which will prevent the
SetErrorModerestoration at line 42 from executing. This leaves the process error mode modified for the remainder of the process. Wrap the code from line 13 through line 40 in atry/finallyblock to ensure the error mode is always restored.source/3rdparty/torch_openreg/tests/test_autograd.py-26-35 (1)
26-35:⚠️ Potential issue | 🟡 MinorFix intermittent /proc read failures by handling thread exit race.
Threads can exit between
psutil.Process(pid).threads()(line 14) and reading/proc/.../comm(line 19), causingFileNotFoundError. Wrap the read in a try-except to skip exited threads and keep the test stable.🔧 Suggested fix
for t in all_threads: - with open(f"{task_path}/{t.id}/comm") as file: - thread_name = file.read().strip() - all_thread_names.add(thread_name) + try: + with open(f"{task_path}/{t.id}/comm") as file: + thread_name = file.read().strip() + all_thread_names.add(thread_name) + except (FileNotFoundError, PermissionError): + # Thread exited or proc entry not accessible; skip + continuesource/3rdparty/torch_openreg/csrc/aten/native/Common.h-51-72 (1)
51-72:⚠️ Potential issue | 🟡 MinorConsider adding an overload for direct
at::TensorListinputs to guard against silent failures.The
find_and_unprotect_tensorstemplate currently handlesTensorBaseandc10::IValue(which internally manages tensor lists when wrapped). However, if any code path directly passesat::TensorListas an argument, it will match neither type check and silently become a no-op, leaving those tensors unprotected. While current call sites only pass individualTensorobjects, an explicit overload would prevent this potential issue.🔧 Suggested overload
+ void find_and_unprotect_tensors(at::TensorList tensors) { + for (const auto& tensor : tensors) { + unprotect_if_needed(tensor); + } + }source/3rdparty/torch_openreg/setup.py-76-86 (1)
76-86:⚠️ Potential issue | 🟡 MinorClean should also remove platform-specific extension artifacts (
.pydon Windows and.dylibon macOS).The current clean removes only
.sofiles, leaving stale compiled extensions behind on other platforms. Windows uses.pydfor Python extensions and macOS uses.dylib, both of which are listed inpackage_databut never cleaned.🔧 Suggested fix
- if filename.endswith(".so"): + if filename.endswith((".so", ".pyd", ".dylib")): os.remove(os.path.join(dirpath, filename))source/3rdparty/torch_openreg/third_party/openreg/csrc/stream.cpp-215-221 (1)
215-221:⚠️ Potential issue | 🟡 MinorGuard against null
priorityinorStreamGetPriority.
Consistent with other API entry points in this file (orEventCreateWithFlags, orStreamDestroy, orEventElapsedTime, etc.), add a null pointer check and returnorErrorUnknown.✅ Suggested fix
orError_t orStreamGetPriority( [[maybe_unused]] orStream_t stream, int* priority) { + if (!priority) { + return orErrorUnknown; + } // Since OpenReg has only one priority level, the following code // returns 0 directly for convenience. *priority = 0; return orSuccess; }deepmd/pt/utils/env.py-53-56 (1)
53-56:⚠️ Potential issue | 🟡 MinorConsider handling missing
torch_openregmodule gracefully.If
DEVICE=openregis set buttorch_openregis not installed, accessingtorch.openreg.is_available()will raise anAttributeErrorwith a confusing message. Consider wrapping this in a try-except to provide a clearer error:🛡️ Proposed fix for clearer error handling
elif os.environ.get("DEVICE") == "openreg": - if not torch.openreg.is_available(): - raise RuntimeError("OpenReg backend is not available in this build.") - DEVICE = torch.device("openreg") + try: + if not torch.openreg.is_available(): + raise RuntimeError("OpenReg backend is not available in this build.") + DEVICE = torch.device("openreg") + except AttributeError: + raise RuntimeError("torch_openreg package is not installed. Install it to use DEVICE=openreg.")source/3rdparty/torch_openreg/torch_openreg/openreg/meta.py-5-5 (1)
5-5:⚠️ Potential issue | 🟡 MinorRemove the unused
noqadirective.Ruff reports the
TOR901suppression as unused.🧹 Suggested change
-lib = torch.library.Library("openreg", "IMPL", "Meta") # noqa: TOR901 +lib = torch.library.Library("openreg", "IMPL", "Meta")source/3rdparty/torch_openreg/tests/test_ops.py-183-187 (1)
183-187:⚠️ Potential issue | 🟡 Minor
assertTrue(x.device.type, "openreg")doesn't validate the device type.In Python's unittest,
assertTrue(expr, msg)treats the second argument as an optional error message, not as a comparison value. This assertion passes for any truthy value, so it would pass whether the device type is "openreg" or "cpu". UseassertEqualto actually check the device type value.✅ Suggested change
- self.assertTrue(x.device.type, "openreg") + self.assertEqual(x.device.type, "openreg")source/3rdparty/torch_openreg/tests/test_ops.py-88-98 (1)
88-98:⚠️ Potential issue | 🟡 MinorUse
assertEqualinstead ofassertTruewith equality comparisons.These assertions produce weaker failure messages;
assertEqualclearly shows expected vs. actual values.Suggested changes
- self.assertTrue(tensor_openreg.size() == torch.Size([4, 4])) + self.assertEqual(tensor_openreg.size(), torch.Size([4, 4])) - self.assertTrue(storage_openreg.size() == 16) + self.assertEqual(storage_openreg.size(), 16) - self.assertTrue(tensor_openreg.size() == torch.Size([2, 2, 2, 2])) + self.assertEqual(tensor_openreg.size(), torch.Size([2, 2, 2, 2])) - self.assertTrue(storage_openreg.size() == 16) + self.assertEqual(storage_openreg.size(), 16)source/3rdparty/torch_openreg/third_party/openreg/README.md-69-138 (1)
69-138:⚠️ Potential issue | 🟡 MinorFix a few typos in API/prose to avoid confusion.
There are a couple of doc slips that could mislead readers (e.g.,
orMemcpyAsyn/cudaMemcpyAsyn, “steam” vs “stream”, and “as follow”). Consider tightening these for clarity.✍️ Suggested doc edits
-| `orMemcpyAsyn` | `cudaMemcpyAsyn` | Asynchronous memory copy | +| `orMemcpyAsync` | `cudaMemcpyAsync` | Asynchronous memory copy | -### Stream&Event Principles +### Stream & Event Principles -Simulating creation, release and synchronization for event and steam: +Simulating creation, release and synchronization for event and stream: -The command to compile example.cpp is as follow: +The command to compile example.cpp is as follows:source/3rdparty/torch_openreg/third_party/openreg/example/example.cpp-7-17 (1)
7-17:⚠️ Potential issue | 🟡 MinorAvoid MemoryGuard protecting freed pointers.
Device buffers are freed at lines 100-102 while guard objects remain in scope; their destructors execute afterorFreereturns and callorMemoryProtecton already-freed memory. Add a release method to nullify the pointer before freeing.🔧 Suggested fix
struct MemoryGuard { MemoryGuard(void* ptr) : ptr_(ptr) { orMemoryUnprotect(ptr_); } ~MemoryGuard() { - orMemoryProtect(ptr_); + if (ptr_) { + orMemoryProtect(ptr_); + } } + void release() noexcept { ptr_ = nullptr; } private: void* ptr_{}; }; @@ + a.release(); + b.release(); + c.release(); orFree(dev_a); orFree(dev_b); orFree(dev_out);source/3rdparty/torch_openreg/third_party/openreg/tests/event_tests.cpp-71-95 (1)
71-95:⚠️ Potential issue | 🟡 MinorDestroy the stream in
EventElapsedTimeto avoid resource leaks.The test creates a stream at line 73 but never destroys it. All other tests in this file properly clean up streams they create, and this test should follow the same pattern.
Suggested fix
EXPECT_EQ(orEventDestroy(start), orSuccess); EXPECT_EQ(orEventDestroy(end), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess);source/3rdparty/torch_openreg/tests/test_event.py-53-66 (1)
53-66:⚠️ Potential issue | 🟡 MinorUse
assertGreaterEqualto allow 0ms elapsed time for back-to-back events.Events recorded consecutively on the same stream with no workload can legitimately have 0ms elapsed time. The current
assertTrue(ms > 0)assertion may flake on fast systems. UseassertGreaterEqual(ms, 0)instead.Suggested change
- self.assertTrue(ms > 0) + self.assertGreaterEqual(ms, 0)source/3rdparty/torch_openreg/csrc/runtime/OpenRegFunctions.cpp-68-77 (1)
68-77:⚠️ Potential issue | 🟡 MinorInconsistent error handling in
ExchangeDevice.Unlike
SetDevice(line 22) which usesOPENREG_CHECK,ExchangeDeviceignores the return values oforGetDeviceandorSetDevice. This inconsistency could mask errors. Consider adding error checking for consistency with other functions in this file.🛡️ Proposed fix to add error checking
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) { int current_device = -1; - orGetDevice(¤t_device); + OPENREG_CHECK(orGetDevice(¤t_device)); if (device != current_device) { - orSetDevice(device); + OPENREG_CHECK(orSetDevice(device)); } return current_device; }source/3rdparty/torch_openreg/csrc/runtime/OpenRegFunctions.cpp-13-18 (1)
13-18:⚠️ Potential issue | 🟡 Minor
GetDeviceshould verify return value before dereferencing.If
orGetDevicefails, the function still assignstmp_deviceto*device. Consider checking the error before the assignment to avoid propagating an invalid value.🛡️ Proposed fix
orError_t GetDevice(DeviceIndex* device) { int tmp_device = -1; auto err = orGetDevice(&tmp_device); - *device = static_cast<DeviceIndex>(tmp_device); + if (err == orSuccess) { + *device = static_cast<DeviceIndex>(tmp_device); + } return err; }source/3rdparty/torch_openreg/csrc/runtime/OpenRegHostAllocator.h-21-28 (1)
21-28:⚠️ Potential issue | 🟡 MinorMissing error handling for
orMallocHostreturn value.The
orMallocHostcall doesn't check its return value for errors. While you verify thatdatais non-null, the allocation function returns anorError_tthat should be checked to distinguish between allocation failures and other errors.🛡️ Proposed fix to check allocation errors
at::DataPtr allocate(size_t nbytes) override { void* data = nullptr; if (nbytes > 0) { - orMallocHost(&data, nbytes); - TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + orError_t err = orMallocHost(&data, nbytes); + TORCH_CHECK(err == orSuccess && data, "Failed to allocate ", nbytes, " bytes on host."); } return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; }source/3rdparty/torch_openreg/csrc/runtime/OpenRegHooks.h-74-83 (1)
74-83:⚠️ Potential issue | 🟡 MinorUnreachable code after
TORCH_CHECK(false, ...).Line 82 is unreachable because
TORCH_CHECK(false, ...)on line 80 always throws. This dead code should be removed.🧹 Proposed fix
at::Device getDeviceFromPtr(void* data) const override { orPointerAttributes attr{}; auto err = orPointerGetAttributes(&attr, data); if (err == orSuccess && attr.type == orMemoryTypeDevice) { return at::Device(at::DeviceType::PrivateUse1, static_cast<int>(attr.device)); - } else { - TORCH_CHECK(false, "failed to get device from pointer"); } - return at::Device(at::DeviceType::PrivateUse1, current_device()); + TORCH_CHECK(false, "failed to get device from pointer"); }source/3rdparty/torch_openreg/csrc/runtime/OpenRegHooks.h-67-72 (1)
67-72:⚠️ Potential issue | 🟡 MinorError return from
orPointerGetAttributesis not checked.Unlike
getDeviceFromPtrwhich checks the error return,isPinnedPtrignores the return value oforPointerGetAttributes. If the call fails,attr.typemay contain uninitialized or stale data, leading to incorrect results.🛡️ Proposed fix
bool isPinnedPtr(const void* data) const override { orPointerAttributes attr{}; - orPointerGetAttributes(&attr, data); - - return attr.type == orMemoryTypeHost; + auto err = orPointerGetAttributes(&attr, data); + return err == orSuccess && attr.type == orMemoryTypeHost; }source/3rdparty/torch_openreg/torch_openreg/openreg/__init__.py-32-33 (1)
32-33:⚠️ Potential issue | 🟡 Minor
is_available()always returnsTrue, ignoring actual device availability.This function unconditionally returns
True, which doesn't reflect actual device availability. The C++OpenRegHooksInterface::isAvailable()correctly checksdevice_count() > 0. Consider aligning the Python implementation.🔧 Proposed fix
def is_available(): - return True + return device_count() > 0source/3rdparty/torch_openreg/third_party/openreg/csrc/memory.cpp-127-127 (1)
127-127:⚠️ Potential issue | 🟡 MinorTypo: extra space in
std ::lock_guard.There's an extraneous space between
stdand::.✏️ Proposed fix
- std ::lock_guard<std::mutex> lock(m_mutex); + std::lock_guard<std::mutex> lock(m_mutex);source/3rdparty/torch_openreg/third_party/openreg/csrc/memory.cpp-171-183 (1)
171-183:⚠️ Potential issue | 🟡 Minor
refcountcan underflow ifprotectNoLockis called without matchingunprotectNoLock.The refcount is decremented unconditionally when
refcount >= 1, but there's no guard against callingprotectNoLockmore times thanunprotectNoLock. This could lead to negative refcount and incorrect protection state.🛡️ Proposed fix to add underflow check
orError_t protectNoLock(Block* info) { if (info && info->type == orMemoryType::orMemoryTypeDevice) { + if (info->refcount <= 0) { + // Already protected or invalid state + return orSuccess; + } if (info->refcount == 1) { if (openreg::mprotect(info->pointer, info->size, F_PROT_NONE) != 0) { return orErrorUnknown; } } info->refcount--; } return orSuccess; }source/3rdparty/torch_openreg/third_party/openreg/csrc/memory.cpp-82-83 (1)
82-83:⚠️ Potential issue | 🟡 MinorZero-size
memcpyreturns error but is typically a valid no-op.Returning
orErrorUnknownforcount == 0deviates from standardmemcpysemantics where zero-size copies are valid no-ops. This could cause unexpected failures.🔧 Proposed fix
orError_t memcpy( void* dst, const void* src, size_t count, orMemcpyKind kind) { - if (!dst || !src || count == 0) + if (!dst || !src) return orErrorUnknown; + if (count == 0) + return orSuccess;source/3rdparty/torch_openreg/csrc/aten/OpenRegMinimal.cpp-76-78 (1)
76-78:⚠️ Potential issue | 🟡 MinorTypo in function name:
densorshould bedense.The function
wrapper__local_scalar_densorappears to have a typo. It should likely bewrapper__local_scalar_denseto match the operation name_local_scalar_denseregistered at line 123.✏️ Proposed fix
-at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) { +at::Scalar wrapper__local_scalar_dense(const at::Tensor& self) { return at::native::openreg::_local_scalar_dense(self); }And update the registration at line 123:
- m.impl("_local_scalar_dense", wrapper__local_scalar_densor); + m.impl("_local_scalar_dense", wrapper__local_scalar_dense);source/3rdparty/torch_openreg/csrc/aten/OpenRegMinimal.cpp-90-98 (1)
90-98:⚠️ Potential issue | 🟡 MinorTypo in function name: extra "set" in
storage_offsetset_.The function
wrapper_set_source_Storage_storage_offsetset_has a redundant "set" suffix. It should bewrapper_set_source_Storage_storage_offset_to match the operation name at line 127.✏️ Proposed fix
-at::Tensor& wrapper_set_source_Storage_storage_offsetset_( +at::Tensor& wrapper_set_source_Storage_storage_offset_( at::Tensor& result, at::Storage storage, int64_t storage_offset, c10::IntArrayRef size, c10::IntArrayRef stride) { return at::native::openreg::set_source_Storage_storage_offset_( result, storage, storage_offset, size, stride); }And update the registration at lines 126-128:
m.impl( "set_.source_Storage_storage_offset", - wrapper_set_source_Storage_storage_offsetset_); + wrapper_set_source_Storage_storage_offset_);source/3rdparty/torch_openreg/tests/test_profiler.py-291-303 (1)
291-303:⚠️ Potential issue | 🟡 MinorPotentially uninitialized
profvariable.If the
autograd_profilecontext manager raises an exception during__enter__(before theas profbinding completes), theprofvariable will be uninitialized, and accessingprof.function_eventsat line 302 will raise aNameError.🛡️ Proposed fix to handle the edge case
`@skipIfTorchDynamo`() def test_profiler_exception_handling(self): """Test that profiler handles exceptions gracefully.""" + prof = None try: with autograd_profile(use_device="openreg") as prof: x = torch.randn(10, 10, device="openreg") # noqa: F841 raise RuntimeError("Test exception") except RuntimeError as e: self.assertEqual(str(e), "Test exception") # Profiler should still be usable - events = prof.function_events - self.assertIsInstance(events, list) + if prof is not None: + events = prof.function_events + self.assertIsInstance(events, list)source/3rdparty/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp-194-201 (1)
194-201:⚠️ Potential issue | 🟡 MinorHandle zero-length copies before calling orMemcpy.
orMemcpyreturns an error when count is 0, causing empty-tensor copies to fail. Add an early return and null pointer validation.🛠️ Proposed fix
void OpenRegDeviceAllocator::copy_data( void* dest, const void* src, std::size_t count) const { + if (count == 0) { + return; + } + TORCH_CHECK(dest && src, "Null pointer passed to OpenReg copy_data"); auto ret = orMemcpy(dest, src, count, orMemcpyDeviceToDevice); TORCH_CHECK( ret == orSuccess, "Failed to copy ", count, " bytes on openreg device"); }Note: The same issue exists in
OpenRegHostAllocator::copy_data.
🧹 Nitpick comments (21)
source/3rdparty/torch_openreg/include/Macros.h (1)
3-7: Make Windows export/import selectable for consumers.
If this header is used outside the DLL build,__declspec(dllexport)should be__declspec(dllimport)for consumers; otherwise you can get warnings or suboptimal linkage. Consider a build flag to switch.♻️ Suggested pattern
`#ifdef` _WIN32 -#define OPENREG_EXPORT __declspec(dllexport) +#ifdef OPENREG_BUILD +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __declspec(dllimport) +#endif `#else` `#define` OPENREG_EXPORT __attribute__((visibility("default"))) `#endif`source/3rdparty/torch_openreg/setup.py (1)
30-39: Add a clear error if PyTorch isn’t installed before build.
get_pytorch_dir()will throw a rawImportErrorif torch isn’t available. A more explicit message improves build UX and aligns with the expected dependency ordering.🔧 Suggested fix
def get_pytorch_dir(): # Disable autoload of the accelerator @@ os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" - import torch + try: + import torch + except ImportError as exc: + raise RuntimeError( + "PyTorch must be installed before building torch_openreg." + ) from exc return os.path.dirname(os.path.realpath(torch.__file__))Based on learnings: Always install backend dependencies (TensorFlow/PyTorch) before building C++ components.
source/3rdparty/torch_openreg/tests/test_utils.py (1)
42-43: Fix typo in skip reason.Minor spelling issue in the skip message.
🔧 Suggested fix
- `@unittest.skip`("Abs kernel only supports float type when assertEuqal") + `@unittest.skip`("Abs kernel only supports float type when assertEqual")source/3rdparty/torch_openreg/CMakeLists.txt (1)
18-20: Avoid forcingCMAKE_BUILD_TYPEtoRelease.For multi-config generators or dev builds, this overrides user intent. Prefer a default only when unset.
🔧 Suggested adjustment
-set(CMAKE_BUILD_TYPE - Release - CACHE STRING "Build type" FORCE) +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "Build type") +endif()source/3rdparty/torch_openreg/tests/test_misc.py (1)
27-32: Prefer assertEqual for clearer failures.
assertTrue(a == b)loses detail on mismatch;assertEqualis more informative.♻️ Suggested change
- self.assertTrue( - torch.utils.backend_registration._get_custom_mod_func("device_count")() == 2 - ) + self.assertEqual( + torch.utils.backend_registration._get_custom_mod_func("device_count")(), 2 + ) @@ - self.assertTrue(x.type() == str) + self.assertEqual(x.type(), str)Also applies to: 117-119
source/3rdparty/torch_openreg/third_party/openreg/CMakeLists.txt (1)
34-37: Post-build test execution may slow development iterations.The
add_custom_commandruns tests automatically after every build ofortests. This is convenient for CI but may slow down local development when making iterative changes. Consider making this behavior optional or relying solely onctestinvocation.💡 Optional: Make post-build test execution configurable
+option(RUN_TESTS_POSTBUILD "Run tests after building" OFF) + if(USE_TEST) enable_testing() # ... add_test(NAME alltests COMMAND ${LIBRARY_TEST}) + if(RUN_TESTS_POSTBUILD) add_custom_command( TARGET ${LIBRARY_TEST} POST_BUILD COMMAND ${CMAKE_CTEST_COMMAND} -C Release --output-on-failure --verbose) + endif() endif()source/3rdparty/torch_openreg/csrc/runtime/OpenRegHostAllocator.h (1)
14-19: Consider checkingorFreeHostreturn value inReportAndDelete.Based on the relevant code snippet showing
orFreeHostreturnsorError_t, silently ignoring the return value could mask memory management issues. At minimum, consider logging a warning on failure.💡 Optional: Add error logging for free failures
static void ReportAndDelete(void* ptr) { if (!ptr) { return; } - orFreeHost(ptr); + orError_t err = orFreeHost(ptr); + TORCH_WARN_IF(err != orSuccess, "orFreeHost failed with error: ", err); }source/3rdparty/torch_openreg/tests/test_streams.py (1)
73-79: UseassertInfor clearer failure messages.The CodeQL hint is valid:
assertIn(a, b)provides more informative error messages thanassertTrue(a in b)when assertions fail.♻️ Proposed fix
`@skipIfTorchDynamo`() def test_stream_repr(self): """Test stream string representation""" stream = torch.Stream(device="openreg:1") - self.assertTrue( - "torch.Stream device_type=openreg, device_index=1" in repr(stream) - ) + self.assertIn( + "torch.Stream device_type=openreg, device_index=1", repr(stream) + )source/3rdparty/torch_openreg/torch_openreg/openreg/random.py (2)
17-38: Consider extracting device resolution logic to reduce duplication.Both
get_rng_stateandset_rng_statecontain identical device parsing logic (lines 17-24 and 29-36). Extracting this to a helper function would improve maintainability.💡 Optional: Extract device resolution helper
def _get_device_index(device="openreg"): """Resolve device argument to a device index.""" if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("openreg", device) idx = device.index if idx is None: idx = current_device() return idx def get_rng_state(device="openreg"): idx = _get_device_index(device) default_generator = torch_openreg._C._get_default_generator(idx) return default_generator.get_state() def set_rng_state(new_state, device="openreg"): idx = _get_device_index(device) default_generator = torch_openreg._C._get_default_generator(idx) default_generator.set_state(new_state)
8-14:__all__is not sorted (style nit).Per Ruff RUF022, applying isort-style sorting to
__all__improves consistency. This is a minor style issue.💡 Optional: Sort __all__
__all__ = [ "get_rng_state", + "initial_seed", + "manual_seed", + "manual_seed_all", "set_rng_state", - "manual_seed", - "manual_seed_all", - "initial_seed", ]source/3rdparty/torch_openreg/third_party/openreg/csrc/memory.h (1)
35-41: Return value ofmunmapis discarded on POSIX.On POSIX,
::munmapreturns 0 on success and -1 on error. The current implementation discards this return value, making it impossible for callers to detect failures. Consider returning an error code or at least documenting that errors are silently ignored.♻️ Proposed fix to return error status
-void munmap(void* addr, size_t size) { +inline int munmap(void* addr, size_t size) { `#if` defined(_WIN32) - VirtualFree(addr, 0, MEM_RELEASE); + return VirtualFree(addr, 0, MEM_RELEASE) ? 0 : -1; `#else` - ::munmap(addr, size); + return ::munmap(addr, size); `#endif` }source/3rdparty/torch_openreg/tests/test_storage.py (1)
151-153: UseassertIsfor identity comparison.Static analysis correctly identifies that
assertTrue(a is b)should useassertIs(a, b)for better error messages when the assertion fails.♻️ Proposed fix
- self.assertTrue( - rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor - ) + self.assertIs( + rebuild_func, torch._utils._rebuild_device_tensor_from_cpu_tensor + )source/3rdparty/torch_openreg/torch_openreg/openreg/__init__.py (2)
69-86: Star import makes__all__entries fragile and triggers static analysis errors.The star import
from .random import *combined with explicitly listing those names in__all__causes static analysis to flag undefined names (F405). Consider explicit imports for names exported in__all__.♻️ Proposed fix: Use explicit imports
-from .random import * # noqa: F403 +from .random import ( + get_rng_state, + initial_seed, + manual_seed, + manual_seed_all, + random, + set_rng_state, +)
45-47:set_devicedoesn't validate input type before comparison.If
deviceis not an integer (e.g., a string ortorch.device), the comparisondevice >= 0may raise aTypeErroror behave unexpectedly. Consider usingtorch.accelerator._get_device_indexlike thedevicecontext manager does.♻️ Proposed fix
def set_device(device) -> None: - if device >= 0: - torch_openreg._C._set_device(device) + idx = torch.accelerator._get_device_index(device, optional=True) + if idx is not None and idx >= 0: + torch_openreg._C._set_device(idx)source/3rdparty/torch_openreg/tests/test_profiler.py (2)
169-173: Consider movingimport osto the top of the file.The
osmodule import is inside thefinallyblock. While this works, it's more conventional to place imports at the module level per PEP 8.✨ Proposed fix
Add
import osat line 5 with the other imports:import json +import os import tempfileThen remove the import from the
finallyblock:finally: - import os - if os.path.exists(trace_file): os.remove(trace_file)
519-521: UseassertGreaterfor clearer failure messages.Per static analysis,
assertTrue(a > b)provides less informative messages on failure compared toassertGreater(a, b).✨ Proposed fix
# Should capture various operation types event_names = [e.name for e in events] - self.assertTrue(len([n for n in event_names if "aten::" in n]) > 0) + self.assertGreater(len([n for n in event_names if "aten::" in n]), 0)source/3rdparty/torch_openreg/tests/test_memory.py (3)
184-189: UseassertEqualfor device type comparison.Per static analysis,
assertTrue(a == b)provides less informative failure messages. UseassertEqualfor better diagnostics.✨ Proposed fix
# Verify storage is on correct device - self.assertTrue(storage.device.type == "openreg") + self.assertEqual(storage.device.type, "openreg")
125-137:empty_cachetest doesn't exercise OpenReg-specific cache clearing.The test calls
torch.cuda.empty_cache()conditionally on CUDA availability, but doesn't actually test OpenReg's cache clearing. Consider usingtorch.openreg.empty_cache()if such an API exists, or document that OpenReg'semptyCacheis intentionally a no-op.
413-424: Test setup assumes exactly 2 devices—consider making it more flexible.The
setUpmethod usesassertEqualto assert exactly 2 devices, which will fail if the environment has a different count. Consider usingskipTestfor environments with fewer than 2 devices.✨ Proposed fix
def setUp(self): self.device_count = torch.openreg.device_count() - self.assertEqual(self.device_count, 2, "This test requires 2 OpenReg devices") + if self.device_count < 2: + self.skipTest("This test requires at least 2 OpenReg devices") gc.collect()source/3rdparty/torch_openreg/csrc/aten/native/Extra.cpp (1)
26-30:quantize_tensor_per_tensor_affine_stubis an empty implementation.This function is a no-op. If this is intentional for stub testing purposes (as mentioned in the comment at lines 146-152), consider adding a brief inline comment explaining this.
📝 Proposed documentation
void quantize_tensor_per_tensor_affine_stub( const at::Tensor& rtensor, at::Tensor& qtensor, double scale, - int64_t zero_point) {} + int64_t zero_point) { + // Intentionally empty - stub for dispatch registration testing +}source/3rdparty/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp (1)
247-258: Consider bounds checks for device index in stats APIs.
Defensive validation avoids out-of-range access if callers pass an invalid device id.♻️ Suggested hardening
c10::CachingDeviceAllocator::DeviceStats OpenRegDeviceAllocator:: getDeviceStats(c10::DeviceIndex device) { + TORCH_CHECK( + device >= 0 && + device < static_cast<c10::DeviceIndex>(device_allocators_.size()), + "Invalid OpenReg device index: ", + device); return device_allocators_[device]->getStats(); } void OpenRegDeviceAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + TORCH_CHECK( + device >= 0 && + device < static_cast<c10::DeviceIndex>(device_allocators_.size()), + "Invalid OpenReg device index: ", + device); device_allocators_[device]->resetAccumulatedStats(); } void OpenRegDeviceAllocator::resetPeakStats(c10::DeviceIndex device) { + TORCH_CHECK( + device >= 0 && + device < static_cast<c10::DeviceIndex>(device_allocators_.size()), + "Invalid OpenReg device index: ", + device); device_allocators_[device]->resetPeakStats(); }
|
@coderabbitai /resolve |
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
|
✅ Actions performedComments resolved. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5188 +/- ##
==========================================
- Coverage 81.95% 81.94% -0.01%
==========================================
Files 714 714
Lines 73441 73445 +4
Branches 3616 3617 +1
==========================================
+ Hits 60187 60188 +1
- Misses 12091 12096 +5
+ Partials 1163 1161 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
See https://pytorch.org/blog/openreg-a-self-contained-pytorch-accelerator-simulator/ for the introduction of openreg. With it, we can implement CUDA-like PyTorch tests on CPUs (if customized OPs are not used), which may replace tests on the expensive GPUs -- we didn't find the solution two years ago, but now it has finally come out.
For example,
If a device is missing, with openreg, we can get the error:
RuntimeError: Expected a 'cpu' device type for generator but found 'openreg'Summary by CodeRabbit
New Features
User-visible APIs
Tests
Chores