diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index bcd5ea91a9..33e443a127 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -16,7 +16,6 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any -import numpy as np import torch from monai.config.type_definitions import NdarrayOrTensor @@ -68,7 +67,7 @@ class TestTimeAugmentation: Args: transform: transform (or composed) to be applied to each realization. At least one transform must be of type `RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`). - . All random transforms must be of type `InvertibleTransform`. + All random transforms must be of type `InvertibleTransform`. batch_size: number of realizations to infer at once. num_workers: how many subprocesses to use for data. inferrer_fn: function to use to perform inference. @@ -92,6 +91,11 @@ class TestTimeAugmentation: will return the full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. + apply_inverse_to_pred: whether to apply inverse transformations to the predictions. + If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions + back to the original spatial reference. + If the prediction is non-spatial (e.g. classification label or score), this should be `False` to + aggregate the raw predictions directly. Defaults to `True`. Example: .. code-block:: python @@ -125,6 +129,7 @@ def __init__( post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, + apply_inverse_to_pred: bool = True, ) -> None: self.transform = transform self.batch_size = batch_size @@ -134,6 +139,7 @@ def __init__( self.image_key = image_key self.return_full_data = return_full_data self.progress = progress + self.apply_inverse_to_pred = apply_inverse_to_pred self._pred_key = CommonKeys.PRED self.inverter = Invertd( keys=self._pred_key, @@ -152,20 +158,21 @@ def __init__( def _check_transforms(self): """Should be at least 1 random transform, and all random transforms should be invertible.""" - ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms - randoms = np.array([isinstance(t, Randomizable) for t in ts]) - invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) - # check at least 1 random - if sum(randoms) == 0: - warnings.warn( - "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." - ) - # check that whenever randoms is True, invertibles is also true - for r, i in zip(randoms, invertibles): - if r and not i: - warnings.warn( - f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}" - ) + transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms + warns = [] + randoms = [] + + for idx, t in enumerate(transforms): + if isinstance(t, Randomizable): + randoms.append(t) + if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform): + warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.") + + if len(randoms) == 0: + warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.") + + if len(warns) > 0: + warnings.warn("TTA has encountered issues with the given transforms:" + "\n ".join(warns)) def __call__( self, data: dict[str, Any], num_examples: int = 10 @@ -199,7 +206,10 @@ def __call__( for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device)) - outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) + if self.apply_inverse_to_pred: + outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) + else: + outs.extend([i[self._pred_key] for i in decollate_batch(b)]) output: NdarrayOrTensor = stack(outs, 0) diff --git a/tests/integration/test_testtimeaugmentation.py b/tests/integration/test_testtimeaugmentation.py index 62e4b46282..84da7c9c15 100644 --- a/tests/integration/test_testtimeaugmentation.py +++ b/tests/integration/test_testtimeaugmentation.py @@ -104,7 +104,7 @@ def test_test_time_augmentation(self): # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) - model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) @@ -181,6 +181,43 @@ def test_image_no_label(self): tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") tta(self.get_data(1, (20, 20), include_label=False)) + def test_non_spatial_output(self): + """ + Test TTA for non-spatial output (e.g., classification scores). + Verifies that setting `apply_inverse_to_pred=False` correctly aggregates + predictions without attempting spatial inversion. + """ + input_size = (20, 20) + data = {"image": np.random.rand(1, *input_size).astype(np.float32)} + + transforms = Compose( + [EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)] + ) + + def mock_classifier(x): + batch_size = x.shape[0] + return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=2, + num_workers=0, + inferrer_fn=mock_classifier, + device="cpu", + orig_key="image", + apply_inverse_to_pred=False, + return_full_data=False, + ) + mode, mean, std, vvc = tt_aug(data, num_examples=4) + + self.assertEqual(mean.shape, (2,)) + np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6) + np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6) + + tt_aug.return_full_data = True + full_output = tt_aug(data, num_examples=4) + self.assertEqual(full_output.shape, (4, 2)) + if __name__ == "__main__": unittest.main()