⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/models/llama/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test,
],
)

fbcode_target(_kind = runtime.python_test,
name = "attention_sink_ring_buffer_test",
srcs = [
"source_transformation/test_attention_sink_ring_buffer.py",
],
supports_static_listing = False,
deps = [
"//caffe2:torch",
":export_library",
],
)

fbcode_target(_kind = runtime.python_test,
name = "quantized_sdpa_source_transform_test",
srcs = [
Expand Down
23 changes: 16 additions & 7 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
current step still has access to [pos - sliding_window_size, pos] tokens.

To make sure we dont over attend, i.e. we dont have pos = 5
to attend to pos = 1, mask calculaton has to account for the sliding window
size.
Expand Down Expand Up @@ -486,21 +486,30 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
if self.enable_dynamic_shape:
is_ring = getattr(self.kv_cache, "is_ring_buffer", False)
if is_ring:
# Ring buffer models: positions can exceed max_context_len.
# The ring buffer handles wrapping via modular arithmetic.
# The causal mask is computed dynamically from cache_positions,
# so we don't use the pre-computed self.mask here.
k, v = self.kv_cache.update(input_pos, k, v)
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
start_pos, seqlen
)
elif self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = self.mask.narrow(0, start_pos, seq_length)
k, v = self.kv_cache.update(input_pos, k, v)
else:
# mask is always 2D
attn_mask = self.mask[input_pos]
k, v = self.kv_cache.update(input_pos, k, v)
if getattr(self.kv_cache, "is_ring_buffer", False):
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
return self.wo(output), None

Expand Down
31 changes: 31 additions & 0 deletions examples/models/llama/config/llama_attention_sink.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size,eviction_batch_size"
# sink_size=4: Keep first 4 tokens (attention sink)
# window_size=124: Sliding window size
# eviction_batch_size=1: Evict 1 token each time
# KV cache size = sink_size + window_size = 4 + 124 = 128 = max_seq_length
use_attention_sink: "4,124,1"

export:
# max_seq_length = KV cache size = sink + window
max_seq_length: 128
# max_context_length = RoPE position encoding limit
# pos_ can exceed max_seq_length but not max_context_length
max_context_length: 8192

quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

backend:
xnnpack:
enabled: True
extended_ops: True
10 changes: 9 additions & 1 deletion examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py

Updated for the ring-buffer based attention sink implementation.
"""
# Convert args to LlmConfig
from executorch.extension.llm.export.config.llm_config import LlmConfig
Expand All @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert llm_config.export.max_seq_length == sink_size + window_size
# For the ring buffer implementation, the cache size is sink_size + window_size * 2
# max_context_length should be >= sink_size + window_size (for RoPE frequencies)
# but can be larger to support extended generation
assert llm_config.export.max_context_length >= sink_size + window_size, (
f"max_context_length ({llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
Expand Down
130 changes: 109 additions & 21 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_vulkan_partitioner,
get_xnnpack_partitioner,
)
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.extension.llm.export.quantizer_lib import (
get_coreml_quantizer,
get_ov_quantizer,
Expand All @@ -57,6 +58,7 @@
get_model_with_r1_r2,
)
from .source_transformation.attention import replace_attention_to_attention_sha
from .source_transformation.attention_sink import enable_attention_sink
from .source_transformation.custom_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
Expand Down Expand Up @@ -728,9 +730,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
calibration_limit=llm_config.quantization.calibration_limit,
calibration_seq_length=llm_config.quantization.calibration_seq_length,
expand_rope_table=llm_config.model.expand_rope_table,
# Attention sink models need attention mask for custom SDPA because:
# 1. The ring buffer creates a dynamic mask based on cache_positions
# 2. Without mask, custom_sdpa uses is_causal=True with start_pos, which
# fails when start_pos exceeds the cache size (positions keep growing)
# 3. With mask, custom_sdpa uses is_causal=False and the mask handles
# all masking logic including sliding window and attention sink
use_custom_sdpa_with_attention_mask=getattr(
llm_config.model, "use_custom_sdpa_with_attention_mask", False
),
)
or bool(llm_config.model.use_attention_sink),
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
quantize_kv_cache=llm_config.model.quantize_kv_cache,
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block enables attention-mask mode for custom SDPA when attention sink is on, but attention-sink configs also enable KV-cache quantization in some cases. In the quantized path, SDPACustom is replaced by QuantizedSDPA, which (per sdpa.py) still doesn’t support attention masks and continues to use start_pos for causal masking. That means attention sink + quantized KV cache may still hit the same start_pos >= cache_size validation failure. Consider propagating the use_attention_mask flag into QuantizedSDPA and making its mask path ignore start_pos similarly (or preventing this combination).

Suggested change
quantize_kv_cache=llm_config.model.quantize_kv_cache,
# Quantized KV cache currently does not support attention-mask-based
# custom SDPA (QuantizedSDPA still relies on start_pos for masking).
# To avoid incompatible behavior, disable KV-cache quantization when
# attention sink is enabled.
quantize_kv_cache=(
False
if getattr(llm_config.model, "use_attention_sink", False)
else llm_config.model.quantize_kv_cache
),

Copilot uses AI. Check for mistakes.
use_kv_cache=llm_config.model.use_kv_cache,
Expand All @@ -750,13 +759,49 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
local_global_attention=llm_config.model.local_global_attention,
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,

use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
use_attention_sink=llm_config.model.use_attention_sink,
params_path=llm_config.base.params,
max_context_len=llm_config.export.max_context_length,
)
)

if llm_config.model.use_attention_sink:
print("Refreshing example inputs for Attention Sink...")
if hasattr(edge_manager.model, "get_example_inputs"):
# The model is now patched to return (tokens, attn_options, cache_indices)
new_inputs = edge_manager.model.get_example_inputs()
# We assume these are all positional arguments
edge_manager.example_inputs = new_inputs
# Clear kwargs since we provide everything positionally
edge_manager.example_kwarg_inputs = {}
print(f"Updated inputs: {len(new_inputs)} items")

# Update dynamic shapes if enabled
if edge_manager.enable_dynamic_shape:
existing_shapes = edge_manager.dynamic_shapes
if existing_shapes and len(existing_shapes) == 2:
# Extract the Dim object from the first input (tokens)
# tokens shape dict is {1: Dim(...)}
token_dim = existing_shapes[0][1]

# cache_indices is 1D tensor of size seq_len
# Spec should be {0: token_dim}
indices_spec = {0: token_dim}

# Relieve static constraint on input_pos
# input_pos spec in existing_shapes[1] is {"input_pos": {0: 1}}
# We change it to {"input_pos": {0: token_dim}}
input_pos_spec = {"input_pos": {0: token_dim}}

edge_manager.dynamic_shapes = (existing_shapes[0], input_pos_spec, indices_spec)
print("Updated dynamic_shapes for Attention Sink (patched input_pos)")

return edge_manager



def get_quantizer_and_quant_params(llm_config):
pt2e_quant_params = get_pt2e_quantization_params(
(
Expand Down Expand Up @@ -1118,6 +1163,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# For attention sink models, the cache_positions buffer must be initialized
# to -1 (sentinel for "empty slot"). Without this pass, ExecuTorch only
# serializes shape+dtype for mutated buffers, leaving them uninitialized
# at runtime, which corrupts the attention mask computation.
if llm_config.model.use_attention_sink:
additional_passes.append(
InitializedMutableBufferPass(["cache_positions"])
)

# export_to_edge
builder_exported = _prepare_for_llama_export(llm_config).export()
builder_exported.run_canonical_optimizations()
Expand Down Expand Up @@ -1282,6 +1336,28 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
model_class_name,
llm_config=llm_config,
)

# Add attention sink metadata if enabled
metadata = _load_llama_model_metadata(
llm_config.model.use_kv_cache,
llm_config.model.use_sdpa_with_kv_cache,
llm_config.model.enable_dynamic_shape,
model.max_seq_len,
model.max_context_len,
model.n_layers,
model.vocab_size,
llm_config.base.metadata,
)

# Add attention sink metadata if enabled
if llm_config.model.use_attention_sink:
# Format: sink_size,window_size,eviction_batch_size
sink_params = [int(x) for x in llm_config.model.use_attention_sink.split(",")]
# IOManager expects these methods to exist returning int.
# By adding them to metadata, export_to_edge will generate constant methods.
metadata["get_sink_size"] = sink_params[0]
metadata["get_window_size"] = sink_params[1]
Comment on lines +1356 to +1359
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metadata added for attention sink uses keys get_sink_size/get_window_size, but the runtime runner expects attention sink metadata under use_attention_sink/attention_sink_size/attention_sink_window_size (see extension/llm/runner/constants.h and llm_runner_helper.cpp). This mismatch prevents the runner from detecting attention sink models. Update the metadata keys (and add an explicit enable flag) to match what the runner reads, or adjust the runner to match these exported keys.

Suggested change
# IOManager expects these methods to exist returning int.
# By adding them to metadata, export_to_edge will generate constant methods.
metadata["get_sink_size"] = sink_params[0]
metadata["get_window_size"] = sink_params[1]
# Runtime runner expects these metadata keys:
# - "use_attention_sink": bool flag to enable attention sink
# - "attention_sink_size": sink size (int)
# - "attention_sink_window_size": window size (int)
metadata["use_attention_sink"] = True
metadata["attention_sink_size"] = sink_params[0]
metadata["attention_sink_window_size"] = sink_params[1]

Copilot uses AI. Check for mistakes.

# Convert dtype override string to actual type.
dtype_override = DType[llm_config.model.dtype_override.value]

Expand All @@ -1296,31 +1372,14 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
example_kwarg_inputs=example_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
save_exported_program=llm_config.export.export_only,
calibration_tasks=llm_config.quantization.calibration_tasks,
calibration_limit=llm_config.quantization.calibration_limit,
calibration_seq_length=llm_config.quantization.calibration_seq_length,
calibration_data=llm_config.quantization.calibration_data,
tokenizer_path=llm_config.base.tokenizer_path,
save_exported_program=llm_config.export.export_only,
verbose=llm_config.debug.verbose,
metadata=_load_llama_model_metadata(
llm_config.model.use_kv_cache,
llm_config.model.use_sdpa_with_kv_cache,
llm_config.model.enable_dynamic_shape,
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_seq_len,
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_context_len,
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
# pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
# Module]`.
model.vocab_size,
llm_config.base.metadata,
),
metadata=metadata,
)


Expand Down Expand Up @@ -1359,6 +1418,9 @@ def _get_source_transforms( # noqa
use_torchao_kernels_linear: bool = False,
use_torchao_kernels_tied_embedding: bool = False,
quantize_with_hqq: bool = True,
use_attention_sink: Optional[str] = None,
params_path: Optional[str] = None,
max_context_len: Optional[int] = None,
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
"""
Return a list of functions that transform a graph.
Expand Down Expand Up @@ -1470,7 +1532,6 @@ def _get_source_transforms( # noqa
transforms.append(materialze_broadcast_of_rope_freq_cis)

use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask

if use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
# todo: do this optionally
Expand Down Expand Up @@ -1546,6 +1607,33 @@ def _get_source_transforms( # noqa
)
)

if use_attention_sink:
sink_params = [int(x) for x in use_attention_sink.split(",")]

# Load ModelArgs for attention sink
if not params_path:
raise ValueError("params_path is required for attention sink")
with open(params_path, "r") as f:
params_dict = json.load(f)

# Ensure use_kv_cache is propagated from config
params_dict["use_kv_cache"] = True # Attention Sink requires KV Cache
# ModelArgs might expect other fields usually handled by Llama2Model init
# We try to pass minimal set needed for Rope/Attention

model_args = ModelArgs(**params_dict)

transforms.append(
partial(
enable_attention_sink,
params=model_args,
sink_size=sink_params[0],
window_size=sink_params[1],
eviction_batch_size=sink_params[2],
max_context_len=max_context_len,
)
)

return transforms


Expand Down
17 changes: 16 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,22 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.llm_config.export.max_context_length == sink_size + window_size
# max_context_length must be >= sink_size + window_size to have enough RoPE frequencies
# A larger max_context_length is allowed (and recommended) to support generation beyond
# the sliding window size.
assert self.llm_config.export.max_context_length >= sink_size + window_size, (
f"max_context_length ({self.llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

# IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation
# positions, not just the cache size. Override the model's max_context_len to use
# a larger value that supports extended generation.
# We use model_args.max_context_len which was set from export.max_context_length
# but for RoPE we need the full generation length capability.
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
model_args.max_context_len = default_rope_length

self.model_ = enable_attention_sink(
module=self.model_,
Expand Down
Loading
Loading