From 73d3ad84b4a82d22cf61247cb54c56475790dad4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 11:02:32 -0600 Subject: [PATCH 01/13] Half impl for ensemble --- pyhealth/interpret/methods/__init__.py | 4 +- pyhealth/interpret/methods/ensemble.py | 196 ++++++++++ tests/core/test_interpret_ensemble.py | 498 +++++++++++++++++++++++++ 3 files changed, 697 insertions(+), 1 deletion(-) create mode 100644 pyhealth/interpret/methods/ensemble.py create mode 100644 tests/core/test_interpret_ensemble.py diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 3c79f5c41..9c3004ea6 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,6 +8,7 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer +from pyhealth.interpret.methods.ensemble import Ensemble __all__ = [ "BaseInterpreter", @@ -19,5 +20,6 @@ "BasicGradientSaliencyMaps", "RandomBaseline", "ShapExplainer", - "LimeExplainer" + "LimeExplainer", + "Ensemble" ] diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py new file mode 100644 index 000000000..8da84ecca --- /dev/null +++ b/pyhealth/interpret/methods/ensemble.py @@ -0,0 +1,196 @@ +"""Random baseline attribution method. + +This module implements a simple random attribution method that assigns +uniformly random importance scores to each input feature. It serves as a +baseline for evaluating the quality of more sophisticated interpretability +methods — any useful attribution technique should outperform random +assignments. +""" + +from __future__ import annotations + +from typing import Dict, Optional + +import torch + +from pyhealth.models import BaseModel +from .base_interpreter import BaseInterpreter + + +class Ensemble(BaseInterpreter): + def __init__( + self, + model: BaseModel, + interpreters: list[BaseInterpreter], + ): + super().__init__(model) + assert len(interpreters) >= 3, "Ensemble must contain at least three interpreters for majority voting" + self.interpreters = interpreters + + # ------------------------------------------------------------------ + # Private helper methods + # ------------------------------------------------------------------ + def _flatten_attributions( + self, + values: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Flatten values dictionary to a single tensor. + + Takes a dictionary of tensors with shape (B, *) and flattens each to (B, M_i), + then concatenates them along the feature dimension to get (B, M). + + Args: + values: Dictionary mapping feature keys to tensors of shape (B, *). + + Returns: + Flattened tensor of shape (B, M) where M is the sum of all flattened dimensions. + """ + flattened_list = [] + for key in sorted(values.keys()): # Sort for consistency + tensor = values[key] + batch_size = tensor.shape[0] + # Flatten all dimensions except batch + flattened = tensor.reshape(batch_size, -1) + flattened_list.append(flattened) + + # Concatenate along feature dimension + return torch.cat(flattened_list, dim=1) + + def _unflatten_attributions( + self, + flattened: torch.Tensor, + shapes: dict[str, torch.Size], + ) -> dict[str, torch.Tensor]: + """Unflatten tensor back to values dictionary. + + Takes a flattened tensor of shape (B, M) and original shapes, + and reconstructs the original dictionary of tensors. + + Args: + flattened: Flattened tensor of shape (B, M). + shapes: Dictionary mapping feature keys to original tensor shapes. + + Returns: + Dictionary mapping feature keys to tensors with original shapes. + """ + values = {} + offset = 0 + + for key in sorted(shapes.keys()): # Must match the order in _flatten_values + shape = shapes[key] + batch_size = shape[0] + + # Calculate the size of the flattened feature + feature_size = 1 + for s in shape[1:]: + feature_size *= s + + # Extract the relevant portion and reshape + values[key] = flattened[:, offset : offset + feature_size].reshape(shape) + offset += feature_size + + return values + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Compute random attributions for input features. + + Generates random importance scores with the same shape as each + input feature tensor. No gradients or forward passes are needed. + + Args: + **kwargs: Input data dictionary from a dataloader batch. + Should contain feature tensors (or tuples of tensors) + keyed by the model's feature keys, plus optional label + or metadata tensors (which are ignored). + + Returns: + Dictionary mapping each feature key to a random attribution + tensor whose shape matches the raw input values. + """ + out_shape: dict[str, torch.Size] | None = None + attr_lst: list[torch.Tensor] = [] + for interpreter in self.interpreters: + attr = interpreter.attribute(**kwargs) + + # record the output shape from the first interpreter, + # since all interpreters should produce the same shape + if out_shape is None: + out_shape = {k: v.shape for k, v in attr.items()} + + flat_attr = self._flatten_attributions(attr) # shape (B, M) + attr_lst.append(flat_attr) + + # Combine the flattened attributions from all interpreters + attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) + # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) + attributions = self.competitive_ranking_noramlize(attributions) # shape (B, I, M) + + + + # normalize the attributions across interpreters (e.g., by ranking) + _, rank = attributions.sort + + + + + + + raise NotImplementedError("Ensemble attribution method is not implemented yet. This is a placeholder for future development.") + + @staticmethod + def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: + """Normalize a tensor via competitive (standard competition) ranking. + + For each (batch, expert) slice, items are ranked ascendingly from + 0 to ``total_item - 1``. Tied scores receive the same rank — the + smallest position index among the tied group (standard competition / + "1224" ranking). The ranks are then divided by ``total_item - 1`` + so that the output lies in [0, 1]. + + Args: + x: Tensor of shape ``(batch_size, expert_size, total_item)`` + containing unbounded floating-point scores. + + Returns: + Tensor of the same shape with values in [0, 1]. + """ + B, I, M = x.shape + + if M <= 1: + # With a single item the rank is 0 and 0/0 is undefined; + # return zeros as a safe default. + return torch.zeros_like(x) + + # 1. Sort ascending along the item dimension + sorted_vals, sort_indices = x.sort(dim=-1) + + # 2. Build a mask that is True at positions where the value changes + # from the previous position (i.e. the start of a new rank group). + change_mask = torch.ones(B, I, M, dtype=torch.bool, device=x.device) + change_mask[..., 1:] = sorted_vals[..., 1:] != sorted_vals[..., :-1] + + # 3. Assign competitive ranks in sorted order. + # At change positions the rank equals the position index; + # at tie positions we propagate the rank of the first occurrence + # via cummax (all non-change positions are set to -1 so cummax + # naturally carries forward the last "real" rank). + positions = torch.arange(M, device=x.device, dtype=torch.long).expand(B, I, M) + ranks_sorted = torch.where( + change_mask, + positions, + torch.full_like(positions, -1), + ) + ranks_sorted, _ = ranks_sorted.cummax(dim=-1) + + # 4. Scatter the ranks back to the original (unsorted) order + ranks = torch.zeros_like(x) + ranks.scatter_(-1, sort_indices, ranks_sorted.to(x.dtype)) + + # 5. Normalize to [0, 1] + return ranks / (M - 1) diff --git a/tests/core/test_interpret_ensemble.py b/tests/core/test_interpret_ensemble.py new file mode 100644 index 000000000..f2f0888f4 --- /dev/null +++ b/tests/core/test_interpret_ensemble.py @@ -0,0 +1,498 @@ +""" +Test suite for Ensemble interpreter implementation. +""" +import unittest + +import torch +import torch.nn as nn + +from pyhealth.models import BaseModel +from pyhealth.interpret.methods import Ensemble +from pyhealth.interpret.methods.base_interpreter import BaseInterpreter + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + +class _MockProcessor: + """Mock feature processor with configurable schema.""" + + def __init__(self, schema_tuple=("value",)): + self._schema = schema_tuple + + def schema(self): + return self._schema + + def is_token(self): + return False + + +class _MockDataset: + """Lightweight stand-in for SampleDataset in unit tests.""" + + def __init__(self, input_schema, output_schema, processors=None): + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = processors or { + k: _MockProcessor() for k in input_schema + } + + +# --------------------------------------------------------------------------- +# Test model helpers +# --------------------------------------------------------------------------- + +class _SimpleModel(BaseModel): + """Minimal model for testing Ensemble.""" + + def __init__(self): + dataset = _MockDataset( + input_schema={"x": "tensor"}, + output_schema={"y": "binary"}, + ) + super().__init__(dataset=dataset) + self.linear = nn.Linear(3, 1, bias=True) + + def forward(self, **kwargs) -> dict: + x = kwargs["x"] + if isinstance(x, tuple): + x = x[0] + y = kwargs.get("y", None) + + logit = self.linear(x) + y_prob = torch.sigmoid(logit) + + result = { + "logit": logit, + "y_prob": y_prob, + "loss": torch.zeros((), device=y_prob.device), + } + if y is not None: + result["y_true"] = y.to(y_prob.device) + return result + + def forward_from_embedding(self, **kwargs) -> dict: + return self.forward(**kwargs) + + def get_embedding_model(self): + return None + + +class _DummyInterpreter(BaseInterpreter): + """Dummy interpreter that returns random attributions.""" + + def attribute(self, **kwargs) -> dict: + attributions = {} + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor): + attributions[key] = torch.randn_like(value) + return attributions + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + +class TestEnsembleFlattenValues(unittest.TestCase): + """Tests for the _flatten_values method.""" + + def setUp(self): + self.model = _SimpleModel() + self.model.eval() + + # Create 3 dummy interpreters + self.interpreters = [ + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + ] + + self.ensemble = Ensemble(self.model, self.interpreters) + + def test_flatten_single_feature_1d(self): + """Test flattening a single 1D feature.""" + values = { + "feature": torch.randn(2, 3), # (B=2, M=3) + } + flattened = self.ensemble._flatten_attributions(values) + + self.assertEqual(flattened.shape, torch.Size([2, 3])) + torch.testing.assert_close(flattened, values["feature"]) + + def test_flatten_single_feature_2d(self): + """Test flattening a single 2D feature.""" + values = { + "feature": torch.randn(2, 4, 5), # (B=2, 4, 5) -> (B=2, 20) + } + flattened = self.ensemble._flatten_attributions(values) + + self.assertEqual(flattened.shape, torch.Size([2, 20])) + expected = values["feature"].reshape(2, -1) + torch.testing.assert_close(flattened, expected) + + def test_flatten_multiple_features(self): + """Test flattening multiple features with different shapes.""" + values = { + "feature_a": torch.randn(3, 5), # (B=3, 5) -> (B=3, 5) + "feature_b": torch.randn(3, 2, 3), # (B=3, 2, 3) -> (B=3, 6) + "feature_c": torch.randn(3, 4), # (B=3, 4) -> (B=3, 4) + } + flattened = self.ensemble._flatten_attributions(values) + + # Total flattened size = 5 + 6 + 4 = 15 + self.assertEqual(flattened.shape, torch.Size([3, 15])) + + def test_flatten_preserves_batch_size(self): + """Test that flattened output has correct batch size.""" + for batch_size in [1, 2, 5, 16]: + values = { + "feature_a": torch.randn(batch_size, 3), + "feature_b": torch.randn(batch_size, 2, 4), + } + flattened = self.ensemble._flatten_attributions(values) + + self.assertEqual(flattened.shape[0], batch_size) + + def test_flatten_consistency_with_sorted_keys(self): + """Test that flattening is consistent with sorted key ordering.""" + values = { + "zebra": torch.randn(2, 3), + "apple": torch.randn(2, 4), + "banana": torch.randn(2, 2), + } + + flattened = self.ensemble._flatten_attributions(values) + + # Should be ordered alphabetically: apple (4), banana (2), zebra (3) + self.assertEqual(flattened.shape, torch.Size([2, 9])) + + # Verify order by checking slices + apple_slice = flattened[:, :4] + banana_slice = flattened[:, 4:6] + zebra_slice = flattened[:, 6:9] + + torch.testing.assert_close(apple_slice, values["apple"].reshape(2, -1)) + torch.testing.assert_close(banana_slice, values["banana"].reshape(2, -1)) + torch.testing.assert_close(zebra_slice, values["zebra"].reshape(2, -1)) + + +class TestEnsembleUnflattenValues(unittest.TestCase): + """Tests for the _unflatten_values method.""" + + def setUp(self): + self.model = _SimpleModel() + self.model.eval() + + self.interpreters = [ + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + ] + + self.ensemble = Ensemble(self.model, self.interpreters) + + def test_unflatten_single_feature_1d(self): + """Test unflattening a single 1D feature.""" + shapes = {"feature": torch.Size([2, 3])} + flattened = torch.randn(2, 3) + + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + self.assertIn("feature", unflattened) + self.assertEqual(unflattened["feature"].shape, torch.Size([2, 3])) + torch.testing.assert_close(unflattened["feature"], flattened) + + def test_unflatten_single_feature_2d(self): + """Test unflattening a single 2D feature.""" + original_shape = torch.Size([2, 4, 5]) + shapes = {"feature": original_shape} + flattened = torch.randn(2, 20) + + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(unflattened["feature"].shape, original_shape) + + def test_unflatten_multiple_features(self): + """Test unflattening multiple features.""" + shapes = { + "feature_a": torch.Size([3, 5]), + "feature_b": torch.Size([3, 2, 3]), + "feature_c": torch.Size([3, 4]), + } + flattened = torch.randn(3, 15) + + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(len(unflattened), 3) + self.assertEqual(unflattened["feature_a"].shape, torch.Size([3, 5])) + self.assertEqual(unflattened["feature_b"].shape, torch.Size([3, 2, 3])) + self.assertEqual(unflattened["feature_c"].shape, torch.Size([3, 4])) + + def test_unflatten_preserves_batch_size(self): + """Test that unflattening preserves batch dimension.""" + for batch_size in [1, 2, 5, 16]: + shapes = { + "feature_a": torch.Size([batch_size, 3]), + "feature_b": torch.Size([batch_size, 2, 4]), + } + flattened = torch.randn(batch_size, 11) + + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + self.assertEqual(unflattened["feature_a"].shape[0], batch_size) + self.assertEqual(unflattened["feature_b"].shape[0], batch_size) + + +class TestEnsembleRoundtrip(unittest.TestCase): + """Tests for flatten/unflatten roundtrip consistency.""" + + def setUp(self): + self.model = _SimpleModel() + self.model.eval() + + self.interpreters = [ + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + _DummyInterpreter(self.model), + ] + + self.ensemble = Ensemble(self.model, self.interpreters) + + def test_roundtrip_single_feature(self): + """Test that flatten->unflatten recovers original single feature.""" + original = { + "feature": torch.randn(4, 6), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = self.ensemble._flatten_attributions(original) + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + torch.testing.assert_close(unflattened["feature"], original["feature"]) + + def test_roundtrip_multiple_features(self): + """Test that flatten->unflatten recovers original with multiple features.""" + original = { + "feature_a": torch.randn(5, 3), + "feature_b": torch.randn(5, 2, 4), + "feature_c": torch.randn(5, 7), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = self.ensemble._flatten_attributions(original) + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key], original[key]) + + def test_roundtrip_high_dimensional(self): + """Test roundtrip with high-dimensional tensors.""" + original = { + "feature_a": torch.randn(2, 3, 4, 5), + "feature_b": torch.randn(2, 2, 3), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = self.ensemble._flatten_attributions(original) + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key], original[key]) + + def test_roundtrip_maintains_device(self): + """Test that roundtrip maintains tensor device.""" + original = { + "feature_a": torch.randn(2, 3), + "feature_b": torch.randn(2, 4, 5), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = self.ensemble._flatten_attributions(original) + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + for key in original: + self.assertEqual(unflattened[key].device, original[key].device) + + def test_roundtrip_with_gradients(self): + """Test that roundtrip works with gradient-tracking tensors.""" + original = { + "feature_a": torch.randn(2, 3, requires_grad=True), + "feature_b": torch.randn(2, 4, 5, requires_grad=True), + } + + shapes = {k: v.shape for k, v in original.items()} + flattened = self.ensemble._flatten_attributions(original) + unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + + for key in original: + torch.testing.assert_close(unflattened[key].detach(), original[key].detach()) + + +class TestCompetitiveRankingNormalize(unittest.TestCase): + """Tests for competitive_ranking_noramlize static method.""" + + def test_all_distinct_values(self): + """When all values are distinct, ranks should be 0..M-1 (normalized).""" + # (B=1, I=1, M=5) with distinct values + x = torch.tensor([[[3.0, 1.0, 4.0, 1.5, 2.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + + # Sorted ascending: 1.0(idx1)=0, 1.5(idx3)=1, 2.0(idx4)=2, 3.0(idx0)=3, 4.0(idx2)=4 + # Normalized by (M-1)=4 + expected = torch.tensor([[[3/4, 0/4, 4/4, 1/4, 2/4]]]) + torch.testing.assert_close(result, expected) + + def test_output_shape(self): + """Output shape must match input shape.""" + for B, I, M in [(2, 3, 5), (1, 1, 10), (4, 2, 7)]: + x = torch.randn(B, I, M) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertEqual(result.shape, x.shape) + + def test_output_range(self): + """All output values must lie in [0, 1].""" + x = torch.randn(4, 3, 20) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertTrue((result >= 0).all()) + self.assertTrue((result <= 1).all()) + + def test_tied_scores_get_same_rank(self): + """Tied scores must receive the same (minimum) rank — '1224' ranking.""" + # [1, 2, 2, 4] -> ranks [0, 1, 1, 3], normalized by 3 + x = torch.tensor([[[1.0, 2.0, 2.0, 4.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) + torch.testing.assert_close(result, expected) + + def test_all_tied(self): + """When every item has the same score, all ranks should be 0.""" + x = torch.tensor([[[5.0, 5.0, 5.0, 5.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_single_item(self): + """M=1 edge case: should return zeros.""" + x = torch.tensor([[[7.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_two_items_distinct(self): + """M=2, distinct values.""" + x = torch.tensor([[[3.0, 1.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + # 1.0 -> rank 0, 3.0 -> rank 1; normalized by 1 + expected = torch.tensor([[[1.0, 0.0]]]) + torch.testing.assert_close(result, expected) + + def test_two_items_tied(self): + """M=2, tied values.""" + x = torch.tensor([[[3.0, 3.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.zeros_like(x) + torch.testing.assert_close(result, expected) + + def test_multiple_tie_groups(self): + """Multiple distinct tie groups: [1,1,3,3,5].""" + x = torch.tensor([[[1.0, 1.0, 3.0, 3.0, 5.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + # ranks: [0, 0, 2, 2, 4], normalized by 4 + expected = torch.tensor([[[0/4, 0/4, 2/4, 2/4, 4/4]]]) + torch.testing.assert_close(result, expected) + + def test_batch_and_expert_independence(self): + """Each (batch, expert) slice must be ranked independently.""" + x = torch.tensor([ + [[3.0, 1.0, 2.0], # batch 0, expert 0 + [1.0, 3.0, 2.0]], # batch 0, expert 1 + [[2.0, 2.0, 1.0], # batch 1, expert 0 + [5.0, 5.0, 5.0]], # batch 1, expert 1 + ]) # (B=2, I=2, M=3) + + result = Ensemble.competitive_ranking_noramlize(x) + + # batch0, expert0: [3,1,2] -> ranks [2,0,1] / 2 + torch.testing.assert_close(result[0, 0], torch.tensor([2/2, 0/2, 1/2])) + # batch0, expert1: [1,3,2] -> ranks [0,2,1] / 2 + torch.testing.assert_close(result[0, 1], torch.tensor([0/2, 2/2, 1/2])) + # batch1, expert0: [2,2,1] -> ranks [1,1,0] / 2 + torch.testing.assert_close(result[1, 0], torch.tensor([1/2, 1/2, 0/2])) + # batch1, expert1: [5,5,5] -> ranks [0,0,0] / 2 + torch.testing.assert_close(result[1, 1], torch.tensor([0.0, 0.0, 0.0])) + + def test_negative_values(self): + """Negative values should be ranked correctly.""" + x = torch.tensor([[[-3.0, -1.0, -2.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + # Ascending: -3(0), -2(1), -1(2); normalized by 2 + expected = torch.tensor([[[0/2, 2/2, 1/2]]]) + torch.testing.assert_close(result, expected) + + def test_already_sorted_ascending(self): + """Input already sorted ascending.""" + x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.tensor([[[0/4, 1/4, 2/4, 3/4, 4/4]]]) + torch.testing.assert_close(result, expected) + + def test_sorted_descending(self): + """Input sorted descending.""" + x = torch.tensor([[[5.0, 4.0, 3.0, 2.0, 1.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + expected = torch.tensor([[[4/4, 3/4, 2/4, 1/4, 0/4]]]) + torch.testing.assert_close(result, expected) + + def test_competitive_not_dense_ranking(self): + """Verify it's competitive (1224) NOT dense (1223) ranking. + + For [10, 20, 20, 40]: competitive ranks are [0,1,1,3], not [0,1,1,2]. + """ + x = torch.tensor([[[10.0, 20.0, 20.0, 40.0]]]) + result = Ensemble.competitive_ranking_noramlize(x) + # Competitive: [0, 1, 1, 3] / 3 + expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) + torch.testing.assert_close(result, expected) + # Dense would give [0, 1, 1, 2] / 3 — assert that's NOT what we get + dense = torch.tensor([[[0/3, 1/3, 1/3, 2/3]]]) + self.assertFalse(torch.allclose(result, dense)) + + def test_preserves_device(self): + """Output should be on the same device as input.""" + x = torch.randn(2, 3, 5) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertEqual(result.device, x.device) + + def test_dtype_float32(self): + """Works with float32 input.""" + x = torch.randn(2, 2, 6, dtype=torch.float32) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertEqual(result.dtype, torch.float32) + + def test_dtype_float64(self): + """Works with float64 input.""" + x = torch.randn(2, 2, 6, dtype=torch.float64) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertEqual(result.dtype, torch.float64) + + def test_large_tensor(self): + """Smoke test on a larger tensor.""" + x = torch.randn(8, 5, 100) + result = Ensemble.competitive_ranking_noramlize(x) + self.assertEqual(result.shape, x.shape) + self.assertTrue((result >= 0).all()) + self.assertTrue((result <= 1).all()) + + def test_max_is_one_min_is_zero_when_all_distinct(self): + """When all items are distinct, min rank = 0, max rank = 1.""" + x = torch.randn(3, 2, 10) + result = Ensemble.competitive_ranking_noramlize(x) + for b in range(3): + for i in range(2): + self.assertAlmostEqual(result[b, i].min().item(), 0.0, places=6) + self.assertAlmostEqual(result[b, i].max().item(), 1.0, places=6) + + +if __name__ == "__main__": + unittest.main() From 8780b1042a45ec07d2f14378129b32dd83c530c9 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 12:26:27 -0600 Subject: [PATCH 02/13] partial impl --- pyhealth/interpret/methods/ensemble.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py index 8da84ecca..9645baa9d 100644 --- a/pyhealth/interpret/methods/ensemble.py +++ b/pyhealth/interpret/methods/ensemble.py @@ -131,14 +131,6 @@ def attribute( # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) attributions = self.competitive_ranking_noramlize(attributions) # shape (B, I, M) - - - # normalize the attributions across interpreters (e.g., by ranking) - _, rank = attributions.sort - - - - raise NotImplementedError("Ensemble attribution method is not implemented yet. This is a placeholder for future development.") @@ -194,3 +186,16 @@ def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: # 5. Normalize to [0, 1] return ranks / (M - 1) + + def conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: + """Truth discovery using CRH algorithm. This try to estimate the true importance scores + by iteratively reweighting the experts based on their agreement with the current consensus. + + Args: + attributions: A normalized attribution tensor of shape (B, I, M) + where values are in [0, 1]. + + Returns: + Tensor of the same shape with values in [0, 1]. + """ + pas \ No newline at end of file From 014aae2890327a087d033ee8bc194ece34d6e9f2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 11 Feb 2026 02:09:28 -0600 Subject: [PATCH 03/13] rename variables --- pyhealth/interpret/methods/ensemble.py | 119 +++++++++++++------------ tests/core/test_interpret_ensemble.py | 38 ++++---- 2 files changed, 80 insertions(+), 77 deletions(-) diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py index 9645baa9d..c218ef1fa 100644 --- a/pyhealth/interpret/methods/ensemble.py +++ b/pyhealth/interpret/methods/ensemble.py @@ -21,17 +21,63 @@ class Ensemble(BaseInterpreter): def __init__( self, model: BaseModel, - interpreters: list[BaseInterpreter], + experts: list[BaseInterpreter], ): super().__init__(model) - assert len(interpreters) >= 3, "Ensemble must contain at least three interpreters for majority voting" - self.interpreters = interpreters + assert len(experts) >= 3, "Ensemble must contain at least three interpreters for majority voting" + self.experts = experts + + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Compute random attributions for input features. + + Generates random importance scores with the same shape as each + input feature tensor. No gradients or forward passes are needed. + + Args: + **kwargs: Input data dictionary from a dataloader batch. + Should contain feature tensors (or tuples of tensors) + keyed by the model's feature keys, plus optional label + or metadata tensors (which are ignored). + + Returns: + Dictionary mapping each feature key to a random attribution + tensor whose shape matches the raw input values. + """ + out_shape: dict[str, torch.Size] | None = None + attr_lst: list[torch.Tensor] = [] + for expert in self.experts: + attr = expert.attribute(**kwargs) + + # record the output shape from the first interpreter, + # since all interpreters should produce the same shape + if out_shape is None: + out_shape = {k: v.shape for k, v in attr.items()} + + flat_attr = self._flatten_attributions(attr) # shape (B, M) + attr_lst.append(flat_attr) + + # Combine the flattened attributions from all interpreters + attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) + # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) + attributions = self._competitive_ranking_noramlize(attributions) # shape (B, I, M) + + + + raise NotImplementedError("Ensemble attribution method is not implemented yet. This is a placeholder for future development.") + # ------------------------------------------------------------------ # Private helper methods # ------------------------------------------------------------------ + @staticmethod def _flatten_attributions( - self, values: dict[str, torch.Tensor], ) -> torch.Tensor: """Flatten values dictionary to a single tensor. @@ -56,8 +102,8 @@ def _flatten_attributions( # Concatenate along feature dimension return torch.cat(flattened_list, dim=1) + @staticmethod def _unflatten_attributions( - self, flattened: torch.Tensor, shapes: dict[str, torch.Size], ) -> dict[str, torch.Tensor]: @@ -91,52 +137,9 @@ def _unflatten_attributions( return values - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - def attribute( - self, - **kwargs: torch.Tensor | tuple[torch.Tensor, ...], - ) -> Dict[str, torch.Tensor]: - """Compute random attributions for input features. - - Generates random importance scores with the same shape as each - input feature tensor. No gradients or forward passes are needed. - - Args: - **kwargs: Input data dictionary from a dataloader batch. - Should contain feature tensors (or tuples of tensors) - keyed by the model's feature keys, plus optional label - or metadata tensors (which are ignored). - - Returns: - Dictionary mapping each feature key to a random attribution - tensor whose shape matches the raw input values. - """ - out_shape: dict[str, torch.Size] | None = None - attr_lst: list[torch.Tensor] = [] - for interpreter in self.interpreters: - attr = interpreter.attribute(**kwargs) - - # record the output shape from the first interpreter, - # since all interpreters should produce the same shape - if out_shape is None: - out_shape = {k: v.shape for k, v in attr.items()} - - flat_attr = self._flatten_attributions(attr) # shape (B, M) - attr_lst.append(flat_attr) - - # Combine the flattened attributions from all interpreters - attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) - # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) - attributions = self.competitive_ranking_noramlize(attributions) # shape (B, I, M) - - - - raise NotImplementedError("Ensemble attribution method is not implemented yet. This is a placeholder for future development.") @staticmethod - def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: + def _competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: """Normalize a tensor via competitive (standard competition) ranking. For each (batch, expert) slice, items are ranked ascendingly from @@ -146,15 +149,15 @@ def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: so that the output lies in [0, 1]. Args: - x: Tensor of shape ``(batch_size, expert_size, total_item)`` + x: Tensor of shape ``(B, I, M)`` containing unbounded floating-point scores. Returns: Tensor of the same shape with values in [0, 1]. """ - B, I, M = x.shape + batch_size, num_experts, num_items = x.shape - if M <= 1: + if num_items <= 1: # With a single item the rank is 0 and 0/0 is undefined; # return zeros as a safe default. return torch.zeros_like(x) @@ -164,7 +167,7 @@ def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: # 2. Build a mask that is True at positions where the value changes # from the previous position (i.e. the start of a new rank group). - change_mask = torch.ones(B, I, M, dtype=torch.bool, device=x.device) + change_mask = torch.ones(batch_size, num_experts, num_items, dtype=torch.bool, device=x.device) change_mask[..., 1:] = sorted_vals[..., 1:] != sorted_vals[..., :-1] # 3. Assign competitive ranks in sorted order. @@ -172,7 +175,7 @@ def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: # at tie positions we propagate the rank of the first occurrence # via cummax (all non-change positions are set to -1 so cummax # naturally carries forward the last "real" rank). - positions = torch.arange(M, device=x.device, dtype=torch.long).expand(B, I, M) + positions = torch.arange(num_items, device=x.device, dtype=torch.long).expand(batch_size, num_experts, num_items) ranks_sorted = torch.where( change_mask, positions, @@ -185,9 +188,9 @@ def competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: ranks.scatter_(-1, sort_indices, ranks_sorted.to(x.dtype)) # 5. Normalize to [0, 1] - return ranks / (M - 1) + return ranks / (num_items - 1) - def conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: + def _conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: """Truth discovery using CRH algorithm. This try to estimate the true importance scores by iteratively reweighting the experts based on their agreement with the current consensus. @@ -196,6 +199,6 @@ def conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: where values are in [0, 1]. Returns: - Tensor of the same shape with values in [0, 1]. + Tensor of shape (B, M) in [0, 1]. """ - pas \ No newline at end of file + pass \ No newline at end of file diff --git a/tests/core/test_interpret_ensemble.py b/tests/core/test_interpret_ensemble.py index f2f0888f4..7ec7cc737 100644 --- a/tests/core/test_interpret_ensemble.py +++ b/tests/core/test_interpret_ensemble.py @@ -336,7 +336,7 @@ def test_all_distinct_values(self): """When all values are distinct, ranks should be 0..M-1 (normalized).""" # (B=1, I=1, M=5) with distinct values x = torch.tensor([[[3.0, 1.0, 4.0, 1.5, 2.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # Sorted ascending: 1.0(idx1)=0, 1.5(idx3)=1, 2.0(idx4)=2, 3.0(idx0)=3, 4.0(idx2)=4 # Normalized by (M-1)=4 @@ -347,13 +347,13 @@ def test_output_shape(self): """Output shape must match input shape.""" for B, I, M in [(2, 3, 5), (1, 1, 10), (4, 2, 7)]: x = torch.randn(B, I, M) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertEqual(result.shape, x.shape) def test_output_range(self): """All output values must lie in [0, 1].""" x = torch.randn(4, 3, 20) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertTrue((result >= 0).all()) self.assertTrue((result <= 1).all()) @@ -361,28 +361,28 @@ def test_tied_scores_get_same_rank(self): """Tied scores must receive the same (minimum) rank — '1224' ranking.""" # [1, 2, 2, 4] -> ranks [0, 1, 1, 3], normalized by 3 x = torch.tensor([[[1.0, 2.0, 2.0, 4.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) torch.testing.assert_close(result, expected) def test_all_tied(self): """When every item has the same score, all ranks should be 0.""" x = torch.tensor([[[5.0, 5.0, 5.0, 5.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_single_item(self): """M=1 edge case: should return zeros.""" x = torch.tensor([[[7.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_two_items_distinct(self): """M=2, distinct values.""" x = torch.tensor([[[3.0, 1.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # 1.0 -> rank 0, 3.0 -> rank 1; normalized by 1 expected = torch.tensor([[[1.0, 0.0]]]) torch.testing.assert_close(result, expected) @@ -390,14 +390,14 @@ def test_two_items_distinct(self): def test_two_items_tied(self): """M=2, tied values.""" x = torch.tensor([[[3.0, 3.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_multiple_tie_groups(self): """Multiple distinct tie groups: [1,1,3,3,5].""" x = torch.tensor([[[1.0, 1.0, 3.0, 3.0, 5.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # ranks: [0, 0, 2, 2, 4], normalized by 4 expected = torch.tensor([[[0/4, 0/4, 2/4, 2/4, 4/4]]]) torch.testing.assert_close(result, expected) @@ -411,7 +411,7 @@ def test_batch_and_expert_independence(self): [5.0, 5.0, 5.0]], # batch 1, expert 1 ]) # (B=2, I=2, M=3) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # batch0, expert0: [3,1,2] -> ranks [2,0,1] / 2 torch.testing.assert_close(result[0, 0], torch.tensor([2/2, 0/2, 1/2])) @@ -425,7 +425,7 @@ def test_batch_and_expert_independence(self): def test_negative_values(self): """Negative values should be ranked correctly.""" x = torch.tensor([[[-3.0, -1.0, -2.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # Ascending: -3(0), -2(1), -1(2); normalized by 2 expected = torch.tensor([[[0/2, 2/2, 1/2]]]) torch.testing.assert_close(result, expected) @@ -433,14 +433,14 @@ def test_negative_values(self): def test_already_sorted_ascending(self): """Input already sorted ascending.""" x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.tensor([[[0/4, 1/4, 2/4, 3/4, 4/4]]]) torch.testing.assert_close(result, expected) def test_sorted_descending(self): """Input sorted descending.""" x = torch.tensor([[[5.0, 4.0, 3.0, 2.0, 1.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) expected = torch.tensor([[[4/4, 3/4, 2/4, 1/4, 0/4]]]) torch.testing.assert_close(result, expected) @@ -450,7 +450,7 @@ def test_competitive_not_dense_ranking(self): For [10, 20, 20, 40]: competitive ranks are [0,1,1,3], not [0,1,1,2]. """ x = torch.tensor([[[10.0, 20.0, 20.0, 40.0]]]) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) # Competitive: [0, 1, 1, 3] / 3 expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) torch.testing.assert_close(result, expected) @@ -461,25 +461,25 @@ def test_competitive_not_dense_ranking(self): def test_preserves_device(self): """Output should be on the same device as input.""" x = torch.randn(2, 3, 5) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertEqual(result.device, x.device) def test_dtype_float32(self): """Works with float32 input.""" x = torch.randn(2, 2, 6, dtype=torch.float32) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertEqual(result.dtype, torch.float32) def test_dtype_float64(self): """Works with float64 input.""" x = torch.randn(2, 2, 6, dtype=torch.float64) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertEqual(result.dtype, torch.float64) def test_large_tensor(self): """Smoke test on a larger tensor.""" x = torch.randn(8, 5, 100) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) self.assertEqual(result.shape, x.shape) self.assertTrue((result >= 0).all()) self.assertTrue((result <= 1).all()) @@ -487,7 +487,7 @@ def test_large_tensor(self): def test_max_is_one_min_is_zero_when_all_distinct(self): """When all items are distinct, min rank = 0, max rank = 1.""" x = torch.randn(3, 2, 10) - result = Ensemble.competitive_ranking_noramlize(x) + result = Ensemble._competitive_ranking_noramlize(x) for b in range(3): for i in range(2): self.assertAlmostEqual(result[b, i].min().item(), 0.0, places=6) From 8adf83584ef3515c2e76cda48339fbc9b3bfb596 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 11 Feb 2026 02:47:29 -0600 Subject: [PATCH 04/13] Implement ensemble methods --- pyhealth/interpret/methods/ensemble.py | 57 ++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py index c218ef1fa..827ab9e59 100644 --- a/pyhealth/interpret/methods/ensemble.py +++ b/pyhealth/interpret/methods/ensemble.py @@ -22,11 +22,16 @@ def __init__( self, model: BaseModel, experts: list[BaseInterpreter], + n_iter: int = 20, + low_confidence_threshold: float | None = None, + early_stopping_threshold: float | None = None, ): super().__init__(model) assert len(experts) >= 3, "Ensemble must contain at least three interpreters for majority voting" self.experts = experts - + self.n_iter = n_iter + self.low_confidence_threshold = low_confidence_threshold + self.early_stopping_threshold = early_stopping_threshold # ------------------------------------------------------------------ # Public API @@ -66,7 +71,7 @@ def attribute( # Combine the flattened attributions from all interpreters attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) - attributions = self._competitive_ranking_noramlize(attributions) # shape (B, I, M) + attributions = self._competitive_ranking_normalize(attributions) # shape (B, I, M) @@ -139,7 +144,7 @@ def _unflatten_attributions( @staticmethod - def _competitive_ranking_noramlize(x: torch.Tensor) -> torch.Tensor: + def _competitive_ranking_normalize(x: torch.Tensor) -> torch.Tensor: """Normalize a tensor via competitive (standard competition) ranking. For each (batch, expert) slice, items are ranked ascendingly from @@ -201,4 +206,48 @@ def _conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape (B, M) in [0, 1]. """ - pass \ No newline at end of file + # Step 1: Initialize truth as median across experts (B, M) + t = torch.median(attributions, dim=1).values # (B, M) + + # Iterative refinement + eps = 1e-6 + + for _ in range(self.n_iter): + t_old = t.clone() + + # Step 2: Compute expert reliability per batch + # errors: (B, I) - mean squared error per expert per batch + errors = torch.mean((attributions - t.unsqueeze(1)) ** 2, dim=2) # (B, I) + + # weights: (B, I) + w = 1.0 / (eps + errors) # (B, I) + w = w / w.sum(dim=1, keepdim=True) # normalize per batch + + # Step 3: Update truth as weighted average + # t: (B, M) = sum over experts of w * attributions + t = torch.sum(w.unsqueeze(2) * attributions, dim=1) # (B, M) + + # Early stopping: check convergence per batch + if self.early_stopping_threshold is not None: + if torch.allclose(t, t_old, atol=self.early_stopping_threshold): + break + + if self.low_confidence_threshold is None: + # If no low confidence threshold is set, just return the CRH result + return t + + # Detect low-confidence batches where all experts are equally weighted + # If std(w) is very low, it means no expert is clearly better + w_std = torch.std(w, dim=1) # type: ignore[assignment] (B,) + low_confidence = w_std < self.low_confidence_threshold # (B,) + + # For low-confidence batches, fall back to uniform weighting (mean) + if low_confidence.any(): + uniform_consensus = torch.mean(attributions, dim=1) # (B, M) + t = torch.where( + low_confidence.unsqueeze(1), # (B, 1) + uniform_consensus, + t + ) + + return t \ No newline at end of file From 937d198266856a39d967b27c02030f1c505974af Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 11 Feb 2026 02:53:26 -0600 Subject: [PATCH 05/13] Fixup --- pyhealth/interpret/methods/ensemble.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py index 827ab9e59..2579230bd 100644 --- a/pyhealth/interpret/methods/ensemble.py +++ b/pyhealth/interpret/methods/ensemble.py @@ -73,9 +73,10 @@ def attribute( # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) attributions = self._competitive_ranking_normalize(attributions) # shape (B, I, M) - - - raise NotImplementedError("Ensemble attribution method is not implemented yet. This is a placeholder for future development.") + # Resolve conflicts and aggregate across interpreters using CRH + consensus = self._conflict_resolution(attributions) # shape (B, M) + assert out_shape is not None, "Output shape should have been determined from the first interpreter" + return self._unflatten_attributions(consensus, out_shape) # dict of tensors with original shapes # ------------------------------------------------------------------ From 4923dfe9a23035c3cbe9f367e9aaeba29aa34e70 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:26:53 -0600 Subject: [PATCH 06/13] Create base class for ensembles --- pyhealth/interpret/methods/base_ensemble.py | 223 ++++++++++++++++++++ tests/core/test_interpret_ensemble.py | 195 ++++------------- 2 files changed, 262 insertions(+), 156 deletions(-) create mode 100644 pyhealth/interpret/methods/base_ensemble.py diff --git a/pyhealth/interpret/methods/base_ensemble.py b/pyhealth/interpret/methods/base_ensemble.py new file mode 100644 index 000000000..691b1f28a --- /dev/null +++ b/pyhealth/interpret/methods/base_ensemble.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import torch + +from pyhealth.models import BaseModel +from .base_interpreter import BaseInterpreter + + +class BaseInterpreterEnsemble(BaseInterpreter): + """Abstract base class for ensemble interpreters. + + Provides the shared workflow for ensemble-based attribution: + + 1. Each expert interpreter independently computes attributions. + 2. The per-expert attribution maps are flattened, then normalized to + a common [0, 1] scale via competitive ranking. + 3. The normalized attributions are passed to :meth:`_ensemble`, which + concrete subclasses must override to implement a specific + aggregation strategy (e.g., CRH truth discovery, simple averaging, + majority voting). + 4. The aggregated result is unflattened back to the original tensor + shapes. + + Subclasses only need to implement :meth:`_ensemble`. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model) + assert len(experts) >= 3, "Ensemble must contain at least three interpreters for majority voting" + self.experts = experts + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Compute consensus attributions by ensembling all expert interpreters. + + Each expert's ``attribute`` method is called with the same inputs. + The resulting attribution maps are flattened, competitively ranked + to a common [0, 1] scale, and aggregated via the subclass-defined + :meth:`_ensemble` strategy. + + Args: + **kwargs: Input data dictionary from a dataloader batch. + Should contain feature tensors (or tuples of tensors) + keyed by the model's feature keys, plus optional label + or metadata tensors (which are forwarded to experts). + + Returns: + Dictionary mapping each feature key to a consensus attribution + tensor whose shape matches the corresponding input tensor. + """ + out_shape: dict[str, torch.Size] | None = None + attr_lst: list[torch.Tensor] = [] + for expert in self.experts: + attr = expert.attribute(**kwargs) + + # record the output shape from the first interpreter, + # since all interpreters should produce the same shape + if out_shape is None: + out_shape = {k: v.shape for k, v in attr.items()} + + flat_attr = self._flatten_attributions(attr) # shape (B, M) + attr_lst.append(flat_attr) + + # Combine the flattened attributions from all interpreters + attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) + # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) + attributions = self._competitive_ranking_normalize(attributions) # shape (B, I, M) + + # Resolve conflicts and aggregate across interpreters using CRH + consensus = self._ensemble(attributions) # shape (B, M) + assert out_shape is not None, "Output shape should have been determined from the first interpreter" + return self._unflatten_attributions(consensus, out_shape) # dict of tensors with original shapes + + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate normalized expert attributions into a single consensus. + + Subclasses must override this method to define the aggregation + strategy (e.g., iterative truth discovery, simple averaging). + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Aggregated tensor of shape ``(B, M)`` with values in [0, 1]. + """ + raise NotImplementedError("Subclasses must implement their ensemble aggregation strategy in this method") + + # ------------------------------------------------------------------ + # Private helper methods + # ------------------------------------------------------------------ + @staticmethod + def _flatten_attributions( + values: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Flatten values dictionary to a single tensor. + + Takes a dictionary of tensors with shape (B, *) and flattens each to (B, M_i), + then concatenates them along the feature dimension to get (B, M). + + Args: + values: Dictionary mapping feature keys to tensors of shape (B, *). + + Returns: + Flattened tensor of shape (B, M) where M is the sum of all flattened dimensions. + """ + flattened_list = [] + for key in sorted(values.keys()): # Sort for consistency + tensor = values[key] + batch_size = tensor.shape[0] + # Flatten all dimensions except batch + flattened = tensor.reshape(batch_size, -1) + flattened_list.append(flattened) + + # Concatenate along feature dimension + return torch.cat(flattened_list, dim=1) + + @staticmethod + def _unflatten_attributions( + flattened: torch.Tensor, + shapes: dict[str, torch.Size], + ) -> dict[str, torch.Tensor]: + """Unflatten tensor back to values dictionary. + + Takes a flattened tensor of shape (B, M) and original shapes, + and reconstructs the original dictionary of tensors. + + Args: + flattened: Flattened tensor of shape (B, M). + shapes: Dictionary mapping feature keys to original tensor shapes. + + Returns: + Dictionary mapping feature keys to tensors with original shapes. + """ + values = {} + offset = 0 + + for key in sorted(shapes.keys()): # Must match the order in _flatten_values + shape = shapes[key] + batch_size = shape[0] + + # Calculate the size of the flattened feature + feature_size = 1 + for s in shape[1:]: + feature_size *= s + + # Extract the relevant portion and reshape + values[key] = flattened[:, offset : offset + feature_size].reshape(shape) + offset += feature_size + + return values + + + @staticmethod + def _competitive_ranking_normalize(x: torch.Tensor) -> torch.Tensor: + """Normalize a tensor via competitive (standard competition) ranking. + + For each (batch, expert) slice, items are ranked ascendingly from + 0 to ``total_item - 1``. Tied scores receive the same rank — the + smallest position index among the tied group (standard competition / + "1224" ranking). The ranks are then divided by ``total_item - 1`` + so that the output lies in [0, 1]. + + Args: + x: Tensor of shape ``(B, I, M)`` + containing unbounded floating-point scores. + + Returns: + Tensor of the same shape with values in [0, 1]. + """ + batch_size, num_experts, num_items = x.shape + + if num_items <= 1: + # With a single item the rank is 0 and 0/0 is undefined; + # return zeros as a safe default. + return torch.zeros_like(x) + + # 1. Sort ascending along the item dimension + sorted_vals, sort_indices = x.sort(dim=-1) + + # 2. Build a mask that is True at positions where the value changes + # from the previous position (i.e. the start of a new rank group). + change_mask = torch.ones(batch_size, num_experts, num_items, dtype=torch.bool, device=x.device) + change_mask[..., 1:] = sorted_vals[..., 1:] != sorted_vals[..., :-1] + + # 3. Assign competitive ranks in sorted order. + # At change positions the rank equals the position index; + # at tie positions we propagate the rank of the first occurrence + # via cummax (all non-change positions are set to -1 so cummax + # naturally carries forward the last "real" rank). + positions = torch.arange(num_items, device=x.device, dtype=torch.long).expand(batch_size, num_experts, num_items) + ranks_sorted = torch.where( + change_mask, + positions, + torch.full_like(positions, -1), + ) + ranks_sorted, _ = ranks_sorted.cummax(dim=-1) + + # 4. Scatter the ranks back to the original (unsorted) order + ranks = torch.zeros_like(x) + ranks.scatter_(-1, sort_indices, ranks_sorted.to(x.dtype)) + + # 5. Normalize to [0, 1] + return ranks / (num_items - 1) diff --git a/tests/core/test_interpret_ensemble.py b/tests/core/test_interpret_ensemble.py index 7ec7cc737..43655db39 100644 --- a/tests/core/test_interpret_ensemble.py +++ b/tests/core/test_interpret_ensemble.py @@ -7,89 +7,9 @@ import torch.nn as nn from pyhealth.models import BaseModel -from pyhealth.interpret.methods import Ensemble +from pyhealth.interpret.methods.base_ensemble import BaseInterpreterEnsemble from pyhealth.interpret.methods.base_interpreter import BaseInterpreter - -# --------------------------------------------------------------------------- -# Mock helpers -# --------------------------------------------------------------------------- - -class _MockProcessor: - """Mock feature processor with configurable schema.""" - - def __init__(self, schema_tuple=("value",)): - self._schema = schema_tuple - - def schema(self): - return self._schema - - def is_token(self): - return False - - -class _MockDataset: - """Lightweight stand-in for SampleDataset in unit tests.""" - - def __init__(self, input_schema, output_schema, processors=None): - self.input_schema = input_schema - self.output_schema = output_schema - self.input_processors = processors or { - k: _MockProcessor() for k in input_schema - } - - -# --------------------------------------------------------------------------- -# Test model helpers -# --------------------------------------------------------------------------- - -class _SimpleModel(BaseModel): - """Minimal model for testing Ensemble.""" - - def __init__(self): - dataset = _MockDataset( - input_schema={"x": "tensor"}, - output_schema={"y": "binary"}, - ) - super().__init__(dataset=dataset) - self.linear = nn.Linear(3, 1, bias=True) - - def forward(self, **kwargs) -> dict: - x = kwargs["x"] - if isinstance(x, tuple): - x = x[0] - y = kwargs.get("y", None) - - logit = self.linear(x) - y_prob = torch.sigmoid(logit) - - result = { - "logit": logit, - "y_prob": y_prob, - "loss": torch.zeros((), device=y_prob.device), - } - if y is not None: - result["y_true"] = y.to(y_prob.device) - return result - - def forward_from_embedding(self, **kwargs) -> dict: - return self.forward(**kwargs) - - def get_embedding_model(self): - return None - - -class _DummyInterpreter(BaseInterpreter): - """Dummy interpreter that returns random attributions.""" - - def attribute(self, **kwargs) -> dict: - attributions = {} - for key, value in kwargs.items(): - if isinstance(value, torch.Tensor): - attributions[key] = torch.randn_like(value) - return attributions - - # --------------------------------------------------------------------------- # Test classes # --------------------------------------------------------------------------- @@ -97,25 +17,12 @@ def attribute(self, **kwargs) -> dict: class TestEnsembleFlattenValues(unittest.TestCase): """Tests for the _flatten_values method.""" - def setUp(self): - self.model = _SimpleModel() - self.model.eval() - - # Create 3 dummy interpreters - self.interpreters = [ - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - ] - - self.ensemble = Ensemble(self.model, self.interpreters) - def test_flatten_single_feature_1d(self): """Test flattening a single 1D feature.""" values = { "feature": torch.randn(2, 3), # (B=2, M=3) } - flattened = self.ensemble._flatten_attributions(values) + flattened = BaseInterpreterEnsemble._flatten_attributions(values) self.assertEqual(flattened.shape, torch.Size([2, 3])) torch.testing.assert_close(flattened, values["feature"]) @@ -125,7 +32,7 @@ def test_flatten_single_feature_2d(self): values = { "feature": torch.randn(2, 4, 5), # (B=2, 4, 5) -> (B=2, 20) } - flattened = self.ensemble._flatten_attributions(values) + flattened = BaseInterpreterEnsemble._flatten_attributions(values) self.assertEqual(flattened.shape, torch.Size([2, 20])) expected = values["feature"].reshape(2, -1) @@ -138,7 +45,7 @@ def test_flatten_multiple_features(self): "feature_b": torch.randn(3, 2, 3), # (B=3, 2, 3) -> (B=3, 6) "feature_c": torch.randn(3, 4), # (B=3, 4) -> (B=3, 4) } - flattened = self.ensemble._flatten_attributions(values) + flattened = BaseInterpreterEnsemble._flatten_attributions(values) # Total flattened size = 5 + 6 + 4 = 15 self.assertEqual(flattened.shape, torch.Size([3, 15])) @@ -150,7 +57,7 @@ def test_flatten_preserves_batch_size(self): "feature_a": torch.randn(batch_size, 3), "feature_b": torch.randn(batch_size, 2, 4), } - flattened = self.ensemble._flatten_attributions(values) + flattened = BaseInterpreterEnsemble._flatten_attributions(values) self.assertEqual(flattened.shape[0], batch_size) @@ -162,7 +69,7 @@ def test_flatten_consistency_with_sorted_keys(self): "banana": torch.randn(2, 2), } - flattened = self.ensemble._flatten_attributions(values) + flattened = BaseInterpreterEnsemble._flatten_attributions(values) # Should be ordered alphabetically: apple (4), banana (2), zebra (3) self.assertEqual(flattened.shape, torch.Size([2, 9])) @@ -180,24 +87,12 @@ def test_flatten_consistency_with_sorted_keys(self): class TestEnsembleUnflattenValues(unittest.TestCase): """Tests for the _unflatten_values method.""" - def setUp(self): - self.model = _SimpleModel() - self.model.eval() - - self.interpreters = [ - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - ] - - self.ensemble = Ensemble(self.model, self.interpreters) - def test_unflatten_single_feature_1d(self): """Test unflattening a single 1D feature.""" shapes = {"feature": torch.Size([2, 3])} flattened = torch.randn(2, 3) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) self.assertIn("feature", unflattened) self.assertEqual(unflattened["feature"].shape, torch.Size([2, 3])) @@ -209,7 +104,7 @@ def test_unflatten_single_feature_2d(self): shapes = {"feature": original_shape} flattened = torch.randn(2, 20) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) self.assertEqual(unflattened["feature"].shape, original_shape) @@ -222,7 +117,7 @@ def test_unflatten_multiple_features(self): } flattened = torch.randn(3, 15) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) self.assertEqual(len(unflattened), 3) self.assertEqual(unflattened["feature_a"].shape, torch.Size([3, 5])) @@ -238,7 +133,7 @@ def test_unflatten_preserves_batch_size(self): } flattened = torch.randn(batch_size, 11) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) self.assertEqual(unflattened["feature_a"].shape[0], batch_size) self.assertEqual(unflattened["feature_b"].shape[0], batch_size) @@ -247,18 +142,6 @@ def test_unflatten_preserves_batch_size(self): class TestEnsembleRoundtrip(unittest.TestCase): """Tests for flatten/unflatten roundtrip consistency.""" - def setUp(self): - self.model = _SimpleModel() - self.model.eval() - - self.interpreters = [ - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - _DummyInterpreter(self.model), - ] - - self.ensemble = Ensemble(self.model, self.interpreters) - def test_roundtrip_single_feature(self): """Test that flatten->unflatten recovers original single feature.""" original = { @@ -266,8 +149,8 @@ def test_roundtrip_single_feature(self): } shapes = {k: v.shape for k, v in original.items()} - flattened = self.ensemble._flatten_attributions(original) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) torch.testing.assert_close(unflattened["feature"], original["feature"]) @@ -280,8 +163,8 @@ def test_roundtrip_multiple_features(self): } shapes = {k: v.shape for k, v in original.items()} - flattened = self.ensemble._flatten_attributions(original) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) for key in original: torch.testing.assert_close(unflattened[key], original[key]) @@ -294,8 +177,8 @@ def test_roundtrip_high_dimensional(self): } shapes = {k: v.shape for k, v in original.items()} - flattened = self.ensemble._flatten_attributions(original) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) for key in original: torch.testing.assert_close(unflattened[key], original[key]) @@ -308,8 +191,8 @@ def test_roundtrip_maintains_device(self): } shapes = {k: v.shape for k, v in original.items()} - flattened = self.ensemble._flatten_attributions(original) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) for key in original: self.assertEqual(unflattened[key].device, original[key].device) @@ -322,8 +205,8 @@ def test_roundtrip_with_gradients(self): } shapes = {k: v.shape for k, v in original.items()} - flattened = self.ensemble._flatten_attributions(original) - unflattened = self.ensemble._unflatten_attributions(flattened, shapes) + flattened = BaseInterpreterEnsemble._flatten_attributions(original) + unflattened = BaseInterpreterEnsemble._unflatten_attributions(flattened, shapes) for key in original: torch.testing.assert_close(unflattened[key].detach(), original[key].detach()) @@ -336,7 +219,7 @@ def test_all_distinct_values(self): """When all values are distinct, ranks should be 0..M-1 (normalized).""" # (B=1, I=1, M=5) with distinct values x = torch.tensor([[[3.0, 1.0, 4.0, 1.5, 2.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # Sorted ascending: 1.0(idx1)=0, 1.5(idx3)=1, 2.0(idx4)=2, 3.0(idx0)=3, 4.0(idx2)=4 # Normalized by (M-1)=4 @@ -347,13 +230,13 @@ def test_output_shape(self): """Output shape must match input shape.""" for B, I, M in [(2, 3, 5), (1, 1, 10), (4, 2, 7)]: x = torch.randn(B, I, M) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertEqual(result.shape, x.shape) def test_output_range(self): """All output values must lie in [0, 1].""" x = torch.randn(4, 3, 20) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertTrue((result >= 0).all()) self.assertTrue((result <= 1).all()) @@ -361,28 +244,28 @@ def test_tied_scores_get_same_rank(self): """Tied scores must receive the same (minimum) rank — '1224' ranking.""" # [1, 2, 2, 4] -> ranks [0, 1, 1, 3], normalized by 3 x = torch.tensor([[[1.0, 2.0, 2.0, 4.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) torch.testing.assert_close(result, expected) def test_all_tied(self): """When every item has the same score, all ranks should be 0.""" x = torch.tensor([[[5.0, 5.0, 5.0, 5.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_single_item(self): """M=1 edge case: should return zeros.""" x = torch.tensor([[[7.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_two_items_distinct(self): """M=2, distinct values.""" x = torch.tensor([[[3.0, 1.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # 1.0 -> rank 0, 3.0 -> rank 1; normalized by 1 expected = torch.tensor([[[1.0, 0.0]]]) torch.testing.assert_close(result, expected) @@ -390,14 +273,14 @@ def test_two_items_distinct(self): def test_two_items_tied(self): """M=2, tied values.""" x = torch.tensor([[[3.0, 3.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) def test_multiple_tie_groups(self): """Multiple distinct tie groups: [1,1,3,3,5].""" x = torch.tensor([[[1.0, 1.0, 3.0, 3.0, 5.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # ranks: [0, 0, 2, 2, 4], normalized by 4 expected = torch.tensor([[[0/4, 0/4, 2/4, 2/4, 4/4]]]) torch.testing.assert_close(result, expected) @@ -411,7 +294,7 @@ def test_batch_and_expert_independence(self): [5.0, 5.0, 5.0]], # batch 1, expert 1 ]) # (B=2, I=2, M=3) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # batch0, expert0: [3,1,2] -> ranks [2,0,1] / 2 torch.testing.assert_close(result[0, 0], torch.tensor([2/2, 0/2, 1/2])) @@ -425,7 +308,7 @@ def test_batch_and_expert_independence(self): def test_negative_values(self): """Negative values should be ranked correctly.""" x = torch.tensor([[[-3.0, -1.0, -2.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # Ascending: -3(0), -2(1), -1(2); normalized by 2 expected = torch.tensor([[[0/2, 2/2, 1/2]]]) torch.testing.assert_close(result, expected) @@ -433,14 +316,14 @@ def test_negative_values(self): def test_already_sorted_ascending(self): """Input already sorted ascending.""" x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.tensor([[[0/4, 1/4, 2/4, 3/4, 4/4]]]) torch.testing.assert_close(result, expected) def test_sorted_descending(self): """Input sorted descending.""" x = torch.tensor([[[5.0, 4.0, 3.0, 2.0, 1.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) expected = torch.tensor([[[4/4, 3/4, 2/4, 1/4, 0/4]]]) torch.testing.assert_close(result, expected) @@ -450,7 +333,7 @@ def test_competitive_not_dense_ranking(self): For [10, 20, 20, 40]: competitive ranks are [0,1,1,3], not [0,1,1,2]. """ x = torch.tensor([[[10.0, 20.0, 20.0, 40.0]]]) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) # Competitive: [0, 1, 1, 3] / 3 expected = torch.tensor([[[0/3, 1/3, 1/3, 3/3]]]) torch.testing.assert_close(result, expected) @@ -461,25 +344,25 @@ def test_competitive_not_dense_ranking(self): def test_preserves_device(self): """Output should be on the same device as input.""" x = torch.randn(2, 3, 5) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertEqual(result.device, x.device) def test_dtype_float32(self): """Works with float32 input.""" x = torch.randn(2, 2, 6, dtype=torch.float32) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertEqual(result.dtype, torch.float32) def test_dtype_float64(self): """Works with float64 input.""" x = torch.randn(2, 2, 6, dtype=torch.float64) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertEqual(result.dtype, torch.float64) def test_large_tensor(self): """Smoke test on a larger tensor.""" x = torch.randn(8, 5, 100) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) self.assertEqual(result.shape, x.shape) self.assertTrue((result >= 0).all()) self.assertTrue((result <= 1).all()) @@ -487,7 +370,7 @@ def test_large_tensor(self): def test_max_is_one_min_is_zero_when_all_distinct(self): """When all items are distinct, min rank = 0, max rank = 1.""" x = torch.randn(3, 2, 10) - result = Ensemble._competitive_ranking_noramlize(x) + result = BaseInterpreterEnsemble._competitive_ranking_normalize(x) for b in range(3): for i in range(2): self.assertAlmostEqual(result[b, i].min().item(), 0.0, places=6) From b1691bf99438e44c25aab01d5e9977d815a8d97d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:27:06 -0600 Subject: [PATCH 07/13] Fixup --- pyhealth/interpret/methods/__init__.py | 4 +- pyhealth/interpret/methods/ensemble.py | 254 ------------------------- 2 files changed, 2 insertions(+), 256 deletions(-) delete mode 100644 pyhealth/interpret/methods/ensemble.py diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 9c3004ea6..99e209514 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,7 +8,7 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer -from pyhealth.interpret.methods.ensemble import Ensemble +from pyhealth.interpret.methods.crh_ensemble import CrhEnsemble __all__ = [ "BaseInterpreter", @@ -21,5 +21,5 @@ "RandomBaseline", "ShapExplainer", "LimeExplainer", - "Ensemble" + "CrhEnsemble" ] diff --git a/pyhealth/interpret/methods/ensemble.py b/pyhealth/interpret/methods/ensemble.py deleted file mode 100644 index 2579230bd..000000000 --- a/pyhealth/interpret/methods/ensemble.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Random baseline attribution method. - -This module implements a simple random attribution method that assigns -uniformly random importance scores to each input feature. It serves as a -baseline for evaluating the quality of more sophisticated interpretability -methods — any useful attribution technique should outperform random -assignments. -""" - -from __future__ import annotations - -from typing import Dict, Optional - -import torch - -from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter - - -class Ensemble(BaseInterpreter): - def __init__( - self, - model: BaseModel, - experts: list[BaseInterpreter], - n_iter: int = 20, - low_confidence_threshold: float | None = None, - early_stopping_threshold: float | None = None, - ): - super().__init__(model) - assert len(experts) >= 3, "Ensemble must contain at least three interpreters for majority voting" - self.experts = experts - self.n_iter = n_iter - self.low_confidence_threshold = low_confidence_threshold - self.early_stopping_threshold = early_stopping_threshold - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - def attribute( - self, - **kwargs: torch.Tensor | tuple[torch.Tensor, ...], - ) -> Dict[str, torch.Tensor]: - """Compute random attributions for input features. - - Generates random importance scores with the same shape as each - input feature tensor. No gradients or forward passes are needed. - - Args: - **kwargs: Input data dictionary from a dataloader batch. - Should contain feature tensors (or tuples of tensors) - keyed by the model's feature keys, plus optional label - or metadata tensors (which are ignored). - - Returns: - Dictionary mapping each feature key to a random attribution - tensor whose shape matches the raw input values. - """ - out_shape: dict[str, torch.Size] | None = None - attr_lst: list[torch.Tensor] = [] - for expert in self.experts: - attr = expert.attribute(**kwargs) - - # record the output shape from the first interpreter, - # since all interpreters should produce the same shape - if out_shape is None: - out_shape = {k: v.shape for k, v in attr.items()} - - flat_attr = self._flatten_attributions(attr) # shape (B, M) - attr_lst.append(flat_attr) - - # Combine the flattened attributions from all interpreters - attributions = torch.stack(attr_lst, dim=1) # shape (B, I, M) - # Normalize the attributions across items for each interpreter (e.g., by competitive ranking) - attributions = self._competitive_ranking_normalize(attributions) # shape (B, I, M) - - # Resolve conflicts and aggregate across interpreters using CRH - consensus = self._conflict_resolution(attributions) # shape (B, M) - assert out_shape is not None, "Output shape should have been determined from the first interpreter" - return self._unflatten_attributions(consensus, out_shape) # dict of tensors with original shapes - - - # ------------------------------------------------------------------ - # Private helper methods - # ------------------------------------------------------------------ - @staticmethod - def _flatten_attributions( - values: dict[str, torch.Tensor], - ) -> torch.Tensor: - """Flatten values dictionary to a single tensor. - - Takes a dictionary of tensors with shape (B, *) and flattens each to (B, M_i), - then concatenates them along the feature dimension to get (B, M). - - Args: - values: Dictionary mapping feature keys to tensors of shape (B, *). - - Returns: - Flattened tensor of shape (B, M) where M is the sum of all flattened dimensions. - """ - flattened_list = [] - for key in sorted(values.keys()): # Sort for consistency - tensor = values[key] - batch_size = tensor.shape[0] - # Flatten all dimensions except batch - flattened = tensor.reshape(batch_size, -1) - flattened_list.append(flattened) - - # Concatenate along feature dimension - return torch.cat(flattened_list, dim=1) - - @staticmethod - def _unflatten_attributions( - flattened: torch.Tensor, - shapes: dict[str, torch.Size], - ) -> dict[str, torch.Tensor]: - """Unflatten tensor back to values dictionary. - - Takes a flattened tensor of shape (B, M) and original shapes, - and reconstructs the original dictionary of tensors. - - Args: - flattened: Flattened tensor of shape (B, M). - shapes: Dictionary mapping feature keys to original tensor shapes. - - Returns: - Dictionary mapping feature keys to tensors with original shapes. - """ - values = {} - offset = 0 - - for key in sorted(shapes.keys()): # Must match the order in _flatten_values - shape = shapes[key] - batch_size = shape[0] - - # Calculate the size of the flattened feature - feature_size = 1 - for s in shape[1:]: - feature_size *= s - - # Extract the relevant portion and reshape - values[key] = flattened[:, offset : offset + feature_size].reshape(shape) - offset += feature_size - - return values - - - @staticmethod - def _competitive_ranking_normalize(x: torch.Tensor) -> torch.Tensor: - """Normalize a tensor via competitive (standard competition) ranking. - - For each (batch, expert) slice, items are ranked ascendingly from - 0 to ``total_item - 1``. Tied scores receive the same rank — the - smallest position index among the tied group (standard competition / - "1224" ranking). The ranks are then divided by ``total_item - 1`` - so that the output lies in [0, 1]. - - Args: - x: Tensor of shape ``(B, I, M)`` - containing unbounded floating-point scores. - - Returns: - Tensor of the same shape with values in [0, 1]. - """ - batch_size, num_experts, num_items = x.shape - - if num_items <= 1: - # With a single item the rank is 0 and 0/0 is undefined; - # return zeros as a safe default. - return torch.zeros_like(x) - - # 1. Sort ascending along the item dimension - sorted_vals, sort_indices = x.sort(dim=-1) - - # 2. Build a mask that is True at positions where the value changes - # from the previous position (i.e. the start of a new rank group). - change_mask = torch.ones(batch_size, num_experts, num_items, dtype=torch.bool, device=x.device) - change_mask[..., 1:] = sorted_vals[..., 1:] != sorted_vals[..., :-1] - - # 3. Assign competitive ranks in sorted order. - # At change positions the rank equals the position index; - # at tie positions we propagate the rank of the first occurrence - # via cummax (all non-change positions are set to -1 so cummax - # naturally carries forward the last "real" rank). - positions = torch.arange(num_items, device=x.device, dtype=torch.long).expand(batch_size, num_experts, num_items) - ranks_sorted = torch.where( - change_mask, - positions, - torch.full_like(positions, -1), - ) - ranks_sorted, _ = ranks_sorted.cummax(dim=-1) - - # 4. Scatter the ranks back to the original (unsorted) order - ranks = torch.zeros_like(x) - ranks.scatter_(-1, sort_indices, ranks_sorted.to(x.dtype)) - - # 5. Normalize to [0, 1] - return ranks / (num_items - 1) - - def _conflict_resolution(self, attributions: torch.Tensor) -> torch.Tensor: - """Truth discovery using CRH algorithm. This try to estimate the true importance scores - by iteratively reweighting the experts based on their agreement with the current consensus. - - Args: - attributions: A normalized attribution tensor of shape (B, I, M) - where values are in [0, 1]. - - Returns: - Tensor of shape (B, M) in [0, 1]. - """ - # Step 1: Initialize truth as median across experts (B, M) - t = torch.median(attributions, dim=1).values # (B, M) - - # Iterative refinement - eps = 1e-6 - - for _ in range(self.n_iter): - t_old = t.clone() - - # Step 2: Compute expert reliability per batch - # errors: (B, I) - mean squared error per expert per batch - errors = torch.mean((attributions - t.unsqueeze(1)) ** 2, dim=2) # (B, I) - - # weights: (B, I) - w = 1.0 / (eps + errors) # (B, I) - w = w / w.sum(dim=1, keepdim=True) # normalize per batch - - # Step 3: Update truth as weighted average - # t: (B, M) = sum over experts of w * attributions - t = torch.sum(w.unsqueeze(2) * attributions, dim=1) # (B, M) - - # Early stopping: check convergence per batch - if self.early_stopping_threshold is not None: - if torch.allclose(t, t_old, atol=self.early_stopping_threshold): - break - - if self.low_confidence_threshold is None: - # If no low confidence threshold is set, just return the CRH result - return t - - # Detect low-confidence batches where all experts are equally weighted - # If std(w) is very low, it means no expert is clearly better - w_std = torch.std(w, dim=1) # type: ignore[assignment] (B,) - low_confidence = w_std < self.low_confidence_threshold # (B,) - - # For low-confidence batches, fall back to uniform weighting (mean) - if low_confidence.any(): - uniform_consensus = torch.mean(attributions, dim=1) # (B, M) - t = torch.where( - low_confidence.unsqueeze(1), # (B, 1) - uniform_consensus, - t - ) - - return t \ No newline at end of file From 350e8ba563f38aa9022f0d430d64567aaf9c9c82 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:31:37 -0600 Subject: [PATCH 08/13] Add CRH ensemble --- pyhealth/interpret/methods/__init__.py | 4 +- pyhealth/interpret/methods/crh_ensemble.py | 134 +++++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 pyhealth/interpret/methods/crh_ensemble.py diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 99e209514..88fe90c52 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,7 +8,7 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer -from pyhealth.interpret.methods.crh_ensemble import CrhEnsemble +from pyhealth.interpret.methods.crh_ensemble import CrhInterpreterEnsemble __all__ = [ "BaseInterpreter", @@ -21,5 +21,5 @@ "RandomBaseline", "ShapExplainer", "LimeExplainer", - "CrhEnsemble" + "CrhInterpreterEnsemble" ] diff --git a/pyhealth/interpret/methods/crh_ensemble.py b/pyhealth/interpret/methods/crh_ensemble.py new file mode 100644 index 000000000..5c9a3ace6 --- /dev/null +++ b/pyhealth/interpret/methods/crh_ensemble.py @@ -0,0 +1,134 @@ +"""CRH ensemble interpreter. + +This module implements the Conflict Resolution on Heterogeneous data (CRH) +ensemble strategy for aggregating attributions from multiple interpretability +experts into a single consensus attribution. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class CrhInterpreterEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using Conflict Resolution on Heterogeneous data (CRH). + + Iteratively estimates a consensus attribution by reweighting experts + according to their agreement with the current consensus estimate. + Experts whose attributions are closer to the consensus receive higher + weights, which in turn pulls the consensus toward more reliable + experts. + + This implements the truth-discovery algorithm from: + + Li, Q., Li, Y., Gao, J., Zhao, B., Fan, W., and Han, J. + "Resolving Conflicts in Heterogeneous Data by Truth Discovery + and Source Reliability Estimation." In *Proceedings of the 2014 + ACM SIGMOD International Conference on Management of Data* + (SIGMOD'14), pp. 1187–1198, 2014. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + n_iter: Maximum number of CRH refinement iterations. Higher values + allow more precise convergence at the cost of computation. + low_confidence_threshold: If set, batches where the standard + deviation of the final expert weights falls below this value + are considered low-confidence, and their consensus is replaced + by a simple uniform average of all experts. + early_stopping_threshold: If set, the CRH loop terminates early + when the maximum absolute change in the consensus vector + between successive iterations is below this value. + + Example: + >>> from pyhealth.interpret.methods import GradientShap, IntegratedGradients, Saliency + >>> experts = [GradientShap(model), IntegratedGradients(model), Saliency(model)] + >>> ensemble = CrhInterpreterEnsemble(model, experts, n_iter=30) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + n_iter: int = 20, + low_confidence_threshold: float | None = None, + early_stopping_threshold: float | None = None, + ): + super().__init__(model, experts) + self.n_iter = n_iter + self.low_confidence_threshold = low_confidence_threshold + self.early_stopping_threshold = early_stopping_threshold + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions via the CRH truth-discovery algorithm. + + The consensus is initialised as the median across experts and then + iteratively refined: on each iteration, expert weights are set + inversely proportional to their mean squared error against the + current consensus, and the consensus is updated as the weighted + average of all experts. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)`` with values in [0, 1]. + """ + # Step 1: Initialize truth as median across experts (B, M) + t = torch.median(attributions, dim=1).values # (B, M) + + # Iterative refinement + eps = 1e-6 + + for _ in range(self.n_iter): + t_old = t.clone() + + # Step 2: Compute expert reliability per batch + # errors: (B, I) - mean squared error per expert per batch + errors = torch.mean((attributions - t.unsqueeze(1)) ** 2, dim=2) # (B, I) + + # weights: (B, I) + w = 1.0 / (eps + errors) # (B, I) + w = w / w.sum(dim=1, keepdim=True) # normalize per batch + + # Step 3: Update truth as weighted average + # t: (B, M) = sum over experts of w * attributions + t = torch.sum(w.unsqueeze(2) * attributions, dim=1) # (B, M) + + # Early stopping: check convergence per batch + if self.early_stopping_threshold is not None: + if torch.allclose(t, t_old, atol=self.early_stopping_threshold): + break + + if self.low_confidence_threshold is None: + # If no low confidence threshold is set, just return the CRH result + return t + + # Detect low-confidence batches where all experts are equally weighted + # If std(w) is very low, it means no expert is clearly better + w_std = torch.std(w, dim=1) # type: ignore[assignment] (B,) + low_confidence = w_std < self.low_confidence_threshold # (B,) + + # For low-confidence batches, fall back to uniform weighting (mean) + if low_confidence.any(): + uniform_consensus = torch.mean(attributions, dim=1) # (B, M) + t = torch.where( + low_confidence.unsqueeze(1), # (B, 1) + uniform_consensus, + t + ) + + return t \ No newline at end of file From bd5e648c3c17c9c9c01bd6a938e06123be40b958 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:35:30 -0600 Subject: [PATCH 09/13] Add AGGMean --- pyhealth/interpret/methods/__init__.py | 4 +- pyhealth/interpret/methods/avg_ensemble.py | 66 ++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 pyhealth/interpret/methods/avg_ensemble.py diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 88fe90c52..9b688e48b 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -9,6 +9,7 @@ from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer from pyhealth.interpret.methods.crh_ensemble import CrhInterpreterEnsemble +from pyhealth.interpret.methods.avg_ensemble import AvgInterpreterEnsemble __all__ = [ "BaseInterpreter", @@ -21,5 +22,6 @@ "RandomBaseline", "ShapExplainer", "LimeExplainer", - "CrhInterpreterEnsemble" + "CrhInterpreterEnsemble", + "AvgInterpreterEnsemble" ] diff --git a/pyhealth/interpret/methods/avg_ensemble.py b/pyhealth/interpret/methods/avg_ensemble.py new file mode 100644 index 000000000..2104c5b66 --- /dev/null +++ b/pyhealth/interpret/methods/avg_ensemble.py @@ -0,0 +1,66 @@ +"""Average ensemble interpreter. + +This module implements the AGGMean ensemble strategy, which aggregates +attributions from multiple interpretability experts by taking the uniform +average of their competitively-ranked importance scores. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class AvgInterpreterEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using uniform averaging (AGGMean). + + Computes the consensus attribution as the simple arithmetic mean + of the competitively-ranked attributions from all expert interpreters. + This is the simplest ensemble strategy — every expert contributes + equally regardless of its agreement with the others. + + Implements the AGGMean method from: + + Rieger, L. and Hansen, L. K. "Aggregating Explanation Methods + for Stable and Robust Explainability." arXiv preprint + arXiv:1903.00519, 2019. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + + Example: + >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer + >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] + >>> ensemble = AvgInterpreterEnsemble(model, experts) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model, experts) + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions by uniform averaging. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)`` with values in [0, 1]. + """ + return torch.mean(attributions, dim=1) \ No newline at end of file From 85f0e878b31204b94575002ac00e27b02aed04cf Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:36:34 -0600 Subject: [PATCH 10/13] rename --- pyhealth/interpret/methods/{avg_ensemble.py => ensemble_avg.py} | 0 pyhealth/interpret/methods/{crh_ensemble.py => ensemble_crh.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename pyhealth/interpret/methods/{avg_ensemble.py => ensemble_avg.py} (100%) rename pyhealth/interpret/methods/{crh_ensemble.py => ensemble_crh.py} (100%) diff --git a/pyhealth/interpret/methods/avg_ensemble.py b/pyhealth/interpret/methods/ensemble_avg.py similarity index 100% rename from pyhealth/interpret/methods/avg_ensemble.py rename to pyhealth/interpret/methods/ensemble_avg.py diff --git a/pyhealth/interpret/methods/crh_ensemble.py b/pyhealth/interpret/methods/ensemble_crh.py similarity index 100% rename from pyhealth/interpret/methods/crh_ensemble.py rename to pyhealth/interpret/methods/ensemble_crh.py From 104090b49e2c1147e90a6cf968a7e4024c440c92 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:39:57 -0600 Subject: [PATCH 11/13] add AGGVar method --- pyhealth/interpret/methods/__init__.py | 8 ++- pyhealth/interpret/methods/ensemble_var.py | 73 ++++++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 pyhealth/interpret/methods/ensemble_var.py diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 9b688e48b..dc08fcb84 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,8 +8,9 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer -from pyhealth.interpret.methods.crh_ensemble import CrhInterpreterEnsemble -from pyhealth.interpret.methods.avg_ensemble import AvgInterpreterEnsemble +from pyhealth.interpret.methods.ensemble_crh import CrhInterpreterEnsemble +from pyhealth.interpret.methods.ensemble_avg import AvgInterpreterEnsemble +from pyhealth.interpret.methods.ensemble_var import VarInterpreterEnsemble __all__ = [ "BaseInterpreter", @@ -23,5 +24,6 @@ "ShapExplainer", "LimeExplainer", "CrhInterpreterEnsemble", - "AvgInterpreterEnsemble" + "AvgInterpreterEnsemble", + "VarInterpreterEnsemble" ] diff --git a/pyhealth/interpret/methods/ensemble_var.py b/pyhealth/interpret/methods/ensemble_var.py new file mode 100644 index 000000000..355735873 --- /dev/null +++ b/pyhealth/interpret/methods/ensemble_var.py @@ -0,0 +1,73 @@ +"""Variance-weighted ensemble interpreter. + +This module implements the AGGVar ensemble strategy, which aggregates +attributions from multiple interpretability experts by dividing the mean +attribution by the standard deviation, penalising features where experts +disagree. +""" + +from __future__ import annotations + +import torch + +from pyhealth.models import BaseModel +from .base_ensemble import BaseInterpreterEnsemble +from .base_interpreter import BaseInterpreter + + +class VarInterpreterEnsemble(BaseInterpreterEnsemble): + """Ensemble interpreter using variance-weighted averaging (AGGVar). + + Computes the consensus attribution by dividing the mean of the + competitively-ranked expert attributions by their standard deviation + (plus a small constant ε for numerical stability). Features that + all experts agree are important receive high scores, while features + with high inter-expert disagreement are suppressed. + + Implements the AGGVar method from: + + Rieger, L. and Hansen, L. K. "Aggregating Explanation Methods + for Stable and Robust Explainability." arXiv preprint + arXiv:1903.00519, 2019. + + Args: + model: The PyHealth model to interpret. + experts: A list of at least three :class:`BaseInterpreter` instances + whose ``attribute`` methods will be called to produce individual + attribution maps. + + Example: + >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer + >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] + >>> ensemble = VarInterpreterEnsemble(model, experts) + >>> attrs = ensemble.attribute(**batch) + """ + + def __init__( + self, + model: BaseModel, + experts: list[BaseInterpreter], + ): + super().__init__(model, experts) + + # ------------------------------------------------------------------ + # Ensemble implementation + # ------------------------------------------------------------------ + def _ensemble(self, attributions: torch.Tensor) -> torch.Tensor: + """Aggregate expert attributions via variance-weighted averaging. + + Computes ``mean / (std + ε)`` across experts for each feature, + rewarding consensus and penalising disagreement. + + Args: + attributions: Normalized attribution tensor of shape + ``(B, I, M)`` with values in [0, 1], where *B* is the + batch size, *I* is the number of experts, and *M* is the + total number of flattened features. + + Returns: + Consensus tensor of shape ``(B, M)``. + """ + mean = torch.mean(attributions, dim=1) # (B, M) + std = torch.std(attributions, dim=1) # (B, M) + return mean / (std + 1e-6) \ No newline at end of file From c76a3d346dc02611703a50a7b3cb89b4c3567e80 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 12 Feb 2026 04:43:42 -0600 Subject: [PATCH 12/13] Update docs --- pyhealth/interpret/methods/ensemble_avg.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pyhealth/interpret/methods/ensemble_avg.py b/pyhealth/interpret/methods/ensemble_avg.py index 2104c5b66..99abf50b6 100644 --- a/pyhealth/interpret/methods/ensemble_avg.py +++ b/pyhealth/interpret/methods/ensemble_avg.py @@ -15,19 +15,30 @@ class AvgInterpreterEnsemble(BaseInterpreterEnsemble): - """Ensemble interpreter using uniform averaging (AGGMean). + """Ensemble interpreter using uniform averaging (AGGMean / Borda). Computes the consensus attribution as the simple arithmetic mean of the competitively-ranked attributions from all expert interpreters. This is the simplest ensemble strategy — every expert contributes equally regardless of its agreement with the others. + Because the inputs are already competitively ranked, averaging is + equivalent (up to a constant factor) to the Borda count, which sums + the ranks instead. The two methods therefore produce identical + feature orderings. + Implements the AGGMean method from: Rieger, L. and Hansen, L. K. "Aggregating Explanation Methods for Stable and Robust Explainability." arXiv preprint arXiv:1903.00519, 2019. + See also the Borda aggregation in: + + Chen, Y., Mancini, M., Zhu, X., and Akata, Z. "Ensemble + Interpretation: A Unified Method for Interpretable Machine + Learning." arXiv preprint arXiv:2312.06255, 2023. + Args: model: The PyHealth model to interpret. experts: A list of at least three :class:`BaseInterpreter` instances From 1e5bb5c5a5828ab31accbc653f84530477dd8217 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 14 Feb 2026 18:02:05 -0600 Subject: [PATCH 13/13] rename classes --- pyhealth/interpret/methods/__init__.py | 12 ++++++------ pyhealth/interpret/methods/ensemble_avg.py | 4 ++-- pyhealth/interpret/methods/ensemble_crh.py | 4 ++-- pyhealth/interpret/methods/ensemble_var.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index dc08fcb84..6c92cb6e4 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -8,9 +8,9 @@ from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer from pyhealth.interpret.methods.lime import LimeExplainer -from pyhealth.interpret.methods.ensemble_crh import CrhInterpreterEnsemble -from pyhealth.interpret.methods.ensemble_avg import AvgInterpreterEnsemble -from pyhealth.interpret.methods.ensemble_var import VarInterpreterEnsemble +from pyhealth.interpret.methods.ensemble_crh import CrhEnsemble +from pyhealth.interpret.methods.ensemble_avg import AvgEnsemble +from pyhealth.interpret.methods.ensemble_var import VarEnsemble __all__ = [ "BaseInterpreter", @@ -23,7 +23,7 @@ "RandomBaseline", "ShapExplainer", "LimeExplainer", - "CrhInterpreterEnsemble", - "AvgInterpreterEnsemble", - "VarInterpreterEnsemble" + "CrhEnsemble", + "AvgEnsemble", + "VarEnsemble" ] diff --git a/pyhealth/interpret/methods/ensemble_avg.py b/pyhealth/interpret/methods/ensemble_avg.py index 99abf50b6..7f66f14eb 100644 --- a/pyhealth/interpret/methods/ensemble_avg.py +++ b/pyhealth/interpret/methods/ensemble_avg.py @@ -14,7 +14,7 @@ from .base_interpreter import BaseInterpreter -class AvgInterpreterEnsemble(BaseInterpreterEnsemble): +class AvgEnsemble(BaseInterpreterEnsemble): """Ensemble interpreter using uniform averaging (AGGMean / Borda). Computes the consensus attribution as the simple arithmetic mean @@ -48,7 +48,7 @@ class AvgInterpreterEnsemble(BaseInterpreterEnsemble): Example: >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] - >>> ensemble = AvgInterpreterEnsemble(model, experts) + >>> ensemble = AvgEnsemble(model, experts) >>> attrs = ensemble.attribute(**batch) """ diff --git a/pyhealth/interpret/methods/ensemble_crh.py b/pyhealth/interpret/methods/ensemble_crh.py index 5c9a3ace6..2a901eb9b 100644 --- a/pyhealth/interpret/methods/ensemble_crh.py +++ b/pyhealth/interpret/methods/ensemble_crh.py @@ -14,7 +14,7 @@ from .base_interpreter import BaseInterpreter -class CrhInterpreterEnsemble(BaseInterpreterEnsemble): +class CrhEnsemble(BaseInterpreterEnsemble): """Ensemble interpreter using Conflict Resolution on Heterogeneous data (CRH). Iteratively estimates a consensus attribution by reweighting experts @@ -49,7 +49,7 @@ class CrhInterpreterEnsemble(BaseInterpreterEnsemble): Example: >>> from pyhealth.interpret.methods import GradientShap, IntegratedGradients, Saliency >>> experts = [GradientShap(model), IntegratedGradients(model), Saliency(model)] - >>> ensemble = CrhInterpreterEnsemble(model, experts, n_iter=30) + >>> ensemble = CrhEnsemble(model, experts, n_iter=30) >>> attrs = ensemble.attribute(**batch) """ diff --git a/pyhealth/interpret/methods/ensemble_var.py b/pyhealth/interpret/methods/ensemble_var.py index 355735873..629adf26c 100644 --- a/pyhealth/interpret/methods/ensemble_var.py +++ b/pyhealth/interpret/methods/ensemble_var.py @@ -15,7 +15,7 @@ from .base_interpreter import BaseInterpreter -class VarInterpreterEnsemble(BaseInterpreterEnsemble): +class VarEnsemble(BaseInterpreterEnsemble): """Ensemble interpreter using variance-weighted averaging (AGGVar). Computes the consensus attribution by dividing the mean of the @@ -39,7 +39,7 @@ class VarInterpreterEnsemble(BaseInterpreterEnsemble): Example: >>> from pyhealth.interpret.methods import IntegratedGradients, DeepLift, LimeExplainer >>> experts = [IntegratedGradients(model), DeepLift(model), LimeExplainer(model)] - >>> ensemble = VarInterpreterEnsemble(model, experts) + >>> ensemble = VarEnsemble(model, experts) >>> attrs = ensemble.attribute(**batch) """