Skip to content

GraniteMoEHybrid Model Calls Invalid Method #45507

@rnowling

Description

@rnowling

System Info

Linux: Ubuntu 24.04.4 LTS / 6.8.0-107-generic-64k / aarch64
Python: 3.12.12
Transformers: 5.5.4
Cuda: 12.9

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the following script as follows:

$ CUDA_VISIBLE_DEVICES=0 python generate_text.py --model-name ibm-granite/granite-4.0-350m-base --n-samples 256 --output-fl granite-350m-base.feather
import argparse

from datasets import Dataset
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer

BATCH_SIZE = 32
MAX_LENGTH = 1024
PREFIX = "Your horoscope is: "

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--output-fl", type=str, required=True)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--n-samples", type=int, required=True)

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(args.model_name)

    print(model.generation_config)

    # see tips here: https://huggingface.co/docs/transformers/en/model_doc/llama3
    if tokenizer.pad_token is None:
        print("Needed to add padding token")
        tokenizer.add_special_tokens({"pad_token":"<pad>"})
        model.resize_token_embeddings(len(tokenizer))
        model.config.pad_token_id = tokenizer.pad_token_id

    model.eval()

    # why is this necessary?
    model.to("cuda")

    batch = [PREFIX] * BATCH_SIZE
    model_inputs = tokenizer(batch, return_tensors="pt").to(model.device)

    generated_samples = []
    n_batches = args.n_samples // BATCH_SIZE + 1
    for _ in range(n_batches):
        generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=MAX_LENGTH)
        batch_samples = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        generated_samples.extend(batch_samples)

    if len(generated_samples) > args.n_samples:
        generated_samples = generated_samples[:args.n_samples]

    df = pd.DataFrame({"generated_sample" : generated_samples,
                       "model" : [args.model_name] * args.n_samples })

    df.head()

    df.to_feather(args.output_fl, compression="zstd")

It produces the following stack trace:

Traceback (most recent call last):
  File "/home/rnowling/Projects/robust-llm-data-generators/sequence_token_probabilities/generate_text.py", line 48, in <module>
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=MAX_LENGTH)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2543, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2736, in _sample
    outputs = self._prefill(
              ^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3768, in _prefill
    return self(**model_inputs, return_dict=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 876, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py", line 1365, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py", line 1183, in forward
    mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py", line 1217, in _update_mamba_mask
    if (past_key_values is not None and past_key_values.has_previous_state()) or (
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rnowling/pytorch-venv/lib/python3.12/site-packages/transformers/cache_utils.py", line 1057, in has_previous_state
    raise ValueError(
ValueError: `has_previous_state` can only be called on LinearAttention layers, and the current Cache seem to only contain Attention layers.

Expected behavior

I expect the script to be able to generate text from the model. :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions