Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/nn/conv/test_hypergraph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ def test_hypergraph_conv_with_more_edges_than_nodes():
assert out.size() == (num_nodes, out_channels)
out = conv(x, hyperedge_index, hyperedge_weight)
assert out.size() == (num_nodes, out_channels)


def test_hypergraph_conv_weight_effect():
"""Test that non-uniform hyperedge weights produce different results
from uniform weights, verifying that W is applied correctly per the
formula X' = D^{-1} H W B^{-1} H^T X Theta.
"""
in_channels, out_channels = (16, 32)
hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]])
num_nodes = hyperedge_index[0].max().item() + 1
x = torch.randn((num_nodes, in_channels))

conv = HypergraphConv(in_channels, out_channels)

# Uniform weights (all ones) should match default (no weight argument)
uniform_weight = torch.ones(2)
out_default = conv(x, hyperedge_index)
out_uniform = conv(x, hyperedge_index, uniform_weight)
assert torch.allclose(out_default, out_uniform, atol=1e-6)

# Non-uniform weights should produce different results
nonuniform_weight = torch.tensor([1.0, 0.5])
out_nonuniform = conv(x, hyperedge_index, nonuniform_weight)
assert not torch.allclose(out_default, out_nonuniform, atol=1e-4)
10 changes: 10 additions & 0 deletions torch_geometric/nn/conv/hypergraph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def forward(self, x: Tensor, hyperedge_index: Tensor,

out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha,
size=(num_nodes, num_edges))

# Apply hyperedge weights W to the intermediate edge representation.
# Per the formula X' = D^{-1} H W B^{-1} H^T X Theta, the diagonal
# weight matrix W must scale edge features between the two message
# passing steps (node-to-edge followed by edge-to-node):
if self.use_attention:
out = out * hyperedge_weight.view(-1, 1, 1)
else:
out = out * hyperedge_weight.view(-1, 1)

out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D,
alpha=alpha, size=(num_edges, num_nodes))

Expand Down