-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Generalize TestTimeAugmentation to non-spatial predictions #8715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
| from copy import deepcopy | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.config.type_definitions import NdarrayOrTensor | ||
|
|
@@ -68,7 +67,7 @@ class TestTimeAugmentation: | |
| Args: | ||
| transform: transform (or composed) to be applied to each realization. At least one transform must be of type | ||
| `RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`). | ||
| . All random transforms must be of type `InvertibleTransform`. | ||
| All random transforms must be of type `InvertibleTransform`. | ||
| batch_size: number of realizations to infer at once. | ||
| num_workers: how many subprocesses to use for data. | ||
| inferrer_fn: function to use to perform inference. | ||
|
|
@@ -92,6 +91,11 @@ class TestTimeAugmentation: | |
| will return the full data. Dimensions will be same size as when passing a single image through | ||
| `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. | ||
| progress: whether to display a progress bar. | ||
| apply_inverse_to_pred: whether to apply inverse transformations to the predictions. | ||
| If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions | ||
| back to the original spatial reference. | ||
| If the prediction is non-spatial (e.g. classification label or score), this should be `False` to | ||
| aggregate the raw predictions directly. Defaults to `True`. | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
@@ -125,6 +129,7 @@ def __init__( | |
| post_func: Callable = _identity, | ||
| return_full_data: bool = False, | ||
| progress: bool = True, | ||
| apply_inverse_to_pred: bool = True, | ||
| ) -> None: | ||
| self.transform = transform | ||
| self.batch_size = batch_size | ||
|
|
@@ -134,6 +139,7 @@ def __init__( | |
| self.image_key = image_key | ||
| self.return_full_data = return_full_data | ||
| self.progress = progress | ||
| self.apply_inverse_to_pred = apply_inverse_to_pred | ||
| self._pred_key = CommonKeys.PRED | ||
| self.inverter = Invertd( | ||
| keys=self._pred_key, | ||
|
|
@@ -152,20 +158,21 @@ def __init__( | |
|
|
||
| def _check_transforms(self): | ||
| """Should be at least 1 random transform, and all random transforms should be invertible.""" | ||
| ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms | ||
| randoms = np.array([isinstance(t, Randomizable) for t in ts]) | ||
| invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) | ||
| # check at least 1 random | ||
| if sum(randoms) == 0: | ||
| warnings.warn( | ||
| "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." | ||
| ) | ||
| # check that whenever randoms is True, invertibles is also true | ||
| for r, i in zip(randoms, invertibles): | ||
| if r and not i: | ||
| warnings.warn( | ||
| f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}" | ||
| ) | ||
| transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms | ||
| warns = [] | ||
| randoms = [] | ||
|
|
||
| for idx, t in enumerate(transforms): | ||
| if isinstance(t, Randomizable): | ||
| randoms.append(t) | ||
| if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform): | ||
| warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.") | ||
|
|
||
| if len(randoms) == 0: | ||
| warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.") | ||
|
|
||
| if len(warns) > 0: | ||
| warnings.warn("TTA has encountered issues with the given transforms:" + "\n ".join(warns)) | ||
|
Comment on lines
+174
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve warning formatting and stacklevel. Current warning concatenates without a newline and lacks a 🛠️ Proposed fix- warnings.warn("TTA has encountered issues with the given transforms:" + "\n ".join(warns))
+ warnings.warn(
+ "TTA has encountered issues with the given transforms:\n " + "\n ".join(warns),
+ stacklevel=2,
+ )🧰 Tools🪛 Ruff (0.14.13)175-175: No explicit Set (B028) 🤖 Prompt for AI Agents |
||
|
|
||
| def __call__( | ||
| self, data: dict[str, Any], num_examples: int = 10 | ||
|
|
@@ -199,7 +206,10 @@ def __call__( | |
| for b in tqdm(dl) if has_tqdm and self.progress else dl: | ||
| # do model forward pass | ||
| b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device)) | ||
| outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) | ||
| if self.apply_inverse_to_pred: | ||
| outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) | ||
| else: | ||
| outs.extend([i[self._pred_key] for i in decollate_batch(b)]) | ||
|
|
||
|
Comment on lines
208
to
213
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Step 1: Examine Invertd class and its post-processing behavior
rg -n "class Invertd" monai/transforms/post/dictionary.py -A 30Repository: Project-MONAI/MONAI Length of output: 1743 🏁 Script executed: #!/bin/bash
# Step 2: Check the test_time_augmentation.py file to understand context
head -250 monai/data/test_time_augmentation.py | tail -70Repository: Project-MONAI/MONAI Length of output: 3385 🏁 Script executed: #!/bin/bash
# Step 3: Search for post_func and output_device in Invertd
rg -n "post_func|output_device|to_tensor" monai/transforms/post/dictionary.py | head -30Repository: Project-MONAI/MONAI Length of output: 1179 🏁 Script executed: #!/bin/bash
# Search for tests covering apply_inverse_to_pred behavior
rg -n "apply_inverse_to_pred" --type py -B 2 -A 2Repository: Project-MONAI/MONAI Length of output: 3372 🏁 Script executed: #!/bin/bash
# Check how inverter is initialized - what post_func/to_tensor/output_device values are used
rg -n "self.inverter|Invertd" monai/data/test_time_augmentation.py -B 2 -A 2Repository: Project-MONAI/MONAI Length of output: 961 🏁 Script executed: #!/bin/bash
# Look at the full __init__ of TestTimeAugmentation to understand inverter setup
sed -n '1,207p' monai/data/test_time_augmentation.py | tail -100Repository: Project-MONAI/MONAI Length of output: 4408 Non-inverse branch ignores When 🤖 Prompt for AI Agents |
||
| output: NdarrayOrTensor = stack(outs, 0) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring still claims all random transforms must be invertible.
With
apply_inverse_to_pred=False, non‑invertible random transforms are allowed. Update the docstring (and consider widening thetransformtype hint) to match behavior. As per coding guidelines, keep docstrings aligned with behavior.✅ Suggested docstring fix
Also applies to: 115-118
🤖 Prompt for AI Agents