Skip to content

Rotation method inconsistent with paper in some scripts #1

@ekhouu

Description

@ekhouu

Sorry if this is in some way intentional/I am missing something; I am pretty new to this stuff. It might be interesting to run tests, but I have not.

Issue

Rotation is implemented inconsistently across files. In the paper, the rotation is $~h = U^{\top} h$, which then in code (assuming the paper uses column-vec) should be h @ U.

Image

spectral_rotation.py and phase2_integration.py implement h @ U, but run_v3_perplexity_crossarch.py and engine.py implement h @ U.T. When the bits are uniform, there is no problem, but in any other case we encounter a problem in the next few lines

Image

I tried to reason through this to see if maybe it was some mix of notation stuff, but I don't think it is. By default, torch.linalg.eigh gives you eigenvectors as columns; h @ U.T doesn't actually use the eigenvectors, it only gets jumbled up stuff.

Because of this, the split at $d_{eff}$ that is meant to split signal and noise does not actually do so.

Tracing back the code

This issue could be in more places, but I think after knowing it exists it's pretty easy to find and fix.

spectral_rotation.py (Correct rotation)

First, the eigenvalues and eigenvectors is from calibration.py (since that's when they're determined):

# ascending col-vec
eigenvalues, eigenvectors = torch.linalg.eigh(cov)
# descending eigenvalues
eigenvalues = eigenvalues.flip(0)
# make columns of eigenvectors descending too
eigenvectors = eigenvector.flip(1)

Definition of U is 215-217:

V = hcd.eigenvectors.float()
Vt = V.T.contigious()
self._cache[Key] = (V, Vt)

So we have our $U$ with columns as eigenvectors. But then at the rotate function:

    def rotate(
        self,
        x: torch.Tensor,
        layer_idx: int,
        head_idx: int,
    ) -> torch.Tensor:
        """Project ``x`` into the spectral basis: :math:`\\hat{x} = V^\\top x`.

        Parameters
        ----------
        x:
            Tensor of shape ``(..., head_dim)``.
        layer_idx:
            Transformer layer index.
        head_idx:
            Attention head index.

        Returns
        -------
        torch.Tensor
            Spectrally rotated tensor, same shape as ``x``.
        """
        _, Vt = self._get_matrices(layer_idx, head_idx)
        Vt = Vt.to(x.device)
        # x: (..., head_dim), Vt: (head_dim, head_dim)
        # Result: (..., head_dim) = x @ V (column-wise: each row of x multiplied by V)
        # Equivalently: (V^T x^T)^T = x @ V
        return x @ Vt.T  # x @ V == (V^T x^T)^T

So this obeys h @ U.

phrase2_integration.py (Correct rotation)

Here we take in eigenvectors and get this V that is our $U$ from the init:

self.V = torch.from_numpy(eigenvectors).float()  # [head_dim, head_dim]

Then the rotate function is simple:

    def rotate(self, x: torch.Tensor) -> torch.Tensor:
        """
        Spectral rotation: x_rot = V^T @ (x - mean)
        x: [..., head_dim]
        Returns: [..., head_dim]
        """
        V = self.V.to(x.device)
        mean = self.mean.to(x.device)
        return (x - mean) @ V  # [..., head_dim]  (equivalent to V^T @ (x - mean) row-wise)

run_v3_perplexity_crossarch.py (Incorrect)

$U$ is evec here (notes are mine), at lines 221-223

# ascending, columns
ev, evec = torch.linalg.eigh(C)
# eigenvalues descending
ev = ev.flip(0).clamp(min=0)
# eigenvectors descending
evec = evec.flip(1)

The rotation happens at 77-81

k_n = torch.norm(K_f, dim=-1, keepdim=True)
K_rot = (K_f / (k_n + 1e-8)) @ VT
v_n = torch.norm(V_f, dim=-1, keepdim=True)
V_rot = (V_f / (v_n + 1e-8)) @ VT

So this is a case of h @ U.T.

engine.py (Incorrect)

We get $U$ as V at 228:

V = eigenvectors.to(device).float()  # [head_dim, head_dim]

(at line 183 it does say these are descending & columns)

But then at 234 V is saved as Pi.

self.Pi = V                       # [head_dim, head_dim]
self.PiT = V.T.contiguous()       # [head_dim, head_dim]

And the rotation block from 434-437:

# --- Spectral rotation ---
rotated = K_normed @ self.PiT.float()                        # (seq_k, head_dim)
rotated_high = rotated[:, :self.d_eff]                       # semantic regime
rotated_low  = rotated[:, self.d_eff:]                       # tail regime

K_normed @ self.PiT.float() is h @ U.T.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions