@@ -21,6 +21,10 @@ class ForwardOptions(TypedDict, total=False):
2121 in_cache_state : Optional [Any ]
2222 out_cache_state : Optional [Any ]
2323 last_valid_token_pos : Optional [torch .LongTensor ]
24+ # YOCO (You Only Cache Once): shared K/V from a donor layer.
25+ # When provided, the attention layer skips its own K/V projection
26+ # and reuses the donor's K/V instead.
27+ shared_kv : Optional [Tuple [torch .Tensor , torch .Tensor ]]
2428
2529
2630class Attention (nn .Module , ABC ):
@@ -314,6 +318,28 @@ def update(
314318 return self .k_cache , self .v_cache
315319
316320
321+ def _create_projection (
322+ args : ModelArgs ,
323+ in_dim : int ,
324+ out_dim : int ,
325+ target_names : Tuple [str , ...],
326+ bias : bool = False ,
327+ ) -> nn .Module :
328+ """Create a Linear or LoRALinear projection based on target_modules config."""
329+ if args .target_modules is not None and any (
330+ n in args .target_modules for n in target_names
331+ ):
332+ return LoRALinear (
333+ in_dim = in_dim ,
334+ out_dim = out_dim ,
335+ rank = args .r ,
336+ alpha = args .lora_alpha ,
337+ dropout = 0.0 ,
338+ use_bias = bias ,
339+ )
340+ return nn .Linear (in_dim , out_dim , bias = bias )
341+
342+
317343@register_attention ("mha" )
318344class AttentionMHA (Attention ):
319345 def __init__ (
@@ -351,78 +377,22 @@ def __init__(
351377 self .enable_dynamic_shape = args .enable_dynamic_shape
352378 q_out_dim = self .n_heads * self .head_dim * (2 if self .use_q_gate else 1 )
353379
354- if self .use_qk_norm :
355- q_norm_dim = self .head_dim
356- k_norm_dim = self .head_dim
357- self .q_norm_fn = RMSNorm (
358- q_norm_dim ,
359- eps = args .norm_eps ,
360- add_unit_offset = args .rms_norm_add_unit_offset ,
361- )
362- self .k_norm_fn = RMSNorm (
363- k_norm_dim ,
364- eps = args .norm_eps ,
365- add_unit_offset = args .rms_norm_add_unit_offset ,
366- )
380+ # YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
381+ num_kv_shared = args .num_kv_shared_layers
382+ n_layers = args .n_layers
383+ if num_kv_shared > 0 :
384+ first_shared = n_layers - num_kv_shared
385+ self .is_kv_shared_layer = layer_id >= first_shared and first_shared > 0
386+ else :
387+ self .is_kv_shared_layer = False
367388
368- self .wq = (
369- LoRALinear (
370- in_dim = args .dim ,
371- out_dim = q_out_dim ,
372- rank = args .r ,
373- alpha = args .lora_alpha ,
374- dropout = 0.0 ,
375- use_bias = args .attention_qkv_bias ,
376- )
377- if args .target_modules is not None and "q_proj" in args .target_modules
378- else nn .Linear (self .dim , q_out_dim , bias = self .attention_qkv_bias )
379- )
380- self .wk = (
381- LoRALinear (
382- in_dim = args .dim ,
383- out_dim = args .n_kv_heads * args .head_dim ,
384- rank = args .r ,
385- alpha = args .lora_alpha ,
386- dropout = 0.0 ,
387- use_bias = args .attention_qkv_bias ,
388- )
389- if args .target_modules is not None and "k_proj" in args .target_modules
390- else nn .Linear (
391- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
392- )
393- )
394- self .wv = (
395- LoRALinear (
396- in_dim = args .dim ,
397- out_dim = args .n_kv_heads * args .head_dim ,
398- rank = args .r ,
399- alpha = args .lora_alpha ,
400- dropout = 0.0 ,
401- use_bias = args .attention_qkv_bias ,
402- )
403- if args .target_modules is not None and "v_proj" in args .target_modules
404- else nn .Linear (
405- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
406- )
407- )
408- self .wo = (
409- LoRALinear (
410- in_dim = args .n_heads * args .head_dim ,
411- out_dim = args .dim ,
412- rank = args .r ,
413- alpha = args .lora_alpha ,
414- dropout = 0.0 ,
415- use_bias = args .attention_qkv_bias ,
416- )
417- if args .target_modules is not None
418- and (
419- "output_proj" in args .target_modules or "o_proj" in args .target_modules
420- )
421- else nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
422- )
389+ self .num_kv_shared_layers = num_kv_shared
390+ self .has_kv_weights = not self .is_kv_shared_layer
423391
424- self .layer_id = layer_id
392+ self ._init_norms (args )
393+ self ._init_projections (args , q_out_dim )
425394
395+ self .layer_id = layer_id
426396 self .rope = rope
427397
428398 causal_mask = torch .tril (
@@ -436,40 +406,102 @@ def __init__(
436406 self .register_buffer ("mask" , causal_mask , persistent = False )
437407
438408 if self .use_kv_cache :
409+ self ._init_kv_cache (args )
410+ self .SDPA = SDPA (
411+ dim = self .n_local_heads * self .head_dim ,
412+ head_dim = self .head_dim ,
413+ n_rep = self .n_rep ,
414+ max_context_len = self .max_context_len ,
415+ )
416+
417+ def _init_norms (self , args : ModelArgs ) -> None :
418+ """Initialize QK normalization layers."""
419+ if self .use_qk_norm :
420+ self .q_norm_fn = RMSNorm (
421+ self .head_dim ,
422+ eps = args .norm_eps ,
423+ add_unit_offset = args .rms_norm_add_unit_offset ,
424+ )
425+ if self .has_kv_weights :
426+ self .k_norm_fn = RMSNorm (
427+ self .head_dim ,
428+ eps = args .norm_eps ,
429+ add_unit_offset = args .rms_norm_add_unit_offset ,
430+ )
431+
432+ def _init_projections (self , args : ModelArgs , q_out_dim : int ) -> None :
433+ """Initialize Q/K/V/O projection layers."""
434+ self .wq = _create_projection (
435+ args , args .dim , q_out_dim , ("q_proj" ,), bias = self .attention_qkv_bias
436+ )
437+ if self .has_kv_weights :
438+ kv_dim = self .n_kv_heads * self .head_dim
439+ self .wk = _create_projection (
440+ args , args .dim , kv_dim , ("k_proj" ,), bias = self .attention_qkv_bias
441+ )
442+ self .wv = _create_projection (
443+ args , args .dim , kv_dim , ("v_proj" ,), bias = self .attention_qkv_bias
444+ )
445+ else :
446+ self .wk = None
447+ self .wv = None
448+ self .wo = _create_projection (
449+ args ,
450+ args .n_heads * args .head_dim ,
451+ args .dim ,
452+ ("output_proj" , "o_proj" ),
453+ bias = False ,
454+ )
455+
456+ def _init_kv_cache (self , args : ModelArgs ) -> None :
457+ """Initialize KV cache (only for non-shared layers)."""
458+ if self .has_kv_weights :
439459 self .kv_cache = KVCache (
440460 args .max_batch_size ,
441461 args .max_context_len ,
442462 self .n_kv_heads ,
443463 self .head_dim ,
444464 args .enable_dynamic_shape ,
445465 )
446- self .SDPA = SDPA (
447- dim = self .n_local_heads * self .head_dim ,
448- head_dim = self .head_dim ,
449- n_rep = self .n_rep ,
450- max_context_len = self .max_context_len ,
451- )
466+ else :
467+ self .kv_cache = None
452468
453- def forward (
469+ def _prepare_qkv_shared (
454470 self ,
455- x : torch .Tensor ,
471+ q : torch .Tensor ,
472+ shared_kv : Tuple [torch .Tensor , torch .Tensor ],
456473 freqs_cos : torch .Tensor ,
457474 freqs_sin : torch .Tensor ,
458- ** kwargs : ForwardOptions ,
459- ) -> Tuple [torch .Tensor , Optional [Any ]]:
460- input_pos = kwargs .get ("input_pos" )
461- bsz , seqlen , _ = x .shape
475+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
476+ """Prepare Q/K/V when using shared KV from a donor layer (YOCO)."""
477+ k , v = shared_kv
462478
463- if self .use_q_gate :
464- q_and_gate = self .wq (x ).view (
465- bsz , seqlen , self .n_local_heads , self .head_dim * 2
466- )
467- q , gate = torch .chunk (q_and_gate , 2 , dim = - 1 )
468- gate = gate .reshape (bsz , seqlen , - 1 )
469- else :
470- q = self .wq (x ).view (bsz , seqlen , self .n_local_heads , self .head_dim )
471- gate = None
479+ if self .use_qk_norm and self .qk_norm_before_rope :
480+ q = self .q_norm_fn (q )
481+
482+ # Apply RoPE to Q only (K already has RoPE from donor layer)
483+ q , _ = self .rope .forward (q , q , freqs_cos , freqs_sin )
484+ q = q .transpose (1 , 2 )
485+
486+ if self .use_qk_norm and not self .qk_norm_before_rope :
487+ q = self .q_norm_fn (q )
488+
489+ return q , k , v
472490
491+ def _prepare_qkv (
492+ self ,
493+ q : torch .Tensor ,
494+ x : torch .Tensor ,
495+ bsz : int ,
496+ seqlen : int ,
497+ freqs_cos : torch .Tensor ,
498+ freqs_sin : torch .Tensor ,
499+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
500+ """Prepare Q/K/V with standard projection (non-YOCO path)."""
501+ assert self .wk is not None and self .wv is not None , (
502+ "wk/wv projections are required when shared_kv is not provided. "
503+ "This layer may be a YOCO shared layer that requires shared_kv from a donor."
504+ )
473505 k , v = self .wk (x ), self .wv (x )
474506 k = k .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
475507 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
@@ -478,7 +510,6 @@ def forward(
478510 q = self .q_norm_fn (q )
479511 k = self .k_norm_fn (k )
480512
481- # RoPE relative positional embeddings
482513 q , k = self .rope .forward (q , k , freqs_cos , freqs_sin )
483514
484515 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
@@ -489,6 +520,34 @@ def forward(
489520 q = self .q_norm_fn (q )
490521 k = self .k_norm_fn (k )
491522
523+ return q , k , v
524+
525+ def forward (
526+ self ,
527+ x : torch .Tensor ,
528+ freqs_cos : torch .Tensor ,
529+ freqs_sin : torch .Tensor ,
530+ ** kwargs : ForwardOptions ,
531+ ) -> Tuple [torch .Tensor , Optional [Any ]]:
532+ input_pos = kwargs .get ("input_pos" )
533+ shared_kv = kwargs .get ("shared_kv" )
534+ bsz , seqlen , _ = x .shape
535+
536+ if self .use_q_gate :
537+ q_and_gate = self .wq (x ).view (
538+ bsz , seqlen , self .n_local_heads , self .head_dim * 2
539+ )
540+ q , gate = torch .chunk (q_and_gate , 2 , dim = - 1 )
541+ gate = gate .reshape (bsz , seqlen , - 1 )
542+ else :
543+ q = self .wq (x ).view (bsz , seqlen , self .n_local_heads , self .head_dim )
544+ gate = None
545+
546+ if shared_kv is not None :
547+ q , k , v = self ._prepare_qkv_shared (q , shared_kv , freqs_cos , freqs_sin )
548+ else :
549+ q , k , v = self ._prepare_qkv (q , x , bsz , seqlen , freqs_cos , freqs_sin )
550+
492551 if self .use_kv_cache :
493552 assert input_pos is not None
494553 if self .enable_dynamic_shape :
@@ -501,15 +560,29 @@ def forward(
501560 else :
502561 # mask is always 2D
503562 attn_mask = self .mask [input_pos ]
504- k , v = self .kv_cache .update (input_pos , k , v )
563+
564+ # Only update KV cache for non-shared layers
565+ if shared_kv is None :
566+ assert self .kv_cache is not None , (
567+ "kv_cache is required when shared_kv is not provided. "
568+ "This layer may be a YOCO shared layer that requires shared_kv from a donor."
569+ )
570+ k , v = self .kv_cache .update (input_pos , k , v )
571+
505572 if getattr (self .kv_cache , "is_ring_buffer" , False ):
506573 attn_mask = self .kv_cache .create_causal_mask_for_ring_buffer (
507574 input_pos [0 ].item (), seqlen
508575 )
576+
509577 output = self .SDPA (input_pos , q , k , v , bsz , seqlen , attn_mask )
510578 if gate is not None :
511579 output = output * torch .sigmoid (gate )
512- return self .wo (output ), None
580+
581+ if shared_kv is None and self .num_kv_shared_layers > 0 :
582+ update = {"kv_to_share" : (k , v )}
583+ else :
584+ update = None
585+ return self .wo (output ), update
513586
514587 # grouped multiquery attention: expand out keys and values
515588 k = k .repeat_interleave (self .n_rep , dim = 1 )
0 commit comments