Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
388a104
Adding the comm file from Makani and making the necessary changes to …
odiazib Oct 16, 2025
a72ef22
Implement a split of the dataset for spatial parallelism and create a…
odiazib Oct 16, 2025
a3e42cc
Adding the necessary files from Makani for spatial parallelism.
odiazib Oct 16, 2025
81e9e60
Adding the necessary files from Makani for spatial parallelism.
odiazib Oct 16, 2025
8f6e71a
Adding spatial parallelism to the model, layers, and FFT. Testing thi…
odiazib Oct 17, 2025
0ad9ffc
Adding NVIDIA PhysicsNemo
odiazib Oct 17, 2025
352f6fe
Move the block of code from conf to the xarray class initialization.
odiazib Oct 20, 2025
ac25394
Reintroduce the logic to run the case in serial and on the CPU.
odiazib Oct 21, 2025
30ffac8
Split domain for spatial parallelism.
odiazib Oct 22, 2025
3273809
Moving code to distributed class.
odiazib Oct 23, 2025
eac6d17
Fixing the xarray test with spatial parallelism.
odiazib Oct 23, 2025
1169228
Spatial distributed model test is working, but tolerance error is 1e-3.
odiazib Oct 28, 2025
ff60738
Adding init_gradient_reduction_hooks to the model and testing trainin…
odiazib Oct 29, 2025
aeb1d1e
Fixing the dataset reader to make training work for the E3SM case.
odiazib Oct 29, 2025
50f1bfb
Getting spatial parallelism input parameters from the CLI.
odiazib Nov 3, 2025
bbf615e
Saving and loading checkpoints when a model uses spatial parallelism.
odiazib Nov 3, 2025
4aa6cd6
Moving init_gradient_reduction_hooks to the Distributed class.
odiazib Nov 3, 2025
ee2b865
Fix initialization and checkpoint handling in distribute class
odiazib Nov 5, 2025
a8e8e31
Only use comm.cleanup() if spatial parallelism is enabled.
odiazib Nov 5, 2025
1402149
Removing old code.
odiazib Nov 5, 2025
8ea20c6
Adding recommendations for PR review and deleting old code.
odiazib Nov 6, 2025
d639dca
Adding recommendations for PR review
odiazib Nov 6, 2025
f48d1e6
Apply suggestion from @mcgibbon
odiazib Nov 6, 2025
3ae7bda
Removing 'comm' and moving the routine to the distribute class.
odiazib Nov 6, 2025
2033b03
Moving logic to distribute class.
odiazib Nov 6, 2025
1b82ca0
Removing 'comm' from the model implementation. This change makes the …
odiazib Nov 7, 2025
31e1835
Removing old code.
odiazib Nov 7, 2025
61425d3
Cleaning up code based on PR review.
odiazib Nov 7, 2025
4e1fc4a
Build the NeMo version of SFNO if spatial parallelism is on.
odiazib Nov 7, 2025
9b10095
Adding review recommendations.
odiazib Nov 7, 2025
44a601a
Fixing unit test for xarray sp
odiazib Nov 7, 2025
550572a
Fixing test for sfnonet with spatial parallelism.
odiazib Nov 8, 2025
07ed091
routine to create a directory.
odiazib Nov 10, 2025
d84b438
Saving test
odiazib Nov 10, 2025
8c8b5e6
unit test for loss function.
odiazib Nov 11, 2025
7e43b44
The ERA5 dataset uses latitude and longitude instead of lon and lat. …
odiazib Nov 17, 2025
2ef0082
Adding code back.
odiazib Nov 18, 2025
79869d4
Unit test for coordinates and annual aggregator. The spatial parallel…
odiazib Nov 19, 2025
7d43bbc
Fixing the error in batch size, we must use 'batch' and 'comm' to get…
odiazib Nov 19, 2025
44ba51f
Split data after is readed.
odiazib Nov 25, 2025
d714fbb
Using more than 1 sample.
odiazib Nov 25, 2025
ea4ee0d
Gather tensors in a snapshot so that the plots display the whole domain.
odiazib Nov 25, 2025
e0c7a7a
updates for unit tests.
odiazib Nov 25, 2025
664840b
tmp cleanup/modularization
mahf708 Nov 26, 2025
5257831
Cleaning up.
odiazib Dec 1, 2025
aad61d6
save unit test.
odiazib Dec 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fme/ace/aggregator/inference/enso/dynamic_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __post_init__(self):
torch.logical_and(lat_mask, lon_mask), 1.0, 0.0
)

dist = Distributed.get_instance()
self._regional_weights = self._regional_weights[*dist.get_local_slices(self._regional_weights.shape)]

@property
def regional_weights(self) -> torch.Tensor:
return self._regional_weights
Expand Down
12 changes: 12 additions & 0 deletions fme/ace/aggregator/inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator
from .video import VideoAggregator
from .zonal_mean import ZonalMeanAggregator
from fme.core.distributed import Distributed

wandb = WandB.get_instance()
APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730)
Expand Down Expand Up @@ -157,12 +158,23 @@ def build(
monthly_reference_data = xr.open_dataset(
self.monthly_reference_data, decode_timedelta=False
)
dist = Distributed.get_instance()
# CHECK: Is there another way to get lat_length and lon_length?
# Should we move this splitting operation inside the InferenceEvaluatorAggregator?
lat_length = len(monthly_reference_data.coords['lat'])
lon_length = len(monthly_reference_data.coords['lon'])
crop_shape = (lat_length, lon_length)
slice_h, slice_w = dist.get_local_slices(crop_shape)
monthly_reference_data = monthly_reference_data.isel(lat=slice_h, lon=slice_w)

if self.time_mean_reference_data is None:
time_mean = None
else:
time_mean = xr.open_dataset(
self.time_mean_reference_data, decode_timedelta=False
)


return InferenceEvaluatorAggregator(
dataset_info=dataset_info,
n_timesteps=n_timesteps,
Expand Down
21 changes: 10 additions & 11 deletions fme/ace/aggregator/one_step/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..plotting import plot_paneled_data


from fme.core.distributed import Distributed
class SnapshotAggregator:
"""
An aggregator that records the first sample of the last batch of data.
Expand Down Expand Up @@ -65,23 +65,22 @@ def _get_data(self) -> tuple[TensorMapping, TensorMapping, TensorMapping]:
input_time = 0
target_time = 1
gen, target, input = {}, {}, {}
dist = Distributed.get_instance()
for name in self._gen_data.keys():
# use first sample in batch
gen[name] = (
self._gen_data[name]
.select(dim=time_dim, index=target_time)[0]
.cpu()
.numpy()
)
gen_data_local=self._gen_data[name].select(dim=time_dim, index=target_time)[0]
gen_data = dist.gather_spatial_distributed(gen_data_local)
gen[name] = (gen_data.cpu().numpy())

target_local=self._target_data[name].select(dim=time_dim, index=target_time)[0]
target_data = dist.gather_spatial_distributed(target_local)
target[name] = (
self._target_data[name]
.select(dim=time_dim, index=target_time)[0]
target_data
.cpu()
.numpy()
)
input[name] = (
self._target_data[name]
.select(dim=time_dim, index=input_time)[0]
target_data
.cpu()
.numpy()
)
Expand Down
97 changes: 97 additions & 0 deletions fme/ace/aggregator/one_step/test_reduced_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import pytest
import torch
import os
from fme.ace.aggregator.one_step.reduced import MeanAggregator
from fme.core.device import get_device
from fme.core.gridded_ops import LatLonOperations

from fme.core.distributed import Distributed


def test_loss_wo_sp():
"""
Basic test the aggregator combines loss correctly
with multiple batches and no distributed training.
"""
nx=8
ny=8
torch.manual_seed(0)
example_data = {
"a": torch.randn(1, 2, nx, ny, device=get_device()),
}
area_weights = torch.ones(nx,ny).to(get_device())
aggregator = MeanAggregator(LatLonOperations(area_weights))
aggregator.record_batch(
loss=1.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
aggregator.record_batch(
loss=2.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
logs = aggregator.get_logs(label="metrics")
print("lost", logs["metrics/loss"] )
assert logs["metrics/loss"] == 1.5
aggregator.record_batch(
loss=3.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
logs = aggregator.get_logs(label="metrics")
print("lost", logs["metrics/loss"] )
assert logs["metrics/loss"] == 2.0

def test_loss_with_sp():
os.environ['H_PARALLEL_SIZE'] = '2'
os.environ['W_PARALLEL_SIZE'] = '2'
nx=8
ny=8
torch.manual_seed(0)
tensor_data_host=torch.randn(1, 2, nx, ny)
area_weights = torch.ones(nx,ny)
aggregator = MeanAggregator(LatLonOperations(area_weights))
dist = Distributed.get_instance()
this_shape=(tensor_data_host.shape[-2],tensor_data_host.shape[-1])
tensor_data_local_host = (tensor_data_host[:,:,*dist.get_local_slices(this_shape)]).detach().clone()
tensor_data_local=tensor_data_local_host.to(dist.local_rank)

example_data = {
"a": tensor_data_local
}

aggregator.record_batch(
loss=1.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
aggregator.record_batch(
loss=2.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
logs = aggregator.get_logs(label="metrics")
print("lost", logs["metrics/loss"] )
assert logs["metrics/loss"] == 1.5
aggregator.record_batch(
loss=3.0,
target_data=example_data,
gen_data=example_data,
target_data_norm=example_data,
gen_data_norm=example_data,
)
logs = aggregator.get_logs(label="metrics")
print("lost", logs["metrics/loss"] )
assert logs["metrics/loss"] == 2.0
6 changes: 1 addition & 5 deletions fme/ace/data_loading/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ def get_dataset(

def __post_init__(self):
dist = Distributed.get_instance()
if self.batch_size % dist.world_size != 0:
raise ValueError(
"batch_size must be divisible by the number of parallel "
f"workers, got {self.batch_size} and {dist.world_size}"
)
dist.check_local_batch_size(self.batch_size)
# TODO: remove following backwards compatibility code in a future release
if isinstance(self.dataset, Sequence):
warnings.warn(
Expand Down
70 changes: 70 additions & 0 deletions fme/ace/models/makani_models/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.distributed as dist

# from makani.utils import comm
from fme.ace.utils import comm


def count_parameters(model, device):
"""Counts model parameters"""

with torch.no_grad():
total_stats = torch.zeros(2, dtype=torch.long, device=device)
local_bytes = 0
for p in model.parameters():
if not p.requires_grad:
continue

# make sure complex weight tensors are accounted for correctly
pview = torch.view_as_real(p) if p.is_complex() else p
pstats = torch.tensor([pview.numel(), pview.nbytes], dtype=torch.long, device=device)
local_bytes += pview.nbytes

# if the weight is split, then we need to reduce
if hasattr(p, "sharded_dims_mp"):
for group in p.sharded_dims_mp:
if (group is not None) and (comm.get_size(group) > 1):
dist.all_reduce(pstats, group=comm.get_group(group))

# sum the total stats
total_stats += pstats

# transfer to cpu
total_stats_arr = total_stats.cpu().numpy()
total_count = total_stats_arr[0]
total_bytes = total_stats_arr[1]

return total_count, total_bytes, local_bytes


def compare_model_parameters(model1, model2):
"""Checks whether both models have the same parameters"""

for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).any():
return False
return True


def check_parameters(model):
"""Prints shapes, strides and whether parameters are contiguous"""
for p in model.parameters():
if p.requires_grad:
print(p.shape, p.stride(), p.is_contiguous())

return
Loading