-
Notifications
You must be signed in to change notification settings - Fork 830
[TEST] Attention sink 2 #17271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TEST] Attention sink 2 #17271
Changes from all commits
c954f27
34a937f
83f437a
97f9910
075d521
fe66e74
04fadb5
97ba715
e4ccab4
bbca983
fec5eb9
f3a0487
7bf2250
f04f211
dc2695a
5ae4922
42e5e7a
246a9b0
60875e3
7fd294f
5f03110
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| base: | ||
| metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' | ||
|
|
||
| model: | ||
| use_sdpa_with_kv_cache: True | ||
| use_kv_cache: True | ||
| dtype_override: fp32 | ||
| enable_dynamic_shape: True | ||
| # Attention Sink: "sink_size,window_size,eviction_batch_size" | ||
| # sink_size=4: Keep first 4 tokens (attention sink) | ||
| # window_size=124: Sliding window size | ||
| # eviction_batch_size=1: Evict 1 token each time | ||
| # KV cache size = sink_size + window_size = 4 + 124 = 128 = max_seq_length | ||
| use_attention_sink: "4,124,1" | ||
|
|
||
| export: | ||
| # max_seq_length = KV cache size = sink + window | ||
| max_seq_length: 128 | ||
| # max_context_length = RoPE position encoding limit | ||
| # pos_ can exceed max_seq_length but not max_context_length | ||
| max_context_length: 8192 | ||
|
|
||
| quantization: | ||
| qmode: 8da4w | ||
| group_size: 128 | ||
| embedding_quantize: 4,32 | ||
|
|
||
| backend: | ||
| xnnpack: | ||
| enabled: True | ||
| extended_ops: True |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |||||||||||||||||||||||
| get_vulkan_partitioner, | ||||||||||||||||||||||||
| get_xnnpack_partitioner, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from executorch.examples.models.llama.model_args import ModelArgs | ||||||||||||||||||||||||
| from executorch.extension.llm.export.quantizer_lib import ( | ||||||||||||||||||||||||
| get_coreml_quantizer, | ||||||||||||||||||||||||
| get_ov_quantizer, | ||||||||||||||||||||||||
|
|
@@ -57,6 +58,7 @@ | |||||||||||||||||||||||
| get_model_with_r1_r2, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from .source_transformation.attention import replace_attention_to_attention_sha | ||||||||||||||||||||||||
| from .source_transformation.attention_sink import enable_attention_sink | ||||||||||||||||||||||||
| from .source_transformation.custom_kv_cache import ( | ||||||||||||||||||||||||
| replace_kv_cache_with_custom_kv_cache, | ||||||||||||||||||||||||
| replace_kv_cache_with_quantized_kv_cache, | ||||||||||||||||||||||||
|
|
@@ -728,9 +730,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: | |||||||||||||||||||||||
| calibration_limit=llm_config.quantization.calibration_limit, | ||||||||||||||||||||||||
| calibration_seq_length=llm_config.quantization.calibration_seq_length, | ||||||||||||||||||||||||
| expand_rope_table=llm_config.model.expand_rope_table, | ||||||||||||||||||||||||
| # Attention sink models need attention mask for custom SDPA because: | ||||||||||||||||||||||||
| # 1. The ring buffer creates a dynamic mask based on cache_positions | ||||||||||||||||||||||||
| # 2. Without mask, custom_sdpa uses is_causal=True with start_pos, which | ||||||||||||||||||||||||
| # fails when start_pos exceeds the cache size (positions keep growing) | ||||||||||||||||||||||||
| # 3. With mask, custom_sdpa uses is_causal=False and the mask handles | ||||||||||||||||||||||||
| # all masking logic including sliding window and attention sink | ||||||||||||||||||||||||
| use_custom_sdpa_with_attention_mask=getattr( | ||||||||||||||||||||||||
| llm_config.model, "use_custom_sdpa_with_attention_mask", False | ||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| or bool(llm_config.model.use_attention_sink), | ||||||||||||||||||||||||
| use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, | ||||||||||||||||||||||||
| quantize_kv_cache=llm_config.model.quantize_kv_cache, | ||||||||||||||||||||||||
| use_kv_cache=llm_config.model.use_kv_cache, | ||||||||||||||||||||||||
|
|
@@ -750,13 +759,49 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: | |||||||||||||||||||||||
| preq_embedding_quantize=llm_config.base.preq_embedding_quantize, | ||||||||||||||||||||||||
| local_global_attention=llm_config.model.local_global_attention, | ||||||||||||||||||||||||
| use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear, | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, | ||||||||||||||||||||||||
| use_attention_sink=llm_config.model.use_attention_sink, | ||||||||||||||||||||||||
| params_path=llm_config.base.params, | ||||||||||||||||||||||||
| max_context_len=llm_config.export.max_context_length, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if llm_config.model.use_attention_sink: | ||||||||||||||||||||||||
| print("Refreshing example inputs for Attention Sink...") | ||||||||||||||||||||||||
| if hasattr(edge_manager.model, "get_example_inputs"): | ||||||||||||||||||||||||
| # The model is now patched to return (tokens, attn_options, cache_indices) | ||||||||||||||||||||||||
| new_inputs = edge_manager.model.get_example_inputs() | ||||||||||||||||||||||||
| # We assume these are all positional arguments | ||||||||||||||||||||||||
| edge_manager.example_inputs = new_inputs | ||||||||||||||||||||||||
| # Clear kwargs since we provide everything positionally | ||||||||||||||||||||||||
| edge_manager.example_kwarg_inputs = {} | ||||||||||||||||||||||||
| print(f"Updated inputs: {len(new_inputs)} items") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Update dynamic shapes if enabled | ||||||||||||||||||||||||
| if edge_manager.enable_dynamic_shape: | ||||||||||||||||||||||||
| existing_shapes = edge_manager.dynamic_shapes | ||||||||||||||||||||||||
| if existing_shapes and len(existing_shapes) == 2: | ||||||||||||||||||||||||
| # Extract the Dim object from the first input (tokens) | ||||||||||||||||||||||||
| # tokens shape dict is {1: Dim(...)} | ||||||||||||||||||||||||
| token_dim = existing_shapes[0][1] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # cache_indices is 1D tensor of size seq_len | ||||||||||||||||||||||||
| # Spec should be {0: token_dim} | ||||||||||||||||||||||||
| indices_spec = {0: token_dim} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Relieve static constraint on input_pos | ||||||||||||||||||||||||
| # input_pos spec in existing_shapes[1] is {"input_pos": {0: 1}} | ||||||||||||||||||||||||
| # We change it to {"input_pos": {0: token_dim}} | ||||||||||||||||||||||||
| input_pos_spec = {"input_pos": {0: token_dim}} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| edge_manager.dynamic_shapes = (existing_shapes[0], input_pos_spec, indices_spec) | ||||||||||||||||||||||||
| print("Updated dynamic_shapes for Attention Sink (patched input_pos)") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return edge_manager | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_quantizer_and_quant_params(llm_config): | ||||||||||||||||||||||||
| pt2e_quant_params = get_pt2e_quantization_params( | ||||||||||||||||||||||||
| ( | ||||||||||||||||||||||||
|
|
@@ -1118,6 +1163,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 | |||||||||||||||||||||||
| if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: | ||||||||||||||||||||||||
| additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # For attention sink models, the cache_positions buffer must be initialized | ||||||||||||||||||||||||
| # to -1 (sentinel for "empty slot"). Without this pass, ExecuTorch only | ||||||||||||||||||||||||
| # serializes shape+dtype for mutated buffers, leaving them uninitialized | ||||||||||||||||||||||||
| # at runtime, which corrupts the attention mask computation. | ||||||||||||||||||||||||
| if llm_config.model.use_attention_sink: | ||||||||||||||||||||||||
| additional_passes.append( | ||||||||||||||||||||||||
| InitializedMutableBufferPass(["cache_positions"]) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # export_to_edge | ||||||||||||||||||||||||
| builder_exported = _prepare_for_llama_export(llm_config).export() | ||||||||||||||||||||||||
| builder_exported.run_canonical_optimizations() | ||||||||||||||||||||||||
|
|
@@ -1282,6 +1336,28 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": | |||||||||||||||||||||||
| model_class_name, | ||||||||||||||||||||||||
| llm_config=llm_config, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Add attention sink metadata if enabled | ||||||||||||||||||||||||
| metadata = _load_llama_model_metadata( | ||||||||||||||||||||||||
| llm_config.model.use_kv_cache, | ||||||||||||||||||||||||
| llm_config.model.use_sdpa_with_kv_cache, | ||||||||||||||||||||||||
| llm_config.model.enable_dynamic_shape, | ||||||||||||||||||||||||
| model.max_seq_len, | ||||||||||||||||||||||||
| model.max_context_len, | ||||||||||||||||||||||||
| model.n_layers, | ||||||||||||||||||||||||
| model.vocab_size, | ||||||||||||||||||||||||
| llm_config.base.metadata, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Add attention sink metadata if enabled | ||||||||||||||||||||||||
| if llm_config.model.use_attention_sink: | ||||||||||||||||||||||||
| # Format: sink_size,window_size,eviction_batch_size | ||||||||||||||||||||||||
| sink_params = [int(x) for x in llm_config.model.use_attention_sink.split(",")] | ||||||||||||||||||||||||
| # IOManager expects these methods to exist returning int. | ||||||||||||||||||||||||
| # By adding them to metadata, export_to_edge will generate constant methods. | ||||||||||||||||||||||||
| metadata["get_sink_size"] = sink_params[0] | ||||||||||||||||||||||||
| metadata["get_window_size"] = sink_params[1] | ||||||||||||||||||||||||
|
Comment on lines
+1356
to
+1359
|
||||||||||||||||||||||||
| # IOManager expects these methods to exist returning int. | |
| # By adding them to metadata, export_to_edge will generate constant methods. | |
| metadata["get_sink_size"] = sink_params[0] | |
| metadata["get_window_size"] = sink_params[1] | |
| # Runtime runner expects these metadata keys: | |
| # - "use_attention_sink": bool flag to enable attention sink | |
| # - "attention_sink_size": sink size (int) | |
| # - "attention_sink_window_size": window size (int) | |
| metadata["use_attention_sink"] = True | |
| metadata["attention_sink_size"] = sink_params[0] | |
| metadata["attention_sink_window_size"] = sink_params[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block enables attention-mask mode for custom SDPA when attention sink is on, but attention-sink configs also enable KV-cache quantization in some cases. In the quantized path,
SDPACustomis replaced byQuantizedSDPA, which (per sdpa.py) still doesn’t support attention masks and continues to usestart_posfor causal masking. That means attention sink + quantized KV cache may still hit the samestart_pos >= cache_sizevalidation failure. Consider propagating theuse_attention_maskflag intoQuantizedSDPAand making its mask path ignorestart_possimilarly (or preventing this combination).