diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 3c79f5c41..6c92cb6e4 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,6 +8,9 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer +from pyhealth.interpret.methods.ensemble_crh import CrhEnsemble +from pyhealth.interpret.methods.ensemble_avg import AvgEnsemble +from pyhealth.interpret.methods.ensemble_var import VarEnsemble __all__ = [ "BaseInterpreter", @@ -19,5 +22,8 @@ "BasicGradientSaliencyMaps", "RandomBaseline", "ShapExplainer", - "LimeExplainer" + "LimeExplainer", + "CrhEnsemble", + "AvgEnsemble", + "VarEnsemble" ] diff --git a/pyhealth/interpret/methods/base_ensemble.py b/pyhealth/interpret/methods/base_ensemble.py new file mode 100644 index 000000000..691b1f28a --- /dev/null +++ b/pyhealth/interpret/methods/base_ensemble.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import torch + +from pyhealth.models import BaseModel +from .base_interpreter import BaseInterpreter + + +class BaseInterpreterEnsemble(BaseInterpreter): + """Abstract base class for ensemble interpreters. + + Provides the shared workflow for ensemble-based attribution: + + 1. Each expert interpreter independently computes attributions. + 2. The per-expert attribution maps are flattened, then normalized to + a common [0, 1] scale via competitive ranking. + 3. The normalized attributions are passed to :meth:`_ensemble`, which + concrete subclasses must override to implement a specific + aggregation strategy (e.g., CRH truth discovery, simple averaging, + majority voting). + 4. The aggregated result is unflattened back to the original tensor + shapes. + + Subclasses only need to implement :meth:`_ensemble`. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model) + assert len(experts) >= 3, "Ensemble must contain at least three interpreters for majority voting" + self.experts = experts + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Compute consensus attributions by ensembling all expert interpreters. + + Each expert's ``attribute`` method is called with the same inputs. + The resulting attribution maps are flattened, competitively ranked + to a common [0, 1] scale, and aggregated via the subclass-defined + :meth:`_ensemble` strategy. + + Args: + **kwargs: Input data dictionary from a dataloader batch. + Should contain feature tensors (or tuples of tensors) + keyed by the model's feature keys, plus optional label + or metadata tensors (which are forwarded to experts). + + Returns: + Dictionary mapping each feature key to a consensus attribution + tensor whose shape matches the corresponding input tensor. + """ + out_shape: dict[str, torch.Size] | None = None + attr_lst: list[torch.Tensor] = [] + for expert in self.experts: + attr = expert.attribute(**kwargs) + + # record the output shape from the first interpreter, + # since all interpreters should produce the same shape + if out_shape is None: + out_shape = {k: v.shape for k, v in attr.items()} + + flat_attr = self._flatten_attributions(attr) # shape (B, M) + attr_lst.append(flat_attr) + + # Combine the flattened attributions from all interpreters + attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) + # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) + attributions = self._competitive_ranking_normalize(attributions) # shape (B, I, M) + + # Resolve conflicts and aggregate across interpreters using CRH + consensus = self._ensemble(attributions) # shape (B, M) + assert out_shape is not None, "Output shape should have been determined from the first interpreter" + return self._unflatten_attributions(consensus, out_shape) # dict of tensors with original shapes + + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate normalized expert attributions into a single consensus. + + Subclasses must override this method to define the aggregation + strategy (e.g., iterative truth discovery, simple averaging). + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Aggregated tensor of shape ``(B, M)`` with values in [0, 1]. + """ + raise NotImplementedError("Subclasses must implement their ensemble aggregation strategy in this method") + + # ------------------------------------------------------------------ + # Private helper methods + # ------------------------------------------------------------------ + @staticmethod + def _flatten_attributions( + values: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Flatten values dictionary to a single tensor. + + Takes a dictionary of tensors with shape (B, *) and flattens each to (B, M_i), + then concatenates them along the feature dimension to get (B, M). + + Args: + values: Dictionary mapping feature keys to tensors of shape (B, *). + + Returns: + Flattened tensor of shape (B, M) where M is the sum of all flattened dimensions. + """ + flattened_list = [] + for key in sorted(values.keys()): # Sort for consistency + tensor = values[key] + batch_size = tensor.shape[0] + # Flatten all dimensions except batch + flattened = tensor.reshape(batch_size, -1) + flattened_list.append(flattened) + + # Concatenate along feature dimension + return torch.cat(flattened_list, dim=1) + + @staticmethod + def _unflatten_attributions( + flattened: torch.Tensor, + shapes: dict[str, torch.Size], + ) -> dict[str, torch.Tensor]: + """Unflatten tensor back to values dictionary. + + Takes a flattened tensor of shape (B, M) and original shapes, + and reconstructs the original dictionary of tensors. + + Args: + flattened: Flattened tensor of shape (B, M). + shapes: Dictionary mapping feature keys to original tensor shapes. + + Returns: + Dictionary mapping feature keys to tensors with original shapes. + """ + values = {} + offset = 0 + + for key in sorted(shapes.keys()): # Must match the order in _flatten_values + shape = shapes[key] + batch_size = shape[0] + + # Calculate the size of the flattened feature + feature_size = 1 + for s in shape[1:]: + feature_size *= s + + # Extract the relevant portion and reshape + values[key] = flattened[:, offset : offset + feature_size].reshape(shape) + offset += feature_size + + return values + + + @staticmethod + def _competitive_ranking_normalize(x: torch.Tensor) -> torch.Tensor: + """Normalize a tensor via competitive (standard competition) ranking. + + For each (batch, expert) slice, items are ranked ascendingly from + 0 to ``total_item - 1``. Tied scores receive the same rank — the + smallest position index among the tied group (standard competition / + "1224" ranking). The ranks are then divided by ``total_item - 1`` + so that the output lies in [0, 1]. + + Args: + x: Tensor of shape ``(B, I, M)`` + containing unbounded floating-point scores. + + Returns: + Tensor of the same shape with values in [0, 1]. + """ + batch_size, num_experts, num_items = x.shape + + if num_items <= 1: + # With a single item the rank is 0 and 0/0 is undefined; + # return zeros as a safe default. + return torch.zeros_like(x) + + # 1. Sort ascending along the item dimension + sorted_vals, sort_indices = x.sort(dim=-1) + + # 2. Build a mask that is True at positions where the value changes + # from the previous position (i.e. the start of a new rank group). + change_mask = torch.ones(batch_size, num_experts, num_items, dtype=torch.bool, device=x.device) + change_mask[..., 1:] = sorted_vals[..., 1:] != sorted_vals[..., :-1] + + # 3. Assign competitive ranks in sorted order. + # At change positions the rank equals the position index; + # at tie positions we propagate the rank of the first occurrence + # via cummax (all non-change positions are set to -1 so cummax + # naturally carries forward the last "real" rank). + positions = torch.arange(num_items, device=x.device, dtype=torch.long).expand(batch_size, num_experts, num_items) + ranks_sorted = torch.where( + change_mask, + positions, + torch.full_like(positions, -1), + ) + ranks_sorted, _ = ranks_sorted.cummax(dim=-1) + + # 4. Scatter the ranks back to the original (unsorted) order + ranks = torch.zeros_like(x) + ranks.scatter_(-1, sort_indices, ranks_sorted.to(x.dtype)) + + # 5. Normalize to [0, 1] + return ranks / (num_items - 1) diff --git a/pyhealth/interpret/methods/ensemble_avg.py b/pyhealth/interpret/methods/ensemble_avg.py new file mode 100644 index 000000000..7f66f14eb --- /dev/null +++ b/pyhealth/interpret/methods/ensemble_avg.py @@ -0,0 +1,77 @@ +"""Average ensemble interpreter. + +This module implements the AGGMean ensemble strategy, which aggregates +attributions from multiple interpretability experts by taking the uniform +average of their competitively-ranked importance scores. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class AvgEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using uniform averaging (AGGMean / Borda). + + Computes the consensus attribution as the simple arithmetic mean + of the competitively-ranked attributions from all expert interpreters. + This is the simplest ensemble strategy — every expert contributes + equally regardless of its agreement with the others. + + Because the inputs are already competitively ranked, averaging is + equivalent (up to a constant factor) to the Borda count, which sums + the ranks instead. The two methods therefore produce identical + feature orderings. + + Implements the AGGMean method from: + + Rieger, L. and Hansen, L. K. "Aggregating Explanation Methods + for Stable and Robust Explainability." arXiv preprint + arXiv:1903.00519, 2019. + + See also the Borda aggregation in: + + Chen, Y., Mancini, M., Zhu, X., and Akata, Z. "Ensemble + Interpretation: A Unified Method for Interpretable Machine + Learning." arXiv preprint arXiv:2312.06255, 2023. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + + Example: + >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer + >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] + >>> ensemble = AvgEnsemble(model, experts) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model, experts) + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions by uniform averaging. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)`` with values in [0, 1]. + """ + return torch.mean(attributions, dim=1) \ No newline at end of file diff --git a/pyhealth/interpret/methods/ensemble_crh.py b/pyhealth/interpret/methods/ensemble_crh.py new file mode 100644 index 000000000..2a901eb9b --- /dev/null +++ b/pyhealth/interpret/methods/ensemble_crh.py @@ -0,0 +1,134 @@ +"""CRH ensemble interpreter. + +This module implements the Conflict Resolution on Heterogeneous data (CRH) +ensemble strategy for aggregating attributions from multiple interpretability +experts into a single consensus attribution. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class CrhEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using Conflict Resolution on Heterogeneous data (CRH). + + Iteratively estimates a consensus attribution by reweighting experts + according to their agreement with the current consensus estimate. + Experts whose attributions are closer to the consensus receive higher + weights, which in turn pulls the consensus toward more reliable + experts. + + This implements the truth-discovery algorithm from: + + Li, Q., Li, Y., Gao, J., Zhao, B., Fan, W., and Han, J. + "Resolving Conflicts in Heterogeneous Data by Truth Discovery + and Source Reliability Estimation." In *Proceedings of the 2014 + ACM SIGMOD International Conference on Management of Data* + (SIGMOD'14), pp. 1187–1198, 2014. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + n_iter: Maximum number of CRH refinement iterations. Higher values + allow more precise convergence at the cost of computation. + low_confidence_threshold: If set, batches where the standard + deviation of the final expert weights falls below this value + are considered low-confidence, and their consensus is replaced + by a simple uniform average of all experts. + early_stopping_threshold: If set, the CRH loop terminates early + when the maximum absolute change in the consensus vector + between successive iterations is below this value. + + Example: + >>> from pyhealth.interpret.methods import GradientShap, IntegratedGradients, Saliency + >>> experts = [GradientShap(model), IntegratedGradients(model), Saliency(model)] + >>> ensemble = CrhEnsemble(model, experts, n_iter=30) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + n_iter: int = 20, + low_confidence_threshold: float | None = None, + early_stopping_threshold: float | None = None, + ): + super().__init__(model, experts) + self.n_iter = n_iter + self.low_confidence_threshold = low_confidence_threshold + self.early_stopping_threshold = early_stopping_threshold + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions via the CRH truth-discovery algorithm. + + The consensus is initialised as the median across experts and then + iteratively refined: on each iteration, expert weights are set + inversely proportional to their mean squared error against the + current consensus, and the consensus is updated as the weighted + average of all experts. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)`` with values in [0, 1]. + """ + # Step 1: Initialize truth as median across experts (B, M) + t = torch.median(attributions, dim=1).values # (B, M) + + # Iterative refinement + eps = 1e-6 + + for _ in range(self.n_iter): + t_old = t.clone() + + # Step 2: Compute expert reliability per batch + # errors: (B, I) - mean squared error per expert per batch + errors = torch.mean((attributions - t.unsqueeze(1)) ** 2, dim=2) # (B, I) + + # weights: (B, I) + w = 1.0 / (eps + errors) # (B, I) + w = w / w.sum(dim=1, keepdim=True) # normalize per batch + + # Step 3: Update truth as weighted average + # t: (B, M) = sum over experts of w * attributions + t = torch.sum(w.unsqueeze(2) * attributions, dim=1) # (B, M) + + # Early stopping: check convergence per batch + if self.early_stopping_threshold is not None: + if torch.allclose(t, t_old, atol=self.early_stopping_threshold): + break + + if self.low_confidence_threshold is None: + # If no low confidence threshold is set, just return the CRH result + return t + + # Detect low-confidence batches where all experts are equally weighted + # If std(w) is very low, it means no expert is clearly better + w_std = torch.std(w, dim=1) # type: ignore[assignment] (B,) + low_confidence = w_std < self.low_confidence_threshold # (B,) + + # For low-confidence batches, fall back to uniform weighting (mean) + if low_confidence.any(): + uniform_consensus = torch.mean(attributions, dim=1) # (B, M) + t = torch.where( + low_confidence.unsqueeze(1), # (B, 1) + uniform_consensus, + t + ) + + return t \ No newline at end of file diff --git a/pyhealth/interpret/methods/ensemble_var.py b/pyhealth/interpret/methods/ensemble_var.py new file mode 100644 index 000000000..629adf26c --- /dev/null +++ b/pyhealth/interpret/methods/ensemble_var.py @@ -0,0 +1,73 @@ +"""Variance-weighted ensemble interpreter. + +This module implements the AGGVar ensemble strategy, which aggregates +attributions from multiple interpretability experts by dividing the mean +attribution by the standard deviation, penalising features where experts +disagree. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class VarEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using variance-weighted averaging (AGGVar). + + Computes the consensus attribution by dividing the mean of the + competitively-ranked expert attributions by their standard deviation + (plus a small constant ε for numerical stability). Features that + all experts agree are important receive high scores, while features + with high inter-expert disagreement are suppressed. + + Implements the AGGVar method from: + + Rieger, L. and Hansen, L. K. "Aggregating Explanation Methods + for Stable and Robust Explainability." arXiv preprint + arXiv:1903.00519, 2019. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + + Example: + >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer + >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] + >>> ensemble = VarEnsemble(model, experts) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model, experts) + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions via variance-weighted averaging. + + Computes ``mean / (std + ε)`` across experts for each feature, + rewarding consensus and penalising disagreement. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)``. + """ + mean = torch.mean(attributions, dim=1) # (B, M) + std = torch.std(attributions, dim=1) # (B, M) + return mean / (std + 1e-6) \ No newline at end of file diff --git a/tests/core/test_interpret_ensemble.py b/tests/core/test_interpret_ensemble.py new file mode 100644 index 000000000..43655db39 --- /dev/null +++ b/tests/core/test_interpret_ensemble.py @@ -0,0 +1,381 @@ +""" +Test suite for Ensemble interpreter implementation. +""" +import unittest + +import torch +import torch.nn as nn + +from pyhealth.models import BaseModel +from pyhealth.interpret.methods.base_ensemble import BaseInterpreterEnsemble +from pyhealth.interpret.methods.base_interpreter import BaseInterpreter + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + +class TestEnsembleFlattenValues(unittest.TestCase): + """Tests for the _flatten_values method.""" + + def test_flatten_single_feature_1d(self): + """Test flattening a single 1D feature.""" + values = { + "feature": torch.randn(2, 3), # (B=2, M=3) + } + flattened = BaseInterpreterEnsemble._flatten_attributions(values) + + self.assertEqual(flattened.shape, torch.Size([2, 3])) + torch.testing.assert_close(flattened, values["feature"]) + + def test_flatten_single_feature_2d(self): + """Test flattening a single 2D feature.""" + values = { + "feature": torch.randn(2, 4, 5), # (B=2, 4, 5) -> (B=2, 20) + } + flattened = BaseInterpreterEnsemble._flatten_attributions(values) + + self.assertEqual(flattened.shape, torch.Size([2, 20])) + expected = values["feature"].reshape(2, -1) + torch.testing.assert_close(flattened, expected) + + def test_flatten_multiple_features(self): + """Test flattening multiple features with different shapes.""" + values = { + "feature_a": torch.randn(3, 5), # (B=3, 5) -> (B=3, 5) + "feature_b": torch.randn(3, 2, 3), # (B=3, 2, 3) -> (B=3, 6) + "feature_c": torch.randn(3, 4), # (B=3, 4) -> (B=3, 4) + } + flattened = BaseInterpreterEnsemble._flatten_attributions(values) + + # Total flattened size = 5 + 6 + 4 = 15 + self.assertEqual(flattened.shape, torch.Size([3, 15])) + + def test_flatten_preserves_batch_size(self): + """Test that flattened output has correct batch size.""" + for batch_size in [1, 2, 5, 16]: + values = { + "feature_a": torch.randn(batch_size, 3), + "feature_b": torch.randn(batch_size, 2, 4), + } + flattened = BaseInterpreterEnsemble._flatten_attributions(values) + + self.assertEqual(flattened.shape[0], batch_size) + + def test_flatten_consistency_with_sorted_keys(self): + """Test that flattening is consistent with sorted key ordering.""" + values = { + "zebra": torch.randn(2, 3), + "apple": torch.randn(2, 4), + "banana": torch.randn(2, 2), + } + + flattened = BaseInterpreterEnsemble._flatten_attributions(values) + + # Should be ordered alphabetically: apple (4), banana (2), zebra (3) + self.assertEqual(flattened.shape, torch.Size([2, 9])) + + # Verify order by checking slices + apple_slice = flattened[:, :4] + banana_slice = flattened[:, 4:6] + zebra_slice = flattened[:, 6:9] + + torch.testing.assert_close(apple_slice, values["apple"].reshape(2, -1)) + torch.testing.assert_close(banana_slice, values["banana"].reshape(2, -1)) + torch.testing.assert_close(zebra_slice, values["zebra"].reshape(2, -1)) + + +class TestEnsembleUnflattenValues(unittest.TestCase): + """Tests for the _unflatten_values method.""" + + def test_unflatten_single_feature_1d(self): + """Test unflattening a single 1D feature.""" + shapes = {"feature": torch.Size([2, 3])} + flattened = torch.randn(2, 3) + + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + self.assertIn("feature", unflattened) + self.assertEqual(unflattened["feature"].shape, torch.Size([2, 3])) + torch.testing.assert_close(unflattened["feature"], flattened) + + def test_unflatten_single_feature_2d(self): + """Test unflattening a single 2D feature.""" + original_shape = torch.Size([2, 4, 5]) + shapes = {"feature": original_shape} + flattened = torch.randn(2, 20) + + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(unflattened["feature"].shape, original_shape) + + def test_unflatten_multiple_features(self): + """Test unflattening multiple features.""" + shapes = { + "feature_a": torch.Size([3, 5]), + "feature_b": torch.Size([3, 2, 3]), + "feature_c": torch.Size([3, 4]), + } + flattened = torch.randn(3, 15) + + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(len(unflattened), 3) + self.assertEqual(unflattened["feature_a"].shape, torch.Size([3, 5])) + self.assertEqual(unflattened["feature_b"].shape, torch.Size([3, 2, 3])) + self.assertEqual(unflattened["feature_c"].shape, torch.Size([3, 4])) + + def test_unflatten_preserves_batch_size(self): + """Test that unflattening preserves batch dimension.""" + for batch_size in [1, 2, 5, 16]: + shapes = { + "feature_a": torch.Size([batch_size, 3]), + "feature_b": torch.Size([batch_size, 2, 4]), + } + flattened = torch.randn(batch_size, 11) + + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(unflattened["feature_a"].shape[0], batch_size) + self.assertEqual(unflattened["feature_b"].shape[0], batch_size) + + +class TestEnsembleRoundtrip(unittest.TestCase): + """Tests for flatten/unflatten roundtrip consistency.""" + + def test_roundtrip_single_feature(self): + """Test that flatten->unflatten recovers original single feature.""" + original = { + "feature": torch.randn(4, 6), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + torch.testing.assert_close(unflattened["feature"], original["feature"]) + + def test_roundtrip_multiple_features(self): + """Test that flatten->unflatten recovers original with multiple features.""" + original = { + "feature_a": torch.randn(5, 3), + "feature_b": torch.randn(5, 2, 4), + "feature_c": torch.randn(5, 7), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key], original[key]) + + def test_roundtrip_high_dimensional(self): + """Test roundtrip with high-dimensional tensors.""" + original = { + "feature_a": torch.randn(2, 3, 4, 5), + "feature_b": torch.randn(2, 2, 3), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key], original[key]) + + def test_roundtrip_maintains_device(self): + """Test that roundtrip maintains tensor device.""" + original = { + "feature_a": torch.randn(2, 3), + "feature_b": torch.randn(2, 4, 5), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + for key in original: + self.assertEqual(unflattened[key].device, original[key].device) + + def test_roundtrip_with_gradients(self): + """Test that roundtrip works with gradient-tracking tensors.""" + original = { + "feature_a": torch.randn(2, 3, requires_grad=True), + "feature_b": torch.randn(2, 4, 5, requires_grad=True), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key].detach(), original[key].detach()) + + +class TestCompetitiveRankingNormalize(unittest.TestCase): + """Tests for competitive_ranking_noramlize static method.""" + + def test_all_distinct_values(self): + """When all values are distinct, ranks should be 0..M-1 (normalized).""" + # (B=1, I=1, M=5) with distinct values + x = torch.tensor([[[3.0, 1.0, 4.0, 1.5, 2.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + + # Sorted ascending: 1.0(idx1)=0, 1.5(idx3)=1, 2.0(idx4)=2, 3.0(idx0)=3, 4.0(idx2)=4 + # Normalized by (M-1)=4 + expected = torch.tensor([[[3/4, 0/4, 4/4, 1/4, 2/4]]]) + torch.testing.assert_close(result, expected) + + def test_output_shape(self): + """Output shape must match input shape.""" + for B, I, M in [(2, 3, 5), (1, 1, 10), (4, 2, 7)]: + x = torch.randn(B, I, M) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertEqual(result.shape, x.shape) + + def test_output_range(self): + """All output values must lie in [0, 1].""" + x = torch.randn(4, 3, 20) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertTrue((result >= 0).all()) + self.assertTrue((result <= 1).all()) + + def test_tied_scores_get_same_rank(self): + """Tied scores must receive the same (minimum) rank — '1224' ranking.""" + # [1, 2, 2, 4] -> ranks [0, 1, 1, 3], normalized by 3 + x = torch.tensor([[[1.0, 2.0, 2.0, 4.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) + torch.testing.assert_close(result, expected) + + def test_all_tied(self): + """When every item has the same score, all ranks should be 0.""" + x = torch.tensor([[[5.0, 5.0, 5.0, 5.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_single_item(self): + """M=1 edge case: should return zeros.""" + x = torch.tensor([[[7.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_two_items_distinct(self): + """M=2, distinct values.""" + x = torch.tensor([[[3.0, 1.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + # 1.0 -> rank 0, 3.0 -> rank 1; normalized by 1 + expected = torch.tensor([[[1.0, 0.0]]]) + torch.testing.assert_close(result, expected) + + def test_two_items_tied(self): + """M=2, tied values.""" + x = torch.tensor([[[3.0, 3.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_multiple_tie_groups(self): + """Multiple distinct tie groups: [1,1,3,3,5].""" + x = torch.tensor([[[1.0, 1.0, 3.0, 3.0, 5.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + # ranks: [0, 0, 2, 2, 4], normalized by 4 + expected = torch.tensor([[[0/4, 0/4, 2/4, 2/4, 4/4]]]) + torch.testing.assert_close(result, expected) + + def test_batch_and_expert_independence(self): + """Each (batch, expert) slice must be ranked independently.""" + x = torch.tensor([ + [[3.0, 1.0, 2.0], # batch 0, expert 0 + [1.0, 3.0, 2.0]], # batch 0, expert 1 + [[2.0, 2.0, 1.0], # batch 1, expert 0 + [5.0, 5.0, 5.0]], # batch 1, expert 1 + ]) # (B=2, I=2, M=3) + + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + + # batch0, expert0: [3,1,2] -> ranks [2,0,1] / 2 + torch.testing.assert_close(result[0, 0], torch.tensor([2/2, 0/2, 1/2])) + # batch0, expert1: [1,3,2] -> ranks [0,2,1] / 2 + torch.testing.assert_close(result[0, 1], torch.tensor([0/2, 2/2, 1/2])) + # batch1, expert0: [2,2,1] -> ranks [1,1,0] / 2 + torch.testing.assert_close(result[1, 0], torch.tensor([1/2, 1/2, 0/2])) + # batch1, expert1: [5,5,5] -> ranks [0,0,0] / 2 + torch.testing.assert_close(result[1, 1], torch.tensor([0.0, 0.0, 0.0])) + + def test_negative_values(self): + """Negative values should be ranked correctly.""" + x = torch.tensor([[[-3.0, -1.0, -2.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + # Ascending: -3(0), -2(1), -1(2); normalized by 2 + expected = torch.tensor([[[0/2, 2/2, 1/2]]]) + torch.testing.assert_close(result, expected) + + def test_already_sorted_ascending(self): + """Input already sorted ascending.""" + x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.tensor([[[0/4, 1/4, 2/4, 3/4, 4/4]]]) + torch.testing.assert_close(result, expected) + + def test_sorted_descending(self): + """Input sorted descending.""" + x = torch.tensor([[[5.0, 4.0, 3.0, 2.0, 1.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + expected = torch.tensor([[[4/4, 3/4, 2/4, 1/4, 0/4]]]) + torch.testing.assert_close(result, expected) + + def test_competitive_not_dense_ranking(self): + """Verify it's competitive (1224) NOT dense (1223) ranking. + + For [10, 20, 20, 40]: competitive ranks are [0,1,1,3], not [0,1,1,2]. + """ + x = torch.tensor([[[10.0, 20.0, 20.0, 40.0]]]) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + # Competitive: [0, 1, 1, 3] / 3 + expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) + torch.testing.assert_close(result, expected) + # Dense would give [0, 1, 1, 2] / 3 — assert that's NOT what we get + dense = torch.tensor([[[0/3, 1/3, 1/3, 2/3]]]) + self.assertFalse(torch.allclose(result, dense)) + + def test_preserves_device(self): + """Output should be on the same device as input.""" + x = torch.randn(2, 3, 5) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertEqual(result.device, x.device) + + def test_dtype_float32(self): + """Works with float32 input.""" + x = torch.randn(2, 2, 6, dtype=torch.float32) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertEqual(result.dtype, torch.float32) + + def test_dtype_float64(self): + """Works with float64 input.""" + x = torch.randn(2, 2, 6, dtype=torch.float64) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertEqual(result.dtype, torch.float64) + + def test_large_tensor(self): + """Smoke test on a larger tensor.""" + x = torch.randn(8, 5, 100) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + self.assertEqual(result.shape, x.shape) + self.assertTrue((result >= 0).all()) + self.assertTrue((result <= 1).all()) + + def test_max_is_one_min_is_zero_when_all_distinct(self): + """When all items are distinct, min rank = 0, max rank = 1.""" + x = torch.randn(3, 2, 10) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) + for b in range(3): + for i in range(2): + self.assertAlmostEqual(result[b, i].min().item(), 0.0, places=6) + self.assertAlmostEqual(result[b, i].max().item(), 1.0, places=6) + + +if __name__ == "__main__": + unittest.main()