Skip to content

Commit ffa82e6

Browse files
authored
Merge branch 'main' into export-D97202714
2 parents 4eb4f95 + 8d43d97 commit ffa82e6

File tree

6 files changed

+977
-104
lines changed

6 files changed

+977
-104
lines changed

examples/models/llama/attention.py

Lines changed: 166 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class ForwardOptions(TypedDict, total=False):
2121
in_cache_state: Optional[Any]
2222
out_cache_state: Optional[Any]
2323
last_valid_token_pos: Optional[torch.LongTensor]
24+
# YOCO (You Only Cache Once): shared K/V from a donor layer.
25+
# When provided, the attention layer skips its own K/V projection
26+
# and reuses the donor's K/V instead.
27+
shared_kv: Optional[Tuple[torch.Tensor, torch.Tensor]]
2428

2529

2630
class Attention(nn.Module, ABC):
@@ -314,6 +318,28 @@ def update(
314318
return self.k_cache, self.v_cache
315319

316320

321+
def _create_projection(
322+
args: ModelArgs,
323+
in_dim: int,
324+
out_dim: int,
325+
target_names: Tuple[str, ...],
326+
bias: bool = False,
327+
) -> nn.Module:
328+
"""Create a Linear or LoRALinear projection based on target_modules config."""
329+
if args.target_modules is not None and any(
330+
n in args.target_modules for n in target_names
331+
):
332+
return LoRALinear(
333+
in_dim=in_dim,
334+
out_dim=out_dim,
335+
rank=args.r,
336+
alpha=args.lora_alpha,
337+
dropout=0.0,
338+
use_bias=bias,
339+
)
340+
return nn.Linear(in_dim, out_dim, bias=bias)
341+
342+
317343
@register_attention("mha")
318344
class AttentionMHA(Attention):
319345
def __init__(
@@ -351,78 +377,22 @@ def __init__(
351377
self.enable_dynamic_shape = args.enable_dynamic_shape
352378
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
353379

354-
if self.use_qk_norm:
355-
q_norm_dim = self.head_dim
356-
k_norm_dim = self.head_dim
357-
self.q_norm_fn = RMSNorm(
358-
q_norm_dim,
359-
eps=args.norm_eps,
360-
add_unit_offset=args.rms_norm_add_unit_offset,
361-
)
362-
self.k_norm_fn = RMSNorm(
363-
k_norm_dim,
364-
eps=args.norm_eps,
365-
add_unit_offset=args.rms_norm_add_unit_offset,
366-
)
380+
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
381+
num_kv_shared = args.num_kv_shared_layers
382+
n_layers = args.n_layers
383+
if num_kv_shared > 0:
384+
first_shared = n_layers - num_kv_shared
385+
self.is_kv_shared_layer = layer_id >= first_shared and first_shared > 0
386+
else:
387+
self.is_kv_shared_layer = False
367388

368-
self.wq = (
369-
LoRALinear(
370-
in_dim=args.dim,
371-
out_dim=q_out_dim,
372-
rank=args.r,
373-
alpha=args.lora_alpha,
374-
dropout=0.0,
375-
use_bias=args.attention_qkv_bias,
376-
)
377-
if args.target_modules is not None and "q_proj" in args.target_modules
378-
else nn.Linear(self.dim, q_out_dim, bias=self.attention_qkv_bias)
379-
)
380-
self.wk = (
381-
LoRALinear(
382-
in_dim=args.dim,
383-
out_dim=args.n_kv_heads * args.head_dim,
384-
rank=args.r,
385-
alpha=args.lora_alpha,
386-
dropout=0.0,
387-
use_bias=args.attention_qkv_bias,
388-
)
389-
if args.target_modules is not None and "k_proj" in args.target_modules
390-
else nn.Linear(
391-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
392-
)
393-
)
394-
self.wv = (
395-
LoRALinear(
396-
in_dim=args.dim,
397-
out_dim=args.n_kv_heads * args.head_dim,
398-
rank=args.r,
399-
alpha=args.lora_alpha,
400-
dropout=0.0,
401-
use_bias=args.attention_qkv_bias,
402-
)
403-
if args.target_modules is not None and "v_proj" in args.target_modules
404-
else nn.Linear(
405-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
406-
)
407-
)
408-
self.wo = (
409-
LoRALinear(
410-
in_dim=args.n_heads * args.head_dim,
411-
out_dim=args.dim,
412-
rank=args.r,
413-
alpha=args.lora_alpha,
414-
dropout=0.0,
415-
use_bias=args.attention_qkv_bias,
416-
)
417-
if args.target_modules is not None
418-
and (
419-
"output_proj" in args.target_modules or "o_proj" in args.target_modules
420-
)
421-
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
422-
)
389+
self.num_kv_shared_layers = num_kv_shared
390+
self.has_kv_weights = not self.is_kv_shared_layer
423391

424-
self.layer_id = layer_id
392+
self._init_norms(args)
393+
self._init_projections(args, q_out_dim)
425394

395+
self.layer_id = layer_id
426396
self.rope = rope
427397

428398
causal_mask = torch.tril(
@@ -436,40 +406,102 @@ def __init__(
436406
self.register_buffer("mask", causal_mask, persistent=False)
437407

438408
if self.use_kv_cache:
409+
self._init_kv_cache(args)
410+
self.SDPA = SDPA(
411+
dim=self.n_local_heads * self.head_dim,
412+
head_dim=self.head_dim,
413+
n_rep=self.n_rep,
414+
max_context_len=self.max_context_len,
415+
)
416+
417+
def _init_norms(self, args: ModelArgs) -> None:
418+
"""Initialize QK normalization layers."""
419+
if self.use_qk_norm:
420+
self.q_norm_fn = RMSNorm(
421+
self.head_dim,
422+
eps=args.norm_eps,
423+
add_unit_offset=args.rms_norm_add_unit_offset,
424+
)
425+
if self.has_kv_weights:
426+
self.k_norm_fn = RMSNorm(
427+
self.head_dim,
428+
eps=args.norm_eps,
429+
add_unit_offset=args.rms_norm_add_unit_offset,
430+
)
431+
432+
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433+
"""Initialize Q/K/V/O projection layers."""
434+
self.wq = _create_projection(
435+
args, args.dim, q_out_dim, ("q_proj",), bias=self.attention_qkv_bias
436+
)
437+
if self.has_kv_weights:
438+
kv_dim = self.n_kv_heads * self.head_dim
439+
self.wk = _create_projection(
440+
args, args.dim, kv_dim, ("k_proj",), bias=self.attention_qkv_bias
441+
)
442+
self.wv = _create_projection(
443+
args, args.dim, kv_dim, ("v_proj",), bias=self.attention_qkv_bias
444+
)
445+
else:
446+
self.wk = None
447+
self.wv = None
448+
self.wo = _create_projection(
449+
args,
450+
args.n_heads * args.head_dim,
451+
args.dim,
452+
("output_proj", "o_proj"),
453+
bias=False,
454+
)
455+
456+
def _init_kv_cache(self, args: ModelArgs) -> None:
457+
"""Initialize KV cache (only for non-shared layers)."""
458+
if self.has_kv_weights:
439459
self.kv_cache = KVCache(
440460
args.max_batch_size,
441461
args.max_context_len,
442462
self.n_kv_heads,
443463
self.head_dim,
444464
args.enable_dynamic_shape,
445465
)
446-
self.SDPA = SDPA(
447-
dim=self.n_local_heads * self.head_dim,
448-
head_dim=self.head_dim,
449-
n_rep=self.n_rep,
450-
max_context_len=self.max_context_len,
451-
)
466+
else:
467+
self.kv_cache = None
452468

453-
def forward(
469+
def _prepare_qkv_shared(
454470
self,
455-
x: torch.Tensor,
471+
q: torch.Tensor,
472+
shared_kv: Tuple[torch.Tensor, torch.Tensor],
456473
freqs_cos: torch.Tensor,
457474
freqs_sin: torch.Tensor,
458-
**kwargs: ForwardOptions,
459-
) -> Tuple[torch.Tensor, Optional[Any]]:
460-
input_pos = kwargs.get("input_pos")
461-
bsz, seqlen, _ = x.shape
475+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
476+
"""Prepare Q/K/V when using shared KV from a donor layer (YOCO)."""
477+
k, v = shared_kv
462478

463-
if self.use_q_gate:
464-
q_and_gate = self.wq(x).view(
465-
bsz, seqlen, self.n_local_heads, self.head_dim * 2
466-
)
467-
q, gate = torch.chunk(q_and_gate, 2, dim=-1)
468-
gate = gate.reshape(bsz, seqlen, -1)
469-
else:
470-
q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
471-
gate = None
479+
if self.use_qk_norm and self.qk_norm_before_rope:
480+
q = self.q_norm_fn(q)
481+
482+
# Apply RoPE to Q only (K already has RoPE from donor layer)
483+
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484+
q = q.transpose(1, 2)
485+
486+
if self.use_qk_norm and not self.qk_norm_before_rope:
487+
q = self.q_norm_fn(q)
488+
489+
return q, k, v
472490

491+
def _prepare_qkv(
492+
self,
493+
q: torch.Tensor,
494+
x: torch.Tensor,
495+
bsz: int,
496+
seqlen: int,
497+
freqs_cos: torch.Tensor,
498+
freqs_sin: torch.Tensor,
499+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
500+
"""Prepare Q/K/V with standard projection (non-YOCO path)."""
501+
assert self.wk is not None and self.wv is not None, (
502+
"wk/wv projections are required when shared_kv is not provided. "
503+
"This layer may be a YOCO shared layer that requires shared_kv from a donor."
504+
)
473505
k, v = self.wk(x), self.wv(x)
474506
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
475507
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@@ -478,7 +510,6 @@ def forward(
478510
q = self.q_norm_fn(q)
479511
k = self.k_norm_fn(k)
480512

481-
# RoPE relative positional embeddings
482513
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
483514

484515
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
@@ -489,6 +520,34 @@ def forward(
489520
q = self.q_norm_fn(q)
490521
k = self.k_norm_fn(k)
491522

523+
return q, k, v
524+
525+
def forward(
526+
self,
527+
x: torch.Tensor,
528+
freqs_cos: torch.Tensor,
529+
freqs_sin: torch.Tensor,
530+
**kwargs: ForwardOptions,
531+
) -> Tuple[torch.Tensor, Optional[Any]]:
532+
input_pos = kwargs.get("input_pos")
533+
shared_kv = kwargs.get("shared_kv")
534+
bsz, seqlen, _ = x.shape
535+
536+
if self.use_q_gate:
537+
q_and_gate = self.wq(x).view(
538+
bsz, seqlen, self.n_local_heads, self.head_dim * 2
539+
)
540+
q, gate = torch.chunk(q_and_gate, 2, dim=-1)
541+
gate = gate.reshape(bsz, seqlen, -1)
542+
else:
543+
q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
544+
gate = None
545+
546+
if shared_kv is not None:
547+
q, k, v = self._prepare_qkv_shared(q, shared_kv, freqs_cos, freqs_sin)
548+
else:
549+
q, k, v = self._prepare_qkv(q, x, bsz, seqlen, freqs_cos, freqs_sin)
550+
492551
if self.use_kv_cache:
493552
assert input_pos is not None
494553
if self.enable_dynamic_shape:
@@ -501,15 +560,29 @@ def forward(
501560
else:
502561
# mask is always 2D
503562
attn_mask = self.mask[input_pos]
504-
k, v = self.kv_cache.update(input_pos, k, v)
563+
564+
# Only update KV cache for non-shared layers
565+
if shared_kv is None:
566+
assert self.kv_cache is not None, (
567+
"kv_cache is required when shared_kv is not provided. "
568+
"This layer may be a YOCO shared layer that requires shared_kv from a donor."
569+
)
570+
k, v = self.kv_cache.update(input_pos, k, v)
571+
505572
if getattr(self.kv_cache, "is_ring_buffer", False):
506573
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
507574
input_pos[0].item(), seqlen
508575
)
576+
509577
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
510578
if gate is not None:
511579
output = output * torch.sigmoid(gate)
512-
return self.wo(output), None
580+
581+
if shared_kv is None and self.num_kv_shared_layers > 0:
582+
update = {"kv_to_share": (k, v)}
583+
else:
584+
update = None
585+
return self.wo(output), update
513586

514587
# grouped multiquery attention: expand out keys and values
515588
k = k.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,7 @@ def _load_llama_model_metadata(
14861486
n_layers: int,
14871487
vocab_size: int,
14881488
metadata_str: Optional[str] = None,
1489+
num_kv_shared_layers: int = 0,
14891490
):
14901491
metadata = {
14911492
"get_max_seq_len": max_seq_len,
@@ -1496,6 +1497,9 @@ def _load_llama_model_metadata(
14961497
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
14971498
"enable_dynamic_shape": enable_dynamic_shape,
14981499
}
1500+
# YOCO (You Only Cache Once) KV sharing metadata
1501+
if num_kv_shared_layers > 0:
1502+
metadata["get_num_kv_shared_layers"] = num_kv_shared_layers
14991503
if metadata_str:
15001504
try:
15011505
extra = json.loads(metadata_str)
@@ -1575,6 +1579,9 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
15751579
# Module]`.
15761580
model.vocab_size,
15771581
llm_config.base.metadata,
1582+
# pyre-fixme[6]: For 10th argument expected `int` but got `Union[Tensor,
1583+
# Module]`.
1584+
num_kv_shared_layers=getattr(model, "num_kv_shared_layers", 0),
15781585
),
15791586
)
15801587

0 commit comments

Comments
 (0)