Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 68 additions & 80 deletions hhds/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

namespace hhds {

Node_hier::Node_hier(Tid hier_tid_value, std::shared_ptr<std::vector<Gid>> hier_gids_value, Tree_pos hier_pos_value, Nid raw_nid_value)
Node_hier::Node_hier(Tid hier_tid_value, std::shared_ptr<std::vector<Gid>> hier_gids_value, Tree_pos hier_pos_value,
Nid raw_nid_value)
: hier_gids(std::move(hier_gids_value)), hier_tid(hier_tid_value), hier_pos(hier_pos_value), raw_nid(raw_nid_value) {}

auto Node_hier::get_root_gid() const noexcept -> Gid {
Expand Down Expand Up @@ -59,13 +60,13 @@ auto Pin_hier::get_current_gid() const noexcept -> Gid {
return (*hier_gids)[static_cast<size_t>(hier_pos)];
}

auto to_class(const Node_hier& v) -> Node_class { return Node_class(nullptr, v.get_raw_nid() & ~static_cast<Nid>(2)); }
auto to_class(const Node_hier& v) -> Node_class { return Node_class(v.get_raw_nid() & ~static_cast<Nid>(2)); }

auto to_flat(const Node_hier& v) -> Node_flat {
return Node_flat(v.get_root_gid(), v.get_current_gid(), v.get_raw_nid() & ~static_cast<Nid>(2));
}

auto to_class(const Node_flat& v) -> Node_class { return Node_class(nullptr, v.get_raw_nid() & ~static_cast<Nid>(2)); }
auto to_class(const Node_flat& v) -> Node_class { return Node_class(v.get_raw_nid() & ~static_cast<Nid>(2)); }

auto to_flat(const Node_class& v, Gid current_gid, Gid root_gid) -> Node_flat {
if (root_gid == Gid_invalid) {
Expand Down Expand Up @@ -112,9 +113,13 @@ auto to_flat(const Edge_hier& e) -> Edge_flat {

auto to_hier(const Edge_class& e, Tid hier_tid, std::shared_ptr<std::vector<Gid>> hier_gids, Tree_pos hier_pos) -> Edge_hier {
Edge_hier out;
out.driver = Pin_hier(hier_tid, hier_gids, hier_pos, e.driver_pin.get_raw_nid(), e.driver_pin.get_port_id(),
e.driver_pin.get_pin_pid());
out.sink = Pin_hier(hier_tid, std::move(hier_gids), hier_pos, e.sink_pin.get_raw_nid(), e.sink_pin.get_port_id(),
out.driver
= Pin_hier(hier_tid, hier_gids, hier_pos, e.driver_pin.get_raw_nid(), e.driver_pin.get_port_id(), e.driver_pin.get_pin_pid());
out.sink = Pin_hier(hier_tid,
std::move(hier_gids),
hier_pos,
e.sink_pin.get_raw_nid(),
e.sink_pin.get_port_id(),
e.sink_pin.get_pin_pid());
return out;
}
Expand All @@ -123,8 +128,8 @@ auto to_class(const Edge_flat& e) -> Edge_class {
Edge_class out{};
out.driver_pin = to_class(e.driver);
out.sink_pin = to_class(e.sink);
out.driver = Node_class(nullptr, out.driver_pin.get_raw_nid() | static_cast<Nid>(2));
out.sink = Node_class(nullptr, out.sink_pin.get_raw_nid() & ~static_cast<Nid>(2));
out.driver = Node_class(out.driver_pin.get_raw_nid() | static_cast<Nid>(2));
out.sink = Node_class(out.sink_pin.get_raw_nid() & ~static_cast<Nid>(2));
out.type = 2; // p -> p
return out;
}
Expand All @@ -133,8 +138,8 @@ auto to_class(const Edge_hier& e) -> Edge_class {
Edge_class out{};
out.driver_pin = to_class(e.driver);
out.sink_pin = to_class(e.sink);
out.driver = Node_class(nullptr, out.driver_pin.get_raw_nid() | static_cast<Nid>(2));
out.sink = Node_class(nullptr, out.sink_pin.get_raw_nid() & ~static_cast<Nid>(2));
out.driver = Node_class(out.driver_pin.get_raw_nid() | static_cast<Nid>(2));
out.sink = Node_class(out.sink_pin.get_raw_nid() & ~static_cast<Nid>(2));
out.type = 2; // p -> p
return out;
}
Expand Down Expand Up @@ -703,32 +708,20 @@ Graph::Graph() { clear_graph(); }

void Graph::assert_accessible() const noexcept { assert(!deleted_ && "graph is no longer valid"); }

void Node_class::assert_accessible_handle() const noexcept {
if (graph != nullptr) {
graph->assert_accessible();
}
}

void Pin_class::assert_accessible_handle() const noexcept {
if (graph != nullptr) {
graph->assert_accessible();
}
}

void Graph::assert_compatible(const Node_class& node) const noexcept {
void Graph::assert_node_exists(const Node_class& node) const noexcept {
assert_accessible();
if (node.graph != nullptr) {
node.graph->assert_accessible();
assert(node.graph == this && "node handle belongs to a different graph");
}
const Nid raw_nid = node.get_raw_nid();
const Nid actual_id = raw_nid >> 2;
assert((raw_nid & static_cast<Nid>(1)) == 0 && "node handle is not a node");
assert(actual_id > 0 && actual_id < node_table.size() && "node handle is invalid for this graph");
}

void Graph::assert_compatible(const Pin_class& pin) const noexcept {
void Graph::assert_pin_exists(const Pin_class& pin) const noexcept {
assert_accessible();
if (pin.graph != nullptr) {
pin.graph->assert_accessible();
assert(pin.graph == this && "pin handle belongs to a different graph");
}
const Pid raw_pid = pin.get_pin_pid();
const Pid actual_id = raw_pid >> 2;
assert((raw_pid & static_cast<Pid>(1)) == static_cast<Pid>(1) && "pin handle is not a pin");
assert(actual_id > 0 && actual_id < pin_table.size() && "pin handle is invalid for this graph");
Comment on lines +711 to +724
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

These checks no longer reject foreign class handles.

Node_class and Pin_class now carry only graph-local ids. A handle from another live Graph with the same local index will pass these bounds checks and hit this graph's node_table / pin_table entry instead, so mixed-graph calls can silently read or mutate the wrong object. These wrappers still need an owner discriminator (for example a per-graph serial/generation) if graph APIs are going to keep accepting class handles. Downstream, hhds/tests/iterators_impl.cpp:929-949 will also need to change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@hhds/graph.cpp` around lines 711 - 724, Graph::assert_node_exists and
Graph::assert_pin_exists currently only validate local indices (via
Node_class::get_raw_nid and Pin_class::get_pin_pid) so a handle from a different
live Graph with the same local id will pass and index the wrong entry; modify
the handle representation to carry a per-graph owner discriminator (e.g. a
generation/serial field) and update Graph::assert_node_exists and
Graph::assert_pin_exists to validate that the handle's owner discriminator
matches this graph's current discriminator before doing the index/bounds checks
against node_table and pin_table; update any constructors/factories that produce
Node_class/Pin_class and adjust downstream tests (hhds/tests/iterators_impl.cpp
lines ~929-949) to create/expect handles with the correct owner discriminator.

}

void Graph::invalidate_from_library() noexcept {
Expand Down Expand Up @@ -793,7 +786,7 @@ void Graph::rebuild_fast_class_cache() const {
fast_class_cache_.reserve(node_table.size() - 1);
for (size_t i = 1; i < node_table.size(); ++i) {
const Nid raw_nid = static_cast<Nid>(i) << 2;
fast_class_cache_.emplace_back(const_cast<Graph*>(this), raw_nid);
fast_class_cache_.emplace_back(raw_nid);
}
}
fast_class_cache_valid_ = true;
Expand Down Expand Up @@ -905,7 +898,7 @@ void Graph::rebuild_forward_class_cache() const {
}

emitted[node_idx] = true;
forward_class_cache_.emplace_back(const_cast<Graph*>(this), static_cast<Nid>(node_idx) << 2);
forward_class_cache_.emplace_back(static_cast<Nid>(node_idx) << 2);

for (auto sink_idx : adjacency[node_idx]) {
if (indegree[sink_idx] == 0) {
Expand All @@ -921,7 +914,7 @@ void Graph::rebuild_forward_class_cache() const {
// Preserve determinism under cycles by appending unresolved nodes by ID.
for (size_t idx = first_user_node_idx; idx < node_count; ++idx) {
if (!emitted[idx]) {
forward_class_cache_.emplace_back(const_cast<Graph*>(this), static_cast<Nid>(idx) << 2);
forward_class_cache_.emplace_back(static_cast<Nid>(idx) << 2);
}
}

Expand All @@ -947,8 +940,8 @@ void Graph::forward_flat_impl(Gid top_graph, ankerl::unordered_dense::set<Gid>&
}
}

void Graph::fast_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std::shared_ptr<std::vector<Gid>> hier_gids, Tree_pos hier_pos,
ankerl::unordered_dense::set<Gid>& active_graphs, std::vector<Node_hier>& out) const {
void Graph::fast_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std::shared_ptr<std::vector<Gid>> hier_gids,
Tree_pos hier_pos, ankerl::unordered_dense::set<Gid>& active_graphs, std::vector<Node_hier>& out) const {
for (size_t i = 1; i < node_table.size(); ++i) {
const Nid node_id = static_cast<Nid>(i) << 2;
const auto& node = node_table[i];
Expand All @@ -972,8 +965,9 @@ void Graph::fast_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std::s
}
}

void Graph::forward_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std::shared_ptr<std::vector<Gid>> hier_gids, Tree_pos hier_pos,
ankerl::unordered_dense::set<Gid>& active_graphs, std::vector<Node_hier>& out) const {
void Graph::forward_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std::shared_ptr<std::vector<Gid>> hier_gids,
Tree_pos hier_pos, ankerl::unordered_dense::set<Gid>& active_graphs,
std::vector<Node_hier>& out) const {
for (const auto& node : forward_class()) {
const Nid node_nid = node.get_raw_nid();
const auto& node_ref = node_table[static_cast<size_t>(node_nid >> 2)];
Expand All @@ -987,7 +981,8 @@ void Graph::forward_hier_impl(std::shared_ptr<Tree> hier_tree, Tid hier_tid, std
}
(*hier_gids)[static_cast<size_t>(child_hier_pos)] = other_graph_id;
active_graphs.insert(other_graph_id);
owner_lib_->get_graph(other_graph_id)->forward_hier_impl(hier_tree, hier_tid, hier_gids, child_hier_pos, active_graphs, out);
owner_lib_->get_graph(other_graph_id)
->forward_hier_impl(hier_tree, hier_tid, hier_gids, child_hier_pos, active_graphs, out);
active_graphs.erase(other_graph_id);
continue;
}
Expand Down Expand Up @@ -1088,13 +1083,20 @@ auto Graph::create_node() -> Node_class {
node_table.emplace_back(id);
invalidate_traversal_caches();
Nid raw_nid = id << 2 | 0;
return Node_class(this, raw_nid);
return Node_class(raw_nid);
}

auto Graph::create_pin(Node_class node, Port_id port_id) -> Pin_class {
assert_node_exists(node);
const Pid pin_pid = create_pin(node.get_raw_nid(), port_id);
return Pin_class(node.get_raw_nid(), port_id, pin_pid);
}

auto Graph::create_pin(Nid nid, Port_id pid) -> Pid {
assert_accessible();
// nid is << 2 here but port_id is not << 2 id (here but pin id in actual) is also not << 2
nid &= ~static_cast<Nid>(2); // Pin ownership is by node identity, independent of edge role bit.
const Nid actual_nid = nid >> 2;
assert(actual_nid > 0 && actual_nid < node_table.size() && "create_pin: node handle is invalid for this graph");
Pid id = pin_table.size();
Comment on lines 1095 to 1100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject pin-encoded ids in the raw create_pin path.

This overload only strips the direction bit. Passing an odd/pin-encoded value will reinterpret its numeric payload as a node index, so create_pin(existing_pin_pid, ...) silently attaches the new pin to whichever node lives at existing_pin_pid >> 2.

Suggested guard
 auto Graph::create_pin(Nid nid, Port_id pid) -> Pid {
   assert_accessible();
+  assert((nid & static_cast<Nid>(1)) == 0 && "create_pin expects a node handle, not a pin id");
   nid &= ~static_cast<Nid>(2);  // Pin ownership is by node identity, independent of edge role bit.
   const Nid actual_nid = nid >> 2;
   assert(actual_nid > 0 && actual_nid < node_table.size() && "create_pin: node handle is invalid for this graph");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@hhds/graph.cpp` around lines 1095 - 1100, The create_pin path currently
strips only the direction bit and will misinterpret pin-encoded ids as node
handles; add a guard at the start of Graph::create_pin to assert or reject that
the incoming nid is a node handle (i.e. its low two bits are zero) before
masking/shifting. Concretely, before `nid &= ~static_cast<Nid>(2);` add a check
like `assert((nid & 3) == 0 && "create_pin: expected node handle, got
pin-encoded id");` (or return/throw an error) so passing an existing pin Pid
cannot be silently reinterpreted as a different node index. Ensure the check
uses the same Nid/Pid types and message references `Graph::create_pin`, `nid`,
and existing masking/shift logic.

assert(id);
pin_table.emplace_back(nid, pid);
Expand All @@ -1105,27 +1107,13 @@ auto Graph::create_pin(Nid nid, Port_id pid) -> Pid {

auto Graph::make_pin_class(Pid pin_pid) const -> Pin_class {
const auto* pin = ref_pin(pin_pid);
return Pin_class(const_cast<Graph*>(this), pin->get_master_nid(), pin->get_port_id(), pin_pid);
return Pin_class(pin->get_master_nid(), pin->get_port_id(), pin_pid);
}

auto Pin_class::get_master_node() const -> Node_class {
assert_accessible_handle();
if (graph != nullptr && pin_pid != 0) {
const auto* pin = graph->ref_pin(pin_pid);
return Node_class(graph, pin->get_master_nid());
}
return Node_class(graph, raw_nid);
}

auto Node_class::create_pin(Port_id port_id_value) const -> Pin_class {
assert_accessible_handle();
assert(graph != nullptr);
const Pid pin_pid = graph->create_pin(raw_nid, port_id_value);
return Pin_class(graph, raw_nid, port_id_value, pin_pid);
}
auto Pin_class::get_master_node() const -> Node_class { return Node_class(raw_nid); }

void Graph::set_subnode(Node_class node, Gid gid) {
assert_compatible(node);
assert_node_exists(node);
set_subnode(node.get_raw_nid(), gid);
}

Expand Down Expand Up @@ -1158,32 +1146,32 @@ void Graph::add_edge(Vid driver_id, Vid sink_id) {

void Graph::del_edge(Node_class node1, Node_class node2) {
assert_accessible();
assert_compatible(node1);
assert_compatible(node2);
assert_node_exists(node1);
assert_node_exists(node2);
del_edge_int(node1.get_raw_nid(), node2.get_raw_nid());
invalidate_traversal_caches();
}

void Graph::del_edge(Node_class node, Pin_class pin) {
assert_accessible();
assert_compatible(node);
assert_compatible(pin);
assert_node_exists(node);
assert_pin_exists(pin);
del_edge_int(node.get_raw_nid(), pin.get_pin_pid());
invalidate_traversal_caches();
}

void Graph::del_edge(Pin_class pin, Node_class node) {
assert_accessible();
assert_compatible(pin);
assert_compatible(node);
assert_pin_exists(pin);
assert_node_exists(node);
del_edge_int(pin.get_pin_pid(), node.get_raw_nid());
invalidate_traversal_caches();
}

void Graph::del_edge(Pin_class pin1, Pin_class pin2) {
assert_accessible();
assert_compatible(pin1);
assert_compatible(pin2);
assert_pin_exists(pin1);
assert_pin_exists(pin2);
del_edge_int(pin1.get_pin_pid(), pin2.get_pin_pid());
invalidate_traversal_caches();
}
Expand Down Expand Up @@ -1256,12 +1244,12 @@ void Graph::del_edge_int(Vid driver_id, Vid sink_id) {

auto Graph::out_edges(Node_class node) -> std::vector<Edge_class> {
assert_accessible();
assert_compatible(node);
assert_node_exists(node);
std::vector<Edge_class> out;
const Nid self_nid = node.get_raw_nid() & ~static_cast<Nid>(2);
auto* self = ref_node(self_nid);
auto edges = self->get_edges(self_nid);
const Node_class self_driver(this, self_nid | static_cast<Nid>(2));
const Node_class self_driver(self_nid | static_cast<Nid>(2));

for (auto vid : edges) {
if (vid & 2) {
Expand All @@ -1283,7 +1271,7 @@ auto Graph::out_edges(Node_class node) -> std::vector<Edge_class> {
Edge_class e{};
e.type = 1; // n -> n
e.driver = self_driver;
e.sink = Node_class(this, sink_nid); // keep your existing ctor usage
e.sink = Node_class(sink_nid);
// e.driver_pin / e.sink_pin remain default
out.push_back(e);
}
Expand All @@ -1293,12 +1281,12 @@ auto Graph::out_edges(Node_class node) -> std::vector<Edge_class> {

auto Graph::inp_edges(Node_class node) -> std::vector<Edge_class> {
assert_accessible();
assert_compatible(node);
assert_node_exists(node);
std::vector<Edge_class> out;
const Nid self_nid = node.get_raw_nid() & ~static_cast<Nid>(2);
auto* self = ref_node(self_nid);
auto edges = self->get_edges(self_nid);
const Node_class self_sink(this, self_nid & ~static_cast<Nid>(2));
const Node_class self_sink(self_nid & ~static_cast<Nid>(2));

for (auto vid : edges) {
if (!(vid & 2)) {
Expand All @@ -1319,7 +1307,7 @@ auto Graph::inp_edges(Node_class node) -> std::vector<Edge_class> {

Edge_class e{};
e.type = 1; // n -> n
e.driver = Node_class(this, driver_nid);
e.driver = Node_class(driver_nid);
e.sink = self_sink;
// e.driver_pin / e.sink_pin remain default
out.push_back(e);
Expand All @@ -1330,7 +1318,7 @@ auto Graph::inp_edges(Node_class node) -> std::vector<Edge_class> {

auto Graph::out_edges(Pin_class pin) -> std::vector<Edge_class> {
assert_accessible();
assert_compatible(pin);
assert_pin_exists(pin);
std::vector<Edge_class> out;
const Pid self_pid = pin.get_pin_pid();
const Pid self_pid_lookup = (self_pid & ~static_cast<Pid>(2)) | static_cast<Pid>(1);
Expand Down Expand Up @@ -1358,7 +1346,7 @@ auto Graph::out_edges(Pin_class pin) -> std::vector<Edge_class> {
Edge_class e{};
e.type = 4; // p -> n
e.driver_pin = self_driver_pin;
e.sink = Node_class(this, sink_nid);
e.sink = Node_class(sink_nid);
out.push_back(e);
}

Expand All @@ -1367,7 +1355,7 @@ auto Graph::out_edges(Pin_class pin) -> std::vector<Edge_class> {

auto Graph::inp_edges(Pin_class pin) -> std::vector<Edge_class> {
assert_accessible();
assert_compatible(pin);
assert_pin_exists(pin);
std::vector<Edge_class> out;
const Pid self_pid = pin.get_pin_pid();
const Pid self_pid_sink = (self_pid & ~static_cast<Pid>(2)) | static_cast<Pid>(1);
Expand All @@ -1394,7 +1382,7 @@ auto Graph::inp_edges(Pin_class pin) -> std::vector<Edge_class> {

Edge_class e{};
e.type = 3; // n -> p
e.driver = Node_class(this, driver_nid);
e.driver = Node_class(driver_nid);
e.sink_pin = self_sink_pin;
out.push_back(e);
}
Expand All @@ -1404,7 +1392,7 @@ auto Graph::inp_edges(Pin_class pin) -> std::vector<Edge_class> {

auto Graph::get_pins(Node_class node) -> std::vector<Pin_class> {
assert_accessible();
assert_compatible(node);
assert_node_exists(node);
std::vector<Pin_class> out;
const Nid self_nid = node.get_raw_nid() & ~static_cast<Nid>(2);
auto* self = ref_node(self_nid);
Expand All @@ -1421,7 +1409,7 @@ auto Graph::get_pins(Node_class node) -> std::vector<Pin_class> {

auto Graph::get_driver_pins(Node_class node) -> std::vector<Pin_class> {
assert_accessible();
assert_compatible(node);
assert_node_exists(node);
std::vector<Pin_class> out;
for (const auto& pin : get_pins(node)) {
const Pid pid_lookup = (pin.get_pin_pid() & ~static_cast<Pid>(2)) | static_cast<Pid>(1);
Expand All @@ -1441,7 +1429,7 @@ auto Graph::get_driver_pins(Node_class node) -> std::vector<Pin_class> {

auto Graph::get_sink_pins(Node_class node) -> std::vector<Pin_class> {
assert_accessible();
assert_compatible(node);
assert_node_exists(node);
std::vector<Pin_class> out;
for (const auto& pin : get_pins(node)) {
const Pid pid_lookup = (pin.get_pin_pid() & ~static_cast<Pid>(2)) | static_cast<Pid>(1);
Expand Down
Loading
Loading