⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor
Expand Down Expand Up @@ -68,7 +67,7 @@ class TestTimeAugmentation:
Args:
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
. All random transforms must be of type `InvertibleTransform`.
All random transforms must be of type `InvertibleTransform`.
batch_size: number of realizations to infer at once.
Comment on lines 67 to 71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring still claims all random transforms must be invertible.

With apply_inverse_to_pred=False, non‑invertible random transforms are allowed. Update the docstring (and consider widening the transform type hint) to match behavior. As per coding guidelines, keep docstrings aligned with behavior.

✅ Suggested docstring fix
-        All random transforms must be of type `InvertibleTransform`.
+        When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.

Also applies to: 115-118

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 67 - 71, The docstring for
test-time augmentation (function/class using parameter names transform,
batch_size, and apply_inverse_to_pred) incorrectly states "All random transforms
must be of type InvertibleTransform"; update the docstring (and the transform
type hint if present) to reflect that non-invertible random transforms are
allowed when apply_inverse_to_pred=False and only need to be invertible when
apply_inverse_to_pred=True; change the wording in both occurrences (the block
around the transform description and the later paragraph at lines ~115-118) to
describe this conditional requirement and, if applicable, broaden the transform
type hint to accept non-invertible Randomizable types when apply_inverse_to_pred
is False.

num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
Expand All @@ -92,6 +91,11 @@ class TestTimeAugmentation:
will return the full data. Dimensions will be same size as when passing a single image through
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
back to the original spatial reference.
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
aggregate the raw predictions directly. Defaults to `True`.

Example:
.. code-block:: python
Expand Down Expand Up @@ -125,6 +129,7 @@ def __init__(
post_func: Callable = _identity,
return_full_data: bool = False,
progress: bool = True,
apply_inverse_to_pred: bool = True,
) -> None:
self.transform = transform
self.batch_size = batch_size
Expand All @@ -134,6 +139,7 @@ def __init__(
self.image_key = image_key
self.return_full_data = return_full_data
self.progress = progress
self.apply_inverse_to_pred = apply_inverse_to_pred
self._pred_key = CommonKeys.PRED
self.inverter = Invertd(
keys=self._pred_key,
Expand All @@ -152,20 +158,21 @@ def __init__(

def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
randoms = np.array([isinstance(t, Randomizable) for t in ts])
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
warnings.warn(
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
)
# check that whenever randoms is True, invertibles is also true
for r, i in zip(randoms, invertibles):
if r and not i:
warnings.warn(
f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}"
)
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
warns = []
randoms = []

for idx, t in enumerate(transforms):
if isinstance(t, Randomizable):
randoms.append(t)
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.")

if len(randoms) == 0:
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")

if len(warns) > 0:
warnings.warn("TTA has encountered issues with the given transforms:" + "\n ".join(warns))
Comment on lines +174 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Improve warning formatting and stacklevel.

Current warning concatenates without a newline and lacks a stacklevel, making diagnostics noisy.

🛠️ Proposed fix
-        warnings.warn("TTA has encountered issues with the given transforms:" + "\n  ".join(warns))
+        warnings.warn(
+            "TTA has encountered issues with the given transforms:\n  " + "\n  ".join(warns),
+            stacklevel=2,
+        )
🧰 Tools
🪛 Ruff (0.14.13)

175-175: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 174 - 175, The warning
message built from the local variable warns is missing a newline after the colon
and does not set a stacklevel, so update the warnings.warn call to prepend a
newline (e.g., "TTA has encountered issues with the given transforms:\n  " + "\n
".join(warns)) and pass an appropriate stacklevel (e.g., stacklevel=2) so user
stack traces point to the caller; locate and modify the warnings.warn(...)
invocation that uses the warns list in test_time_augmentation.py.


def __call__(
self, data: dict[str, Any], num_examples: int = 10
Expand Down Expand Up @@ -199,7 +206,10 @@ def __call__(
for b in tqdm(dl) if has_tqdm and self.progress else dl:
# do model forward pass
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
if self.apply_inverse_to_pred:
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
else:
outs.extend([i[self._pred_key] for i in decollate_batch(b)])

Comment on lines 208 to 213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Step 1: Examine Invertd class and its post-processing behavior
rg -n "class Invertd" monai/transforms/post/dictionary.py -A 30

Repository: Project-MONAI/MONAI

Length of output: 1743


🏁 Script executed:

#!/bin/bash
# Step 2: Check the test_time_augmentation.py file to understand context
head -250 monai/data/test_time_augmentation.py | tail -70

Repository: Project-MONAI/MONAI

Length of output: 3385


🏁 Script executed:

#!/bin/bash
# Step 3: Search for post_func and output_device in Invertd
rg -n "post_func|output_device|to_tensor" monai/transforms/post/dictionary.py | head -30

Repository: Project-MONAI/MONAI

Length of output: 1179


🏁 Script executed:

#!/bin/bash
# Search for tests covering apply_inverse_to_pred behavior
rg -n "apply_inverse_to_pred" --type py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 3372


🏁 Script executed:

#!/bin/bash
# Check how inverter is initialized - what post_func/to_tensor/output_device values are used
rg -n "self.inverter|Invertd" monai/data/test_time_augmentation.py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 961


🏁 Script executed:

#!/bin/bash
# Look at the full __init__ of TestTimeAugmentation to understand inverter setup
sed -n '1,207p' monai/data/test_time_augmentation.py | tail -100

Repository: Project-MONAI/MONAI

Length of output: 4408


Non-inverse branch ignores to_tensor, output_device, and post_func parameters.

When apply_inverse_to_pred=False, the code skips Invertd entirely (line 212), so these configured post-processing steps are never applied. Users may expect them to work regardless of inversion. Either apply post-processing in both branches, or add a post-processing-only path to Invertd when inversion is not needed.

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 208 - 213, The non-inverse
branch currently returns raw predictions and skips all Invertd post-processing
(to_tensor, output_device, post_func); update the branch so decollated items
still go through the same inverter pipeline (or a post-processing-only path)
before extracting self._pred_key. Concretely, in the else branch replace
outs.extend([i[self._pred_key] for i in decollate_batch(b)]) with code that
calls self.inverter on each PadListDataCollate.inverse(i) (or calls an Invertd
method/flag that runs only to_tensor/output_device/post_func but not spatial
inverse) and then extracts [self._pred_key]; ensure the call honors to_tensor,
output_device and post_func parameters so behavior matches the
apply_inverse_to_pred=True path.

output: NdarrayOrTensor = stack(outs, 0)

Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_test_time_augmentation(self):
# output might be different size, so pad so that they match
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

Expand Down Expand Up @@ -181,6 +181,43 @@ def test_image_no_label(self):
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image")
tta(self.get_data(1, (20, 20), include_label=False))

def test_non_spatial_output(self):
"""
Test TTA for non-spatial output (e.g., classification scores).
Verifies that setting `apply_inverse_to_pred=False` correctly aggregates
predictions without attempting spatial inversion.
"""
input_size = (20, 20)
data = {"image": np.random.rand(1, *input_size).astype(np.float32)}

transforms = Compose(
[EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)]
)

def mock_classifier(x):
batch_size = x.shape[0]
return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)

tt_aug = TestTimeAugmentation(
transform=transforms,
batch_size=2,
num_workers=0,
inferrer_fn=mock_classifier,
device="cpu",
orig_key="image",
apply_inverse_to_pred=False,
return_full_data=False,
)
mode, mean, std, vvc = tt_aug(data, num_examples=4)

self.assertEqual(mean.shape, (2,))
np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)
np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)

tt_aug.return_full_data = True
full_output = tt_aug(data, num_examples=4)
self.assertEqual(full_output.shape, (4, 2))


if __name__ == "__main__":
unittest.main()
Loading