⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@Brooooooklyn
Copy link

Summary

Implements fused backward pass (VJP) for scaled_dot_product_attention on Metal GPU. This enables efficient gradient computation during training without falling back to unfused (decomposed) attention operations.

Changes

New Files

  • mlx/backend/metal/kernels/sdpa_vector_vjp.h - Vector VJP kernel for short sequences (L ≤ 8)
  • mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h - STEEL dQ gradient kernel
  • mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h - STEEL dK/dV gradient kernel

Modified Files

  • mlx/backend/metal/scaled_dot_product_attention.cpp - VJP dispatch logic (+840 lines)
  • mlx/fast.cpp / mlx/fast_primitives.h - Logsumexp caching, VJP routing
  • python/tests/test_fast_sdpa.py - Comprehensive VJP tests (+220 lines)

Implementation Notes

Uses a two-kernel approach to avoid atomic operations:

  1. dQ kernel (steel_attention_vjp_dq.h):

    • Computes query gradients via outer loop over KV blocks
    • Uses log2 domain for numerical stability
    • Proper clamping to prevent overflow (exp2 arg clamped to [-88, 0])
  2. dK/dV kernel (steel_attention_vjp_dkv.h):

    • Uses K-row ownership model where each simdgroup owns exclusive rows
    • Eliminates race conditions in GQA where multiple query heads share KV
    • No atomic operations needed
  3. Vector VJP (sdpa_vector_vjp.h):

    • Optimized path for short sequences (L ≤ 8)
    • Uses float32 accumulators for half/bfloat16 precision
    • Shared memory reduction for efficiency

Key Features

  • Float32 accumulators for half/bfloat16 precision
  • Logsumexp caching from forward pass for VJP reuse
  • Proper GQA (grouped query attention) support
  • Causal mask support

Limitations

  • Falls back to unfused attention for mask/sinks gradients (per existing design)
  • Requires logsumexp from forward pass (training mode only)
  • Head dimension D=256 not supported in vector VJP (32KB threadgroup memory limit)

Test Plan

  • Existing test_sdpa_grad passes
  • New comprehensive VJP tests added:
    • test_sdpa_grad_vector_path - short sequences (L=1,4,7,8)
    • test_sdpa_grad_steel_path - longer sequences (L=16,32,128,256)
    • test_sdpa_grad_head_dims - head dimensions (D=32,64,96,128)
    • test_sdpa_grad_gqa - GQA configurations (4:1, 8:1, 16:1, MHA)
    • test_sdpa_grad_dtypes - float16, bfloat16, float32
    • test_sdpa_grad_edge_cases - L=1, non-power-of-2, large batch, qL≠kvL

All 21 SDPA tests pass (1 skipped for unrelated disabled feature).

Copilot AI review requested due to automatic review settings January 14, 2026 03:01
@Brooooooklyn
Copy link
Author

Notes: I'm working on https://github.com/mlx-node/mlx-node and trying to port some features in trl.
This pull request was generated by Claude Code. I am trying to reduce the computation and memory usage of GRPO training by utilizing the full flash attention feature.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@Brooooooklyn Brooooooklyn marked this pull request as draft January 14, 2026 04:27
@Brooooooklyn Brooooooklyn force-pushed the flash-attn branch 5 times, most recently from 568ff36 to 26b5857 Compare January 14, 2026 09:03
@Brooooooklyn Brooooooklyn marked this pull request as ready for review January 14, 2026 09:06
@Brooooooklyn Brooooooklyn requested a review from Copilot January 14, 2026 09:12
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@Brooooooklyn Brooooooklyn marked this pull request as draft January 14, 2026 14:17
@Brooooooklyn Brooooooklyn force-pushed the flash-attn branch 7 times, most recently from d9089ef to dd8daf1 Compare January 18, 2026 13:06
Implement fused backward pass (VJP) for scaled_dot_product_attention
on Metal GPU, enabling efficient training without falling back to
unfused attention.

- **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients
  - Outer loop over KV blocks, inner accumulation for dQ
  - Uses log2 domain for numerical stability

- **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients
  - K-row ownership model eliminates atomic operations
  - Each simdgroup owns exclusive K rows to prevent races

- Optimized path for short sequences (L ≤ 8)
- Uses shared memory for efficient reduction

- Float32 accumulators for half/bfloat16 precision
- Logsumexp caching from forward pass
- Proper GQA (grouped query attention) support
- Causal mask support
- Comprehensive test coverage for all code paths

- No gradient support for mask or attention sinks (falls back to unfused)
- Requires logsumexp from forward pass (training mode only)
- Head dimension D=256 not supported in vector VJP (threadgroup memory)

Co-Authored-By: Claude <[email protected]>
@Brooooooklyn Brooooooklyn marked this pull request as ready for review January 18, 2026 13:29
@Brooooooklyn
Copy link
Author

@awni @zcbenz Do you have any interesting review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant