Skip to content

Whisper generation fails on empty transcription after align_special_tokens #45584

@ronansgd

Description

@ronansgd

System Info

  • transformers version: 5.6.0.dev0
  • Platform: Linux-6.17.0-1009-gcp-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 1.11.0
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.11.0+cu130 (CUDA)
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: no
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@eustlb

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Basic reproduction script

import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers.trainer_utils import align_special_tokens


def main() -> None:
    model_name = "openai/whisper-tiny"
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    # clear suppress tokens so the model can freely produce EOS when there is nothing to transcribe.
    model.generation_config.begin_suppress_tokens = None
    model.generation_config.suppress_tokens = [
        i for i in range(model.config.vocab_size) if i != model.config.eos_token_id
    ]

    # --- Without align_special_tokens: works fine ---
    features = processor(torch.zeros(16_000 * 10).numpy(), sampling_rate=16_000, return_tensors="pt").input_features

    out = model.generate(features)
    decoded = processor.batch_decode(out, skip_special_tokens=True)
    print(f"Before align_special_tokens: generate() on silence → {decoded!r}  (empty transcription, as expected)")

    # --- After align_special_tokens (called e.g. by Trainer.train()): crashes ---
    align_special_tokens(model, processor)

    print("\nAfter align_special_tokens: generate() on silence →", end=" ")
    try:
        model.generate(features)
        print("no crash")
    except IndexError as e:
        print(f"IndexError: {e}")


if __name__ == "__main__":
    main()

Expected behavior

The generation should not fail (and return an empty transcription)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions