Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 270 additions & 14 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
import time
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import torch
from torch import nn
Expand All @@ -28,6 +28,7 @@
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name
from ..quantization.awq.utils.utils import get_best_device
from ..quantization.config import FORMAT, METHOD, QuantizeConfig, resolve_quant_format
from ..utils.attn_mask import normalize_seq_mask
from ..utils.ctx import ctx
from ..utils.device import get_device
from ..utils.fallback import normalize_fallback
Expand Down Expand Up @@ -357,10 +358,239 @@ def _record_input_feature(self, module_name: str, feature: torch.Tensor) -> None
with self.lock:
entry = self.tasks.get(module_name)
if entry is None:
entry = {"inputs": []}
entry = {"inputs": [], "batch_indices": []}
self.tasks[module_name] = entry
inputs_list = entry.setdefault("inputs", [])
inputs_list.append(feature)
entry.setdefault("batch_indices", []).append(self.current_batch_index())

@staticmethod
def _can_concat_batch_tensors(tensors: List[torch.Tensor]) -> bool:
"""Return whether cached tensors share the same non-batch shape."""

if not tensors:
return False

first = tensors[0]
return all(
tensor.dim() == first.dim() and tensor.shape[1:] == first.shape[1:]
for tensor in tensors
)

@staticmethod
def _concat_batch_metadata(values: List[Any]) -> Any:
"""Merge cached per-batch metadata while preserving feature alignment.

When every tensor has the same trailing shape, we concatenate on dim 0
so the metadata still lines up with a concatenated feature tensor.

When sequence lengths differ across batches, the cached metadata no
longer shares a common shape. Real AWQ repros hit cases like
`attention_mask` `[1, 1, 423, 423]` plus `[1, 1, 36, 36]` and
`position_ids` `[1, 423]` plus `[1, 36]`. Those values cannot be
concatenated into one batch-aligned tensor, so we return the most
recent non-None value. `_layer_input_features()` makes the same choice
for the corresponding feature tensor by keeping `tensors[-1]`.
"""

non_none = [value for value in values if value is not None]
if not non_none:
return None

first = non_none[0]
if torch.is_tensor(first):
tensors = [value for value in values if torch.is_tensor(value)]
if len(tensors) != len(non_none):
return non_none[-1]
if AWQProcessor._can_concat_batch_tensors(tensors):
return torch.cat(tensors, dim=0)
return non_none[-1]

return non_none[-1]

def _feature_kwargs_from_batch_indices(self, batch_indices: List[Optional[int]]) -> Dict[str, Any]:
"""Build kwargs aligned to the captured feature batches.

This helper mirrors how `_layer_input_features()` collapses cached
activations. If all selected metadata tensors share the same trailing
dimensions we concatenate them. If sequence lengths differ across
batches, we keep the most recent metadata entry so it stays aligned
with the most recent feature tensor fallback.
"""

feature_kwargs: Dict[str, Any] = dict(getattr(self, "_module_forward_kwargs", {}))
cache = getattr(self, "inputs_cache", None)
if cache is None:
return feature_kwargs

valid_indices = [int(idx) for idx in batch_indices if idx is not None and idx >= 0]
if not valid_indices:
return feature_kwargs

attention_masks = getattr(cache, "attention_masks", None) or []
if attention_masks:
selected_masks = [
attention_masks[idx]
for idx in valid_indices
if idx < len(attention_masks)
]
merged_mask = self._concat_batch_metadata(selected_masks)
if merged_mask is not None:
feature_kwargs["attention_mask"] = merged_mask

position_ids = getattr(cache, "position_ids", None) or []
if position_ids:
selected_pos = [
position_ids[idx]
for idx in valid_indices
if idx < len(position_ids)
]
merged_pos = self._concat_batch_metadata(selected_pos)
if merged_pos is not None:
feature_kwargs["position_ids"] = merged_pos

input_kwargs = getattr(cache, "layer_input_kwargs", None) or []
if input_kwargs:
selected_kwargs = [
input_kwargs[idx]
for idx in valid_indices
if idx < len(input_kwargs)
]
if selected_kwargs:
merged_input_kwargs: Dict[str, Any] = {}
all_keys = {
key
for kwargs in selected_kwargs
for key in kwargs.keys()
}
for key in all_keys:
merged_value = self._concat_batch_metadata(
[kwargs.get(key) for kwargs in selected_kwargs]
)
if merged_value is not None:
merged_input_kwargs[key] = merged_value
feature_kwargs.update(merged_input_kwargs)

return feature_kwargs

@staticmethod
def _pack_kept_token_rows(batch_tensor: torch.Tensor, keep_mask: torch.Tensor) -> torch.Tensor:
"""Pack per-batch row tensors into one kept-token row tensor."""

batch = keep_mask.shape[0]
source = batch_tensor
if source.shape[0] == 1 and batch > 1:
source = source.expand(batch, *source.shape[1:])
packed_rows = [source[batch_index, keep_mask[batch_index]] for batch_index in range(batch)]
return torch.cat(packed_rows, dim=0).unsqueeze(0)

@staticmethod
def _pack_square_token_blocks(batch_tensor: torch.Tensor, keep_mask: torch.Tensor) -> torch.Tensor:
"""Pack per-batch square token tensors into one block-diagonal tensor."""

lengths = [int(row.sum().item()) for row in keep_mask]
total_kept = sum(lengths)
fill_value = torch.finfo(batch_tensor.dtype).min if batch_tensor.is_floating_point() else 0
packed = torch.full(
(1, *batch_tensor.shape[1:-2], total_kept, total_kept),
fill_value=fill_value,
dtype=batch_tensor.dtype,
device=batch_tensor.device,
)
offset = 0
for batch_index, length in enumerate(lengths):
if length <= 0:
continue
keep = keep_mask[batch_index].to(device=batch_tensor.device, dtype=torch.bool)
block = batch_tensor[batch_index:batch_index + 1][..., keep, :]
block = block[..., keep]
packed[..., offset:offset + length, offset:offset + length] = block
offset += length
return packed

@staticmethod
def _pack_attention_mask_for_feature(
attention_mask: torch.Tensor,
keep_mask: torch.Tensor,
) -> torch.Tensor:
"""Pack a per-sample mask to match a flattened kept-token feature tensor."""

batch = keep_mask.shape[0]

if attention_mask.ndim in (3, 4) and attention_mask.shape[0] == batch:
return AWQProcessor._pack_square_token_blocks(attention_mask, keep_mask)

if attention_mask.ndim == 2 and attention_mask.shape[1] == keep_mask.shape[1]:
return AWQProcessor._pack_kept_token_rows(attention_mask, keep_mask)

return attention_mask

@staticmethod
def _pack_position_ids_for_feature(
position_ids: torch.Tensor,
keep_mask: torch.Tensor,
) -> torch.Tensor:
"""Pack per-sample position ids to match flattened kept-token feature tensors."""

batch = keep_mask.shape[0]
if position_ids.ndim != 2 or position_ids.shape[1] != keep_mask.shape[1]:
return position_ids

if position_ids.shape[0] not in (1, batch):
return position_ids

return AWQProcessor._pack_kept_token_rows(position_ids, keep_mask)

def _align_module_kwargs_to_input(
self,
inp: torch.Tensor,
module_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Align masks and positions to packed feature tensors used during AWQ replay.

AWQ capture hooks can flatten kept tokens from multiple batch rows into
one replay tensor shaped like `[1, kept_tokens, hidden]`. The original
cached kwargs still describe the pre-packed batch layout, for example an
attention mask shaped `[B, 1, S, S]` or position ids shaped `[B, S]`.
Passing those batched kwargs to replay produces shape mismatches in
attention because the replay activations no longer have the original
batch and sequence structure.

When the packed token count matches the keep-mask derived from the
cached attention mask, we rebuild `attention_mask` and `position_ids`
to match the packed layout. If the counts do not match, we drop those
kwargs rather than pass incompatible shapes into attention replay.
"""

aligned_kwargs = dict(module_kwargs)
attention_mask = aligned_kwargs.get("attention_mask")
if not torch.is_tensor(inp) or inp.dim() < 3 or not torch.is_tensor(attention_mask):
return aligned_kwargs

if inp.shape[0] != 1 or attention_mask.ndim < 2:
return aligned_kwargs

try:
keep_mask = normalize_seq_mask(attention_mask)
except Exception:
return aligned_kwargs

if keep_mask is None or keep_mask.ndim != 2 or keep_mask.shape[0] <= 1:
return aligned_kwargs

total_kept = int(keep_mask.to(dtype=torch.int64).sum().item())
if total_kept != int(inp.shape[1]):
aligned_kwargs.pop("attention_mask", None)
aligned_kwargs.pop("position_ids", None)
return aligned_kwargs

aligned_kwargs["attention_mask"] = self._pack_attention_mask_for_feature(attention_mask, keep_mask)

position_ids = aligned_kwargs.get("position_ids")
if torch.is_tensor(position_ids):
aligned_kwargs["position_ids"] = self._pack_position_ids_for_feature(position_ids, keep_mask)

return aligned_kwargs

def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, NamedModule]]) -> Optional[float]:
"""Estimates the average weight scale of the previous subset for reuse heuristics."""
Expand All @@ -382,22 +612,41 @@ def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, Nam
return float(sum(values) / len(values))

def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor]:
"""Collapses per-batch cached inputs into one feature tensor per module."""
"""Collapse cached per-batch inputs into one replay tensor per module.

Most batches can be concatenated on dim 0. Variable-length calibration
batches cannot: for example `[1, 423, H]` and `[1, 36, H]` represent
different sequence lengths after masking or packing. In that case we
keep the most recent feature tensor and rebuild kwargs from the same
batch index so activations, masks, and position ids remain aligned.
"""

features: Dict[str, torch.Tensor] = {}
feature_kwargs: Dict[str, Dict[str, Any]] = {}
root_buckets: Dict[str, List[torch.Tensor]] = {}
# Iterate over a snapshot since quantization may mutate state.modules concurrently
for name in list(state.modules):
entry = self.tasks.get(name) or {}
tensors: List[torch.Tensor] = entry.get("inputs", []) # type: ignore[arg-type]
batch_indices: List[Optional[int]] = entry.get("batch_indices", []) # type: ignore[arg-type]
if not tensors:
features[name] = torch.empty(0)
feature_kwargs[name] = {}
continue
try:
if self._can_concat_batch_tensors(tensors):
features[name] = torch.cat(tensors, dim=0)
feature_kwargs[name] = self._feature_kwargs_from_batch_indices(batch_indices)
entry["inputs"] = [features[name]]
entry["batch_indices"] = [None]
else:
# Variable-length captures such as `[1, 423, H]` and `[1, 36, H]`
# cannot be concatenated on dim 0. Keep the latest capture and
# reuse metadata from the same batch index so replay stays aligned.
features[name] = tensors[-1]
last_batch_index = batch_indices[-1] if batch_indices else None
feature_kwargs[name] = self._feature_kwargs_from_batch_indices([last_batch_index])
entry["inputs"] = [features[name]]
except RuntimeError:
features[name] = tensors[0]
entry["batch_indices"] = [last_batch_index]
root = name.split(".", 1)[0]
root_buckets.setdefault(root, []).extend(tensors)
if features[name] is not None and features[name].numel() > 0:
Expand All @@ -415,6 +664,7 @@ def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor
# features[root] = torch.cat(tensors, dim=0)
# except RuntimeError:
# features[root] = tensors[0]
self._awq_feature_kwargs = feature_kwargs
return features

def _quantize_layer_fallback(
Expand Down Expand Up @@ -604,6 +854,9 @@ def unwrap(m):
}

module_kwargs_global = dict(self._module_forward_kwargs)
module_kwargs_global["_awq_feature_kwargs"] = dict(
getattr(self, "_awq_feature_kwargs", {})
)

setattr(self._scale_context, "layer_index", layer_index)
setattr(self._scale_context, "prev_scale", state.previous_weight_scale)
Expand Down Expand Up @@ -894,6 +1147,7 @@ def _search_best_scale(
global_allowed_kwargs = self._sanitize_kwargs(global_kwargs, module2inspect)
for key, value in global_allowed_kwargs.items():
module_kwargs.setdefault(key, value)
module_kwargs = self._align_module_kwargs_to_input(inp, module_kwargs)

if use_chunked_scale_search:
# Build the FP reference output one micro-batch at a time and move each
Expand Down Expand Up @@ -1362,30 +1616,32 @@ def _iter_module_forward_outputs(
yield module_output
return

def _slice_value(val, length):
full_batch_size = int(x.shape[0]) if x.dim() > 0 else 1

def _slice_value(val, start, length):
"""Slices batch-shaped kwargs to match a micro-batched forward chunk."""

if isinstance(val, torch.Tensor) and val.shape[0] == module_kwargs.get("position_ids", val).shape[0]:
return val[:length]
if isinstance(val, torch.Tensor) and val.shape[0] != length:
return val
if isinstance(val, torch.Tensor):
return val[:length]
if val.ndim > 0 and val.shape[0] == full_batch_size:
return val[start:start + length]
return val
if isinstance(val, (list, tuple)):
sliced = [_slice_value(item, length) for item in val]
sliced = [_slice_value(item, start, length) for item in val]
return type(val)(sliced)
return val

batch_offset = 0
for x_partial in torch.split(x, effective_quant_batch_size, dim=0):
x_forward = x_partial.to(target_device) if target_device is not None and x_partial.device != target_device else x_partial
partial_kwargs = {
key: _slice_value(value, x_forward.shape[0])
key: _slice_value(value, batch_offset, x_forward.shape[0])
for key, value in module_kwargs.items()
}
partial_output = module(x_forward, **partial_kwargs)
if isinstance(partial_output, tuple):
partial_output = partial_output[0]
yield partial_output
batch_offset += x_forward.shape[0]

def _iter_reference_output_chunks(self, x: torch.Tensor, reference_output):
"""Normalizes reference outputs to the same chunking contract as forward outputs."""
Expand Down
16 changes: 16 additions & 0 deletions gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,25 @@ def _should_drain_finalize_futures_synchronously(
than the weight-only paths. Letting its finalizers overlap the next layer
can visibly ratchet active VRAM upward from layer N to N+1, so ParoQuant
always drains per-layer finalizers synchronously.

Any multi-accelerator quantization flow can overlap layer N finalizers with
layer N+1 materialization/replay if we keep the default async drain. That
saves some wall time, but it also broadens the lifetime of device-resident
weights, activations, and packing state across layer boundaries. In
practice, the overlap is not worth the allocator pressure risk, so
multi-device runs drain per-layer finalizers synchronously.
"""
if looper.gptq_model.quantize_config.wait_for_submodule_finalizers:
return True

quant_devices = getattr(looper, "_quant_devices", None) or []
active_accelerators = {
(device.type, device.index)
for device_like in quant_devices
if (device := normalize_device_like(device_like)) is not None and device.type != "cpu"
}
if len(active_accelerators) > 1:
return True
return any(isinstance(process, ParoQuantProcessor) for process, *_ in finalize_tasks)


Expand Down
Loading
Loading