Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added regress_lm/__pycache__/__init__.cpython-313.pyc
Binary file not shown.
Binary file added regress_lm/__pycache__/core.cpython-313.pyc
Binary file not shown.
Binary file added regress_lm/__pycache__/tokenizers.cpython-313.pyc
Binary file not shown.
Binary file added regress_lm/__pycache__/vocabs.cpython-313.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 0 additions & 1 deletion regress_lm/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""PyTorch implementation of a RegressLM."""

# pytype:disable=attribute-error
from concurrent import futures
import dataclasses
import functools
from typing import Any, Sequence
Expand Down
117 changes: 117 additions & 0 deletions regress_lm/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import transformers
import sentencepiece as spp
import sentencepiece as spt
import torch


ObjectT = TypeVar("ObjectT")
Expand Down Expand Up @@ -106,6 +107,122 @@ def possible_next_token_ids(self, prev_tokens: Sequence[int]) -> list[int]:
possible_next_tokens = self.tokenizer.possible_next_tokens(new_tokens)
return [self.stoi[t] for t in possible_next_tokens]

def get_allowed_tokens_mask(
self, prev_tokens: torch.Tensor, device: torch.device
) -> torch.Tensor:
"""Returns a boolean mask of allowed next tokens for a batch.

Args:
prev_tokens: (batch_size, seq_len) tensor of previous token IDs.
device: The device to create the mask on.

Returns:
(batch_size, vocab_size) boolean tensor where True means allowed.
"""
batch_size, seq_len = prev_tokens.shape
obj_step = seq_len % self.num_tokens_per_obj

# 1. Base mask based on position in the object (0, 1, ... num_tokens_per_obj-1)
# We cache these to avoid recomputing.
if not hasattr(self, "_cached_pos_masks"):
self._cached_pos_masks = {}

if obj_step not in self._cached_pos_masks:
# Create a dummy sequence to query the tokenizer
# We just need the length to be correct for standard tokenizers
dummy_prev = [self.pad_token] * obj_step # Content doesn't matter for pos-based
# Note: This assumes standard tokenizers only care about length/position
# for the "structure" of the float.
# Special values are handled separately below.

# We need to be careful: possible_next_tokens might raise if we pass garbage
# but for P10/IEEE/Normalized, they check length.
# Let's try to construct a valid-looking prefix if possible, or just rely on
# the fact that they check length.
# Actually, let's use the tokenizer's logic directly if possible, or just
# use the existing possible_next_tokens with a "safe" prefix.

# For P10: checks index.
# For IEEE: checks index.
# For Normalized: checks index.
# For AddSpecialValues: checks index AND content of first token.

# So for the "base" mask, we want the mask assuming it's a "normal" number.
# We can pass a prefix that indicates a normal number.
# e.g. ["+"] if obj_step=1.

safe_prefix = []
if obj_step > 0:
# Try to find a "normal" token to start with.
# For P10/IEEE, '+' is usually safe.
# We can iterate self.itos and find one that isn't special.
# But simpler: just use the first token that is valid for step 0.
pass # We will handle this by just using the existing method with a "clean" prefix

# Actually, we can just compute the mask for a "generic" case.
# But wait, AddSpecialValues *wraps* another tokenizer.
# If we ask AddSpecialValues.possible_next_tokens with a "normal" prefix,
# it delegates to the inner tokenizer. That's exactly what we want for the base mask.

# Construct a prefix that looks "normal".
# We can just use the first valid token for each step to build a chain.
# This is a bit hacky but robust enough for these specific tokenizers.
prefix = []
for _ in range(obj_step):
allowed = self.possible_next_token_ids([self.stoi[t] for t in prefix])
# Pick the first one
prefix.append(self.itos[allowed[0]])

allowed_indices = self.possible_next_token_ids([self.stoi[t] for t in prefix])
mask = torch.zeros(len(self), dtype=torch.bool, device=device)
mask[allowed_indices] = True
self._cached_pos_masks[obj_step] = mask

# Start with the cached mask for this position
mask = self._cached_pos_masks[obj_step].to(device).repeat(batch_size, 1)

# 2. Handle Special Values (NAN, INF, INVALID)
# If we are inside an object (obj_step > 0), we must check if we started with a special token.
if obj_step > 0:
# Check if the tokenizer has special values
# We look for the AddSpecialValues wrapper or similar logic.
# We can detect this by checking if there are tokens that force themselves to be repeated.
# Or more explicitly, we can check if the tokenizer has `_special_tokens`.

# Let's look at the tokenizer chain.
tokenizer = self.tokenizer
special_tokens_map = {} # start_token_id -> special_token_id (usually same)

# Unwrap to find AddSpecialValues
curr = tokenizer
while hasattr(curr, "tokenizer") or hasattr(curr, "_tokenizer"):
if hasattr(curr, "_special_tokens"):
# Found it.
for st in curr._special_tokens:
if st in self.stoi:
sid = self.stoi[st]
special_tokens_map[sid] = sid
break
curr = getattr(curr, "tokenizer", getattr(curr, "_tokenizer", None))

if special_tokens_map:
# Find which sequences started with a special token
# The start of the current object is at index: seq_len - obj_step
start_token_ids = prev_tokens[:, seq_len - obj_step]

for special_id in special_tokens_map:
# Identify rows where the object started with this special token
is_special = (start_token_ids == special_id)

if is_special.any():
# For these rows, the ONLY allowed token is the special token itself
# (assuming special values are repeated like <NAN><NAN><NAN>)
# We zero out the mask for these rows and set the special token to True
mask[is_special] = False
mask[is_special, special_id] = True

return mask

@property
def bos_pad_id(self) -> int:
"""Returns the BOS / PAD id for the decoder."""
Expand Down