diff --git a/linopy/common.py b/linopy/common.py index 279bb5c6..0823deac 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -287,6 +287,16 @@ def as_dataarray( def broadcast_mask(mask: DataArray, labels: DataArray) -> DataArray: + """ + Broadcast a boolean mask to match the shape of labels. + + Ensures that mask dimensions are a subset of labels dimensions, broadcasts + the mask accordingly, and fills any NaN values (from missing coordinates) + with False while emitting a FutureWarning. + """ + assert set(mask.dims).issubset(labels.dims), ( + "Dimensions of mask not a subset of resulting labels dimensions." + ) mask = mask.broadcast_like(labels) if mask.isnull().any(): warn( diff --git a/linopy/model.py b/linopy/model.py index a58e09e5..d5d4830a 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -552,9 +552,6 @@ def add_variables( if mask is not None: mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool) - assert set(mask.dims).issubset(data.dims), ( - "Dimensions of mask not a subset of resulting labels dimensions." - ) mask = broadcast_mask(mask, data.labels) # Auto-mask based on NaN in bounds (use numpy for speed) @@ -750,9 +747,6 @@ def add_constraints( if mask is not None: mask = as_dataarray(mask).astype(bool) - assert set(mask.dims).issubset(data.dims), ( - "Dimensions of mask not a subset of resulting labels dimensions." - ) mask = broadcast_mask(mask, data.labels) # Auto-mask based on null expressions or NaN RHS (use numpy for speed) diff --git a/test/test_constraints.py b/test/test_constraints.py index 0f0ae35a..01aebb69 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -191,6 +191,11 @@ def test_masked_constraints_broadcast() -> None: assert (m.constraints.labels.bc3[2:5, :] == -1).all() assert (m.constraints.labels.bc3[5:10, :] == -1).all() + # Mask with extra dimension not in data should raise + mask4 = xr.DataArray([True, False], dims=["extra_dim"]) + with pytest.raises(AssertionError, match="not a subset"): + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc4", mask=mask4) + def test_non_aligned_constraints() -> None: m: Model = Model() diff --git a/test/test_variables.py b/test/test_variables.py index 824ec77a..37de6aff 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -134,6 +134,11 @@ def test_variables_mask_broadcast() -> None: assert (z.labels[2:5, :] == -1).all() assert (z.labels[5:10, :] == -1).all() + # Mask with extra dimension not in data should raise + mask4 = xr.DataArray([True, False], dims=["extra_dim"]) + with pytest.raises(AssertionError, match="not a subset"): + m.add_variables(lower, upper, name="w", mask=mask4) + def test_variables_get_name_by_label(m: Model) -> None: assert m.variables.get_name_by_label(4) == "x"