-
Notifications
You must be signed in to change notification settings - Fork 610
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell #2584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables deterministic FP16/BF16 attention on Blackwell GPUs (sm100+) with cuDNN >= 9.18.1 by threading a Key changes:
The implementation correctly handles the asymmetry where forward passes are always deterministic while backward passes support both modes. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Python as Python Layer<br/>(JAX/PyTorch)
participant CPP as C++ Extension<br/>(attention.cpp)
participant Backend as Backend Selection<br/>(fused_attn.cpp)
participant cuDNN
Note over User,cuDNN: Backward Pass (Training)
User->>Python: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO
Python->>Python: Read env var & determine<br/>deterministic flag
Python->>CPP: Call fused_attn_bwd with<br/>deterministic parameter
CPP->>Backend: nvte_get_fused_attn_backend(...,<br/>deterministic)
alt Blackwell (sm100+) && training
alt deterministic=false (non-deterministic)
Backend->>Backend: Check cuDNN >= 9.7.0
Backend->>Backend: Require (dropout=0 OR bias=NO_BIAS)
Backend->>Backend: Enable arbitrary_seqlen backend
else deterministic=true
Backend->>Backend: Check cuDNN >= 9.18.1
Backend->>Backend: Require (dropout=0 AND bias=NO_BIAS)
Backend->>Backend: Enable arbitrary_seqlen backend
end
else Non-Blackwell or inference
Backend->>Backend: Standard backend selection
end
Backend-->>CPP: Return selected backend
CPP->>cuDNN: Execute attention backward
cuDNN-->>CPP: Gradients
CPP-->>Python: Return gradients
Note over User,cuDNN: Forward Pass (Always Deterministic)
Python->>CPP: Call fused_attn_fwd
CPP->>Backend: nvte_get_fused_attn_backend(...,<br/>deterministic=false)
Backend->>Backend: Forward is always deterministic<br/>(hardcoded false)
Backend-->>CPP: Return selected backend
CPP->>cuDNN: Execute attention forward
cuDNN-->>CPP: Output + aux tensors
CPP-->>Python: Return output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Greptile OverviewGreptile SummaryOverviewThis PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer. Key Changes
ArchitectureThe change follows a layered approach:
The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism. Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User/Test
participant PyAPI as Python API
participant Utils as utils.py
participant CppExt as C++ Extensions
participant Backend as Backend Selection
participant cuDNN as cuDNN Library
User->>PyAPI: Call attention with deterministic=True
PyAPI->>Utils: get_attention_backend(params)
Utils->>Utils: Extract deterministic from params
Utils->>CppExt: get_fused_attn_backend(..., deterministic)
CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
alt Blackwell (sm_arch >= 100) & Training & Deterministic
Backend->>Backend: Check cuDNN version >= 9.18.0
Backend->>Backend: Check bias_type == NO_BIAS
Backend->>Backend: Check dropout == 0.0
alt All checks pass
Backend-->>CppExt: F16_arbitrary_seqlen backend
else Any check fails
Backend-->>CppExt: No_Backend (disabled)
end
else Other architectures or inference
Backend->>Backend: Apply standard backend selection
Backend-->>CppExt: Selected backend
end
CppExt-->>Utils: Backend choice
Utils-->>PyAPI: Backend configuration
alt Forward Pass
PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
Note over PyAPI,CppExt: Forward always uses deterministic=true
else Backward Pass
PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
Note over PyAPI,CppExt: Backward respects user's deterministic flag
end
CppExt->>cuDNN: Execute attention operation
cuDNN-->>CppExt: Results
CppExt-->>PyAPI: Output tensors
PyAPI-->>User: Attention output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
|
/te-ci jax L0 |
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L0 |
|
/te-ci L1 |
Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 3 comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
|
Cool, we are currently suffering from this issue. |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.
It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.
I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .
Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped
qa/L0_jax_unittest/test.sh
Outdated
| mkdir -p "$XML_LOG_DIR" | ||
|
|
||
| python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" | ||
| NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?
That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this extra line tests test_fused_attn.py with determinism, while the line before tests everything with non-determinism. The extra test_fused_attn.py test takes ~20mins on Blackwell:
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_backward | 5040x | 1336.28s | avg: 0.27s
================================================================================
TOTAL RUNTIME | | 1336.28s |
================================================================================
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now with cd5bcf3, the extra determinism tests should really take no time at all (there are only 20 tests added).
| float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, | ||
| size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph); | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt there was a bit of distinction when I was implementing it: is_training is a description of the state we are in while deterministic is more of a request from the user (that they want to run it in deterministic mode). Not a lot of difference, to be honest - just a feel of the words. I kind of did this when I introduced deterministic as a parameter in AttentionParams so just followed along with it in this PR. Any strong objections?
|
/te-ci L0 L1 |
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
|
Pipeline 42017245 for CI with updated cuDNN. |
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
|
Pipeline 42067766 for 9.18.1 tests. |
Description
This PR enables determinism for FP16/BF16 attention on Blackwell. It requires cuDNN >= 9.18.1.
To run determinism, please set
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.Support matrix for FP16/BF16 on Blackwell:
Type of change
Changes
Please see Description.
Checklist: