diff --git a/fme/ace/aggregator/inference/enso/dynamic_index.py b/fme/ace/aggregator/inference/enso/dynamic_index.py index e569467cf..8256a7880 100644 --- a/fme/ace/aggregator/inference/enso/dynamic_index.py +++ b/fme/ace/aggregator/inference/enso/dynamic_index.py @@ -48,6 +48,9 @@ def __post_init__(self): torch.logical_and(lat_mask, lon_mask), 1.0, 0.0 ) + dist = Distributed.get_instance() + self._regional_weights = self._regional_weights[*dist.get_local_slices(self._regional_weights.shape)] + @property def regional_weights(self) -> torch.Tensor: return self._regional_weights diff --git a/fme/ace/aggregator/inference/main.py b/fme/ace/aggregator/inference/main.py index 241f33efb..e40b1a3f6 100644 --- a/fme/ace/aggregator/inference/main.py +++ b/fme/ace/aggregator/inference/main.py @@ -35,6 +35,7 @@ from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator from .video import VideoAggregator from .zonal_mean import ZonalMeanAggregator +from fme.core.distributed import Distributed wandb = WandB.get_instance() APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730) @@ -157,12 +158,23 @@ def build( monthly_reference_data = xr.open_dataset( self.monthly_reference_data, decode_timedelta=False ) + dist = Distributed.get_instance() + # CHECK: Is there another way to get lat_length and lon_length? + # Should we move this splitting operation inside the InferenceEvaluatorAggregator? + lat_length = len(monthly_reference_data.coords['lat']) + lon_length = len(monthly_reference_data.coords['lon']) + crop_shape = (lat_length, lon_length) + slice_h, slice_w = dist.get_local_slices(crop_shape) + monthly_reference_data = monthly_reference_data.isel(lat=slice_h, lon=slice_w) + if self.time_mean_reference_data is None: time_mean = None else: time_mean = xr.open_dataset( self.time_mean_reference_data, decode_timedelta=False ) + + return InferenceEvaluatorAggregator( dataset_info=dataset_info, n_timesteps=n_timesteps, diff --git a/fme/ace/aggregator/one_step/snapshot.py b/fme/ace/aggregator/one_step/snapshot.py index 71acb9e98..7c6d29e95 100644 --- a/fme/ace/aggregator/one_step/snapshot.py +++ b/fme/ace/aggregator/one_step/snapshot.py @@ -9,7 +9,7 @@ from ..plotting import plot_paneled_data - +from fme.core.distributed import Distributed class SnapshotAggregator: """ An aggregator that records the first sample of the last batch of data. @@ -65,23 +65,22 @@ def _get_data(self) -> tuple[TensorMapping, TensorMapping, TensorMapping]: input_time = 0 target_time = 1 gen, target, input = {}, {}, {} + dist = Distributed.get_instance() for name in self._gen_data.keys(): # use first sample in batch - gen[name] = ( - self._gen_data[name] - .select(dim=time_dim, index=target_time)[0] - .cpu() - .numpy() - ) + gen_data_local=self._gen_data[name].select(dim=time_dim, index=target_time)[0] + gen_data = dist.gather_spatial_distributed(gen_data_local) + gen[name] = (gen_data.cpu().numpy()) + + target_local=self._target_data[name].select(dim=time_dim, index=target_time)[0] + target_data = dist.gather_spatial_distributed(target_local) target[name] = ( - self._target_data[name] - .select(dim=time_dim, index=target_time)[0] + target_data .cpu() .numpy() ) input[name] = ( - self._target_data[name] - .select(dim=time_dim, index=input_time)[0] + target_data .cpu() .numpy() ) diff --git a/fme/ace/aggregator/one_step/test_reduced_sp.py b/fme/ace/aggregator/one_step/test_reduced_sp.py new file mode 100644 index 000000000..df680d3ba --- /dev/null +++ b/fme/ace/aggregator/one_step/test_reduced_sp.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest +import torch +import os +from fme.ace.aggregator.one_step.reduced import MeanAggregator +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + +from fme.core.distributed import Distributed + + +def test_loss_wo_sp(): + """ + Basic test the aggregator combines loss correctly + with multiple batches and no distributed training. + """ + nx=8 + ny=8 + torch.manual_seed(0) + example_data = { + "a": torch.randn(1, 2, nx, ny, device=get_device()), + } + area_weights = torch.ones(nx,ny).to(get_device()) + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=1.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + aggregator.record_batch( + loss=2.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 1.5 + aggregator.record_batch( + loss=3.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 2.0 + +def test_loss_with_sp(): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + nx=8 + ny=8 + torch.manual_seed(0) + tensor_data_host=torch.randn(1, 2, nx, ny) + area_weights = torch.ones(nx,ny) + aggregator = MeanAggregator(LatLonOperations(area_weights)) + dist = Distributed.get_instance() + this_shape=(tensor_data_host.shape[-2],tensor_data_host.shape[-1]) + tensor_data_local_host = (tensor_data_host[:,:,*dist.get_local_slices(this_shape)]).detach().clone() + tensor_data_local=tensor_data_local_host.to(dist.local_rank) + + example_data = { + "a": tensor_data_local + } + + aggregator.record_batch( + loss=1.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + aggregator.record_batch( + loss=2.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 1.5 + aggregator.record_batch( + loss=3.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 2.0 diff --git a/fme/ace/data_loading/config.py b/fme/ace/data_loading/config.py index 1ab55063a..29109b342 100644 --- a/fme/ace/data_loading/config.py +++ b/fme/ace/data_loading/config.py @@ -82,11 +82,7 @@ def get_dataset( def __post_init__(self): dist = Distributed.get_instance() - if self.batch_size % dist.world_size != 0: - raise ValueError( - "batch_size must be divisible by the number of parallel " - f"workers, got {self.batch_size} and {dist.world_size}" - ) + dist.check_local_batch_size(self.batch_size) # TODO: remove following backwards compatibility code in a future release if isinstance(self.dataset, Sequence): warnings.warn( diff --git a/fme/ace/models/makani_models/helpers.py b/fme/ace/models/makani_models/helpers.py new file mode 100644 index 000000000..3f48f98d0 --- /dev/null +++ b/fme/ace/models/makani_models/helpers.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist + +# from makani.utils import comm +from fme.ace.utils import comm + + +def count_parameters(model, device): + """Counts model parameters""" + + with torch.no_grad(): + total_stats = torch.zeros(2, dtype=torch.long, device=device) + local_bytes = 0 + for p in model.parameters(): + if not p.requires_grad: + continue + + # make sure complex weight tensors are accounted for correctly + pview = torch.view_as_real(p) if p.is_complex() else p + pstats = torch.tensor([pview.numel(), pview.nbytes], dtype=torch.long, device=device) + local_bytes += pview.nbytes + + # if the weight is split, then we need to reduce + if hasattr(p, "sharded_dims_mp"): + for group in p.sharded_dims_mp: + if (group is not None) and (comm.get_size(group) > 1): + dist.all_reduce(pstats, group=comm.get_group(group)) + + # sum the total stats + total_stats += pstats + + # transfer to cpu + total_stats_arr = total_stats.cpu().numpy() + total_count = total_stats_arr[0] + total_bytes = total_stats_arr[1] + + return total_count, total_bytes, local_bytes + + +def compare_model_parameters(model1, model2): + """Checks whether both models have the same parameters""" + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + if p1.data.ne(p2.data).any(): + return False + return True + + +def check_parameters(model): + """Prints shapes, strides and whether parameters are contiguous""" + for p in model.parameters(): + if p.requires_grad: + print(p.shape, p.stride(), p.is_contiguous()) + + return diff --git a/fme/ace/models/makani_mpu/fft.py b/fme/ace/models/makani_mpu/fft.py new file mode 100644 index 000000000..db119184c --- /dev/null +++ b/fme/ace/models/makani_mpu/fft.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import math +import torch +from torch import nn +import torch.nn.functional as F + +# from makani.utils import comm +from fme.ace.utils import comm +from physicsnemo.distributed.utils import compute_split_shapes +from torch_harmonics.distributed import distributed_transpose_azimuth as distributed_transpose_w +from torch_harmonics.distributed import distributed_transpose_polar as distributed_transpose_h + + +class DistributedRealFFT1(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlon: int, lmax: Optional[int] = None, mmax: Optional[int] = None): + super().__init__() + + # get the comms grid: + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlon = nlon + self.lmax = min(lmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + self.mmax = min(mmax or self.lmax, self.lmax) + + # compute half modes + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + # shapes + self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_w) + self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_w) + + def forward(self, x: torch.Tensor, norm: Optional[str] = "ortho", channel_dim: Optional[int] = -3) -> torch.Tensor: + # store number of chans + num_chans = x.shape[channel_dim] + + # We make w local by transposing into channel dim + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (channel_dim, -1), self.lon_shapes) + + # do first FFT + x = torch.fft.rfft(x, n=self.nlon, dim=-1, norm=norm) + + # mode truncation + x = x[..., : self.mmax].contiguous() + + # transpose: after this, m is split and c is local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, channel_dim), chan_shapes) + + return x + + +class DistributedInverseRealFFT1(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlon: int, lmax: Optional[int] = None, mmax: Optional[int] = None): + super().__init__() + + # get the comms grid: + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlon = nlon + self.lmax = min(lmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + self.mmax = min(mmax or self.lmax, self.lmax) + + # compute half modes + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + # shapes + self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_w) + self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_w) + + def forward(self, x: torch.Tensor, norm: Optional[str] = "ortho", channel_dim: Optional[int] = -3) -> torch.Tensor: + # store number of channels + num_chans = x.shape[channel_dim] + + # transpose: after this, channels are split and m is local + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (channel_dim, -1), self.m_shapes) + + # apply the inverse (real) FFT + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm=norm) + + # transpose: after this, m is split and channels are local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, channel_dim), chan_shapes) + + return x + + +class DistributedRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat: int, nlon: int, lmax: Optional[int] = None, mmax: Optional[int] = None): + super().__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + # compute half modes + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + # shapes + self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_h) + self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_w) + self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_h) + self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_w) + + def forward(self, x: torch.Tensor, norm: Optional[str] = "ortho", channel_dim: Optional[int] = -3) -> torch.Tensor: + # store number of chans + num_chans = x.shape[channel_dim] + + # h and w is split. First we make w local by transposing into channel dim + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (channel_dim, -1), self.lon_shapes) + + # do first FFT + x = torch.fft.rfft(x, n=self.nlon, dim=-1, norm=norm) + + # mode truncation + x = x[..., : self.mmax].contiguous() + + # transpose: after this, m is split and c is local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, channel_dim), chan_shapes) + + # transpose: after this, c is split and h is local + if self.comm_size_h > 1: + x = distributed_transpose_h.apply(x, (channel_dim, -2), self.lat_shapes) + + # do second FFT: + x = torch.fft.fft(x, n=self.nlat, dim=-2, norm=norm) + + # apply mode truncation: + x = torch.cat([x[..., : self.lmax_high, :], x[..., -self.lmax_low :, :]], dim=-2) + + # transpose: after this, l is split and c is local + if self.comm_size_h > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_h) + x = distributed_transpose_h.apply(x, (-2, channel_dim), chan_shapes) + + return x + + +class DistributedInverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat: int, nlon: int, lmax: Optional[int] = None, mmax: Optional[int] = None): + super().__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + # compute half modes + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + # shapes + self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_h) + self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_w) + self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_h) + self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_w) + + def forward(self, x: torch.Tensor, norm: Optional[str] = "ortho", channel_dim: Optional[int] = -3) -> torch.Tensor: + # store number of channels + num_chans = x.shape[channel_dim] + + # transpose: after that, channels are split, l is local: + if self.comm_size_h > 1: + x = distributed_transpose_h.apply(x, (channel_dim, -2), self.l_shapes) + + # we should pad the middle here manually, so that the inverse FFT is correct + # TEST THIS + if self.lmax < self.nlat: + xh = x[..., : self.lmax_high, :] + xl = x[..., -self.lmax_low :, :] + xhp = F.pad(xh, (0, 0, 0, self.nlat - self.lmax), mode="constant") + x = torch.cat([xhp, xl], dim=-2) + + # do first fft + x = torch.fft.ifft(x, n=self.nlat, dim=-2, norm=norm) + + if self.comm_size_h > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_h) + x = distributed_transpose_h.apply(x, (-2, channel_dim), chan_shapes) + + # transpose: after this, channels are split and m is local + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (channel_dim, -1), self.m_shapes) + + # apply the inverse (real) FFT + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm=norm) + + # transpose: after this, m is split and channels are local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, channel_dim), chan_shapes) + + return x + + +# 3D routines +# forward +class DistributedRealFFT3(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nd, nh, nw, ldmax=None, lhmax=None, lwmax=None): + super().__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nd = nd + self.nh = nh + self.nw = nw + self.ldmax = min(ldmax or self.nd, self.nd) + self.lhmax = min(lhmax or self.nh, self.nh) + self.lwmax = min(lwmax or self.nw // 2 + 1, self.nw // 2 + 1) + + # half-modes + self.ldmax_high = math.ceil(self.ldmax / 2) + self.ldmax_low = math.floor(self.ldmax / 2) + self.lhmax_high = math.ceil(self.lhmax / 2) + self.lhmax_low = math.floor(self.lhmax / 2) + + # shapes, we assume the d-dim is always local + self.lat_shapes = compute_split_shapes(self.nh, self.comm_size_h) + self.lon_shapes = compute_split_shapes(self.nw, self.comm_size_w) + self.l_shapes = compute_split_shapes(self.lhmax, self.comm_size_h) + self.m_shapes = compute_split_shapes(self.lwmax, self.comm_size_w) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # make sure input is 5D + assert x.dim() == 5 + + # store number of chans + num_chans = x.shape[1] + + # h and w is split. First we make w local by transposing into channel dim + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (1, -1), self.lon_shapes) + + # do first 2D FFT + x = torch.fft.rfft2(x, s=(self.nd, self.nw), dim=(-3, -1), norm="ortho") + + # truncate width-modes + x = x[..., : self.lwmax] + + # truncate depth-modes + x = torch.cat([x[..., : self.ldmax_high, :, :], x[..., -self.ldmax_low :, :, :]], dim=-3) + + # transpose: after this, m is split and c is local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, 1), chan_shapes) + + # transpose: after this, c is split and h is local + if self.comm_size_h > 1: + x = distributed_transpose_h.apply(x, (1, -2), self.lat_shapes) + + # do second FFT: + x = torch.fft.fft(x, n=self.nh, dim=-2, norm="ortho") + + # truncate the modes + x = torch.cat([x[..., : self.lhmax_high, :], x[..., -self.lhmax_low :, :]], dim=-2) + + # transpose: after this, l is split and c is local + if self.comm_size_h > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_h) + x = distributed_transpose_h.apply(x, (-2, 1), chan_shapes) + + return x + + +class DistributedInverseRealFFT3(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nd, nh, nw, ldmax=None, lhmax=None, lwmax=None): + super().__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nd = nd + self.nh = nh + self.nw = nw + self.ldmax = min(ldmax or self.nd, self.nd) + self.lhmax = min(lhmax or self.nh, self.nh) + self.lwmax = min(lwmax or self.nw // 2 + 1, self.nw // 2 + 1) + + # half-modes + self.ldmax_high = math.ceil(self.ldmax / 2) + self.ldmax_low = math.floor(self.ldmax / 2) + self.lhmax_high = math.ceil(self.lhmax / 2) + self.lhmax_low = math.floor(self.lhmax / 2) + + # shapes, we assume the d-dim is always local + self.lat_shapes = compute_split_shapes(self.nh, self.comm_size_h) + self.lon_shapes = compute_split_shapes(self.nw, self.comm_size_w) + self.l_shapes = compute_split_shapes(self.lhmax, self.comm_size_h) + self.m_shapes = compute_split_shapes(self.lwmax, self.comm_size_w) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # make sure input is 5D + assert x.dim() == 5 + + # store number of chans + num_chans = x.shape[1] + + # transpose: after that, channels are split, lh is local: + if self.comm_size_h > 1: + x = distributed_transpose_h.apply(x, (1, -2), self.l_shapes) + + # we should pad the middle here manually, so that the inverse FFT is correct + if self.lhmax < self.nh: + xh = x[..., : self.lhmax_high, :] + xl = x[..., -self.lhmax_low :, :] + xhp = F.pad(xh, (0, 0, 0, self.nh - self.lhmax), mode="constant") + x = torch.cat([xhp, xl], dim=-2) + + if self.ldmax < self.nd: + xh = x[..., : self.ldmax_high, :, :] + xl = x[..., -self.ldmax_low :, :, :] + xhp = F.pad(xh, (0, 0, 0, 0, 0, self.nd - self.ldmax), mode="constant") + x = torch.cat([xhp, xl], dim=-3) + + # do first fft + x = torch.fft.ifft2(x, s=(self.nd, self.nh), dim=(-3, -2), norm="ortho") + + if self.comm_size_h > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_h) + x = distributed_transpose_h.apply(x, (-2, 1), chan_shapes) + + # transpose: after this, channels are split and m is local + if self.comm_size_w > 1: + x = distributed_transpose_w.apply(x, (1, -1), self.m_shapes) + + # apply the inverse (real) FFT + x = torch.fft.irfft(x, n=self.nw, dim=-1, norm="ortho") + + # transpose: after this, m is split and channels are local + if self.comm_size_w > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_w) + x = distributed_transpose_w.apply(x, (-1, 1), chan_shapes) + + return x diff --git a/fme/ace/models/makani_mpu/helpers.py b/fme/ace/models/makani_mpu/helpers.py new file mode 100644 index 000000000..445f824c6 --- /dev/null +++ b/fme/ace/models/makani_mpu/helpers.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors + +from physicsnemo.distributed.utils import split_tensor_along_dim +# from makani.utils import comm +from fme.ace.utils import comm + + +def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): + + # get comm params + comm_size = dist.get_world_size(group=group) + comm_rank = dist.get_rank(group=group) + + # split and local transposition + tsplit = split_tensor_along_dim(tensor, dim=dim0, num_chunks=comm_size) + x_send = [y.contiguous() for y in tsplit] + x_send_shapes = [x.shape for x in x_send] + x_recv = [] + x_shape = list(x_send_shapes[comm_rank]) + for dim1_len in dim1_split_sizes: + x_shape[dim1] = dim1_len + x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device)) + + # global transposition + req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) + + # get dim0 split sizes + dim0_split_sizes = [x[dim0] for x in x_send_shapes] + + return x_recv, dim0_split_sizes, req + + +def gather_uneven(tensor, dim, comm_name): + if comm.get_size(comm_name) == 1: + return tensor + + # gather dims + dim_tensor = torch.tensor([tensor.shape[dim]], dtype=torch.int, device=tensor.device) + dim_list = [torch.empty_like(dim_tensor) for _ in range(comm.get_size(comm_name))] + dim_list[comm.get_rank(comm_name)] = dim_tensor + dist.all_gather(dim_list, dim_tensor, group=comm.get_group(comm_name)) + + # gather tensor + gathered_shape = list(tensor.shape) + tensor_list = [] + for rshape in dim_list: + gathered_shape[dim] = rshape.item() + tensor_list.append(torch.empty(gathered_shape, dtype=tensor.dtype, device=tensor.device)) + + tensor_list[comm.get_rank(comm_name)] = tensor + dist.all_gather(tensor_list, tensor, group=comm.get_group(comm_name)) + + # concatenate + result = torch.cat(tensor_list, dim=dim) + + return result + + +def sync_params(model, mode="broadcast"): + """Helper routine to ensure shared weights are the same after initialization""" + + def _sync_param(param, comm_group, mode): + if comm.get_size(comm_group) > 1: + if mode == "broadcast": + is_complex = param.is_complex() + if is_complex: + param_real = torch.view_as_real(param).clone() + else: + param_real = param.clone() + # tlist = [torch.empty_like(param_real) for x in range(comm.get_size(comm_group))] + # tlist[comm.get_rank(comm_group)] = param_real + # gather all weights in the comm group + dist.broadcast(param_real, src=comm.get_root(comm_group), group=comm.get_group(comm_group), async_op=False) + # use weight of rank 0 + # important to use copy here otherwise the handle gets detaches from the optimizer + if is_complex: + param.copy_(torch.view_as_complex(param_real)) + else: + param.copy_(param_real) + elif mode == "mean": + is_complex = param.is_complex() + if is_complex: + dist.all_reduce(torch.view_as_real(param), op=dist.ReduceOp.AVG, group=comm.get_group(comm_group), async_op=False) + else: + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=comm.get_group(comm_group), async_op=False) + else: + raise ValueError(f"Unknown weight synchronization mode {mode}") + + return + + with torch.no_grad(): + # distributed sync step + for param in model.parameters(): + # share along data dim + _sync_param(param, "data", mode) + + if not hasattr(param, "is_shared_mp"): + param.is_shared_mp = ["model"] + + for comm_group in param.is_shared_mp: + _sync_param(param, comm_group, mode) + + # synchronize the device to make sure all copies have finished + if dist.is_initialized(): + device = next(model.parameters()).device + dist.barrier(device_ids=[device.index]) + + return diff --git a/fme/ace/models/makani_mpu/layer_norm.py b/fme/ace/models/makani_mpu/layer_norm.py new file mode 100644 index 000000000..97308aee7 --- /dev/null +++ b/fme/ace/models/makani_mpu/layer_norm.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from torch import amp + +from typing import Tuple, List, Optional + +# for spatial model-parallelism +# from makani.utils import comm +from fme.ace.utils import comm +from physicsnemo.distributed.mappings import gather_from_parallel_region, copy_to_parallel_region + +# quadrature stuff +# from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature +@torch.compile +def _normalize_transform_kernel(x: torch.Tensor, mean: torch.Tensor, var: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor: + + # normalization + x = (x - mean) / torch.sqrt(var + eps) + + # affine transformation + x = weight * x + bias + + return x + + +@torch.compile +def _normalize_kernel(x: torch.Tensor, mean: torch.Tensor, var: torch.Tensor, eps: float) -> torch.Tensor: + + # normalization + x = (x - mean) / torch.sqrt(var + eps) + + return x + +@torch.compile +def _welford_kernel(vars: torch.Tensor, means: torch.Tensor, counts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # for weighted welford, replace counts by + # omega = sum_i w_i, where w_i are the individual weights + + # get m2s + m2s = vars * counts + + # do welford update + mean = means[0, ...] + m2 = m2s[0, ...] + count = counts[0, ...] + + # use Welford's algorithm to accumulate them into a single mean and variance + for i in range(1, means.shape[0]): + delta = means[i, ...] - mean + m2 = m2 + m2s[i, ...] + delta**2 * count * counts[i, ...] / (count + counts[i, ...]) + if i == 1: + mean = (mean * count + means[i, ...] * counts[i, ...]) / (count + counts[i, ...]) + else: + mean = mean + delta * counts[i, ...] / (count + counts[i, ...]) + + # update the current count + count = count + counts[i, ...] + + var = m2 / count + + return var, mean, count + +def distributed_welford_variance(var: torch.Tensor, mean: torch.Tensor, count: torch.Tensor, group: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes the statistics locally, then uses the Welford online algorithm to reduce them""" + + # concatenate: + # this has the shape [3, 1, ...] + var_mean_count = torch.stack([var, mean, count], dim=0).unsqueeze(1) + + # gather + # this has the shape [3, spatial_size, ...], we split it up directly into individual tensors again + vars_means_counts = gather_from_parallel_region(var_mean_count, dim=1, shapes=None, group=group) + + # split up + vars = vars_means_counts[0, ...] + means = vars_means_counts[1, ...] + counts = vars_means_counts[2, ...] + + # do welford update + var, mean, count = _welford_kernel(vars, means, counts) + + return var, mean, count + +class DistributedInstanceNorm2d(nn.Module): + """ + Computes a distributed instance norm using Welford's online algorithm + """ + + def __init__(self, num_features, eps=1e-05, affine=False): + super().__init__() + + self.eps = eps + self.affine = affine + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight.is_shared_mp = ["spatial"] + self.bias.is_shared_mp = ["spatial"] + + def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the statistics locally, then uses the Welford online algorithm to reduce them""" + + # extract shapes + B, C, H, W = x.shape + + # those have the shapes [B, C] + var, mean = torch.var_mean(x, dim=(-2, -1), unbiased=False, keepdim=False) + + # workaround to not use shapes, as otherwise cuda graphs won't work + # those have the shapes [B, C] + count = torch.ones_like(x, requires_grad=False) + count = torch.sum(count, dim=(-2, -1), keepdim=False) + var, mean, _ = distributed_welford_variance(var, mean, count, "spatial") + + # reshape + var = var.reshape(B, C, 1, 1) + mean = mean.reshape(B, C, 1, 1) + + return var, mean + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + with amp.autocast(device_type="cuda", enabled=False): + dtype = x.dtype + x = x.float() + + # start by computing std and mean + var, mean = self._stats_welford(x) + + # this is absolutely necessary to get the correct graph in the backward pass + mean = copy_to_parallel_region(mean, "spatial") + var = copy_to_parallel_region(var, "spatial") + + x = x.to(dtype) + mean = mean.to(dtype) + var = var.to(dtype) + + # apply the normalization + if self.affine: + x = _normalize_transform_kernel(x, mean, var, self.weight.reshape(-1, 1, 1), self.bias.reshape(-1, 1, 1), self.eps) + else: + x = _normalize_kernel(x, mean, var, self.eps) + + return x + +class DistributedLayerNorm(nn.Module): + """ + This is a lightweight wrapper which only computed norm across channels. + This norm breaks equivariance since the norm across channels is different per grid + point. + """ + + def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): + super().__init__() + + assert comm.get_size("matmul") == 1 + + self.norm = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype) + + if elementwise_affine: + # set up weight sharing and sharding + self.norm.weight.is_shared_mp = ["model"] + self.norm.weight.sharded_dims_mp = [None] + if bias: + self.norm.bias.is_shared_mp = ["model"] + self.norm.bias.sharded_dims_mp = [None] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # assume input is NCHW, so we transpose + xt = torch.transpose(x, 1, 3) + xn = self.norm(xt) + x = torch.transpose(xn, 1, 3).contiguous() + + return x diff --git a/fme/ace/models/makani_mpu/layers.py b/fme/ace/models/makani_mpu/layers.py new file mode 100644 index 000000000..32a4a2ecc --- /dev/null +++ b/fme/ace/models/makani_mpu/layers.py @@ -0,0 +1,512 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.amp import custom_fwd, custom_bwd + +# from makani.utils import comm +from fme.ace.utils import comm + +# parallel helpers +from physicsnemo.distributed.utils import compute_split_shapes +from physicsnemo.distributed.mappings import reduce_from_parallel_region +from physicsnemo.distributed.mappings import scatter_to_parallel_region +from physicsnemo.distributed.mappings import gather_from_parallel_region +from physicsnemo.distributed.mappings import copy_to_parallel_region + +# use some distributed routines from torch harmonics +from torch_harmonics.distributed import distributed_transpose_azimuth as distributed_transpose_w +from torch_harmonics.distributed import distributed_transpose_polar as distributed_transpose_h + + +class _DistMatmulHelper(torch.autograd.Function): + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, X, weight, bias, inp_group_name, out_group_name): + # store some variables + ctx.save_for_backward(X, weight, bias) + ctx.out_group_name = out_group_name + + # matrix multiplication + xconv = F.conv2d(X, weight, bias=None) + + # reduce + if comm.get_size(inp_group_name) > 1: + dist.all_reduce(xconv, group=comm.get_group(inp_group_name)) + + # add bias + if bias is not None: + xconvbias = xconv + bias + else: + xconvbias = xconv + + return xconvbias + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + X, weight, bias = ctx.saved_tensors + gname = ctx.out_group_name + + # do the bwd pass on dgrad + grad_input = F.conv_transpose2d(grad_out, weight, bias=None) + + # reduce across nodes + if comm.get_size(gname) > 1: + dgrad_handle = dist.all_reduce(grad_input, group=comm.get_group(gname), async_op=True) + + # weight grad + grad_weight = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1), bias=None).transpose(0, 1) + + if bias is not None: + grad_bias = torch.sum(grad_out, dim=(0, 2, 3), keepdim=True) + else: + grad_bias = None + + if comm.get_size(gname) > 1: + dgrad_handle.wait() + + return grad_input, grad_weight, grad_bias, None, None + + +class DistributedMatmul(nn.Module): + def __init__(self, inp_dim, out_dim, input_format="nchw", comm_inp_name="fin", comm_out_name="fout", bias=True): + super(DistributedMatmul, self).__init__() + + # get sizes + self.comm_inp_name = comm_inp_name + self.comm_out_name = comm_out_name + comm_inp_size = comm.get_size(self.comm_inp_name) + comm_out_size = comm.get_size(self.comm_out_name) + + # split: + assert inp_dim % comm_inp_size == 0, f"Error, the size of input feature dim ({inp_dim}) has to be evenly divisible by the input feature comm dim ({comm_inp_size})" + assert out_dim % comm_out_size == 0, f"Error, the size of output feature dim ({out_dim}) has to be evenly divisible by the output feature comm dim ({comm_out_size})" + + # compute reduced dims + inp_dim_local = inp_dim // comm_inp_size + out_dim_local = out_dim // comm_out_size + + # parameters + if input_format == "nchw": + self.weight = nn.Parameter(torch.ones(out_dim_local, inp_dim_local, 1, 1)) + self.weight.is_shared_mp = ["spatial"] + self.weight.sharded_dims_mp = [self.comm_out_name, self.comm_inp_name, None, None] + self.matmul_handle = F.conv2d + elif input_format == "traditional": + self.weight = nn.Parameter(torch.ones(out_dim_local, inp_dim_local)) + self.weight.sharded_dims_mp = [self.comm_out_name, self.comm_inp_name] + self.matmul_handle = F.linear + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # bias + self.bias = None + if bias: + if input_format == "nchw": + self.bias = nn.Parameter(torch.zeros(1, out_dim_local, 1, 1)) + self.bias.is_shared_mp = ["spatial"] + self.bias.sharded_dims_mp = [None, self.comm_out_name, None, None] + elif input_format == "traditional": + self.bias = nn.Parameter(torch.zeros(out_dim_local)) + self.bias.sharded_dims_mp = [self.comm_out_name] + + def forward(self, x): + x_cp = copy_to_parallel_region(x, self.comm_out_name) + x_loc = self.matmul_handle(x_cp, self.weight, bias=None) + x_out = reduce_from_parallel_region(x_loc, self.comm_inp_name) + if self.bias is not None: + x_out = x_out + self.bias + + return x_out + + +# distributed encoder/decoder +class DistributedEncoderDecoder(nn.Module): + def __init__(self, num_layers, input_dim, output_dim, hidden_dim, act_layer, gain=1.0, input_format="nchw", comm_inp_name="fin", comm_out_name="fout"): + super(DistributedEncoderDecoder, self).__init__() + + # get comms + comm_inp_size = comm.get_size(comm_inp_name) + comm_out_size = comm.get_size(comm_out_name) + + print("using DistributedEncoderDecoder") + + # get list of modules + encoder_modules = [] + current_dim = input_dim + comm_inp_name_tmp = comm_inp_name + comm_out_name_tmp = comm_out_name + for i in range(num_layers - 1): + encoder_modules.append( + DistributedMatmul(current_dim, hidden_dim, input_format=input_format, comm_inp_name=comm_inp_name_tmp, comm_out_name=comm_out_name_tmp, bias=True) + ) + + # proper initialization + # scale = math.sqrt(2.0 / current_dim) + # nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + encoder_modules.append(act_layer()) + current_dim = hidden_dim + comm_inp_name_tmp, comm_out_name_tmp = (comm_out_name_tmp, comm_inp_name_tmp) + + # final layer + encoder_modules.append(DistributedMatmul(current_dim, output_dim, input_format=input_format, comm_inp_name=comm_inp_name_tmp, comm_out_name=comm_out_name_tmp, bias=False)) + + # proper initialization of final layer + # scale = math.sqrt(gain / current_dim) + # nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + # create fwd sequence + self.fwd = nn.Sequential(*encoder_modules) + + # store the comm names for in and out so that they can be queried + self.comm_inp_name = comm_inp_name + self.comm_out_name = comm_out_name_tmp + + def forward(self, x): + return self.fwd(x) + + +# more complicated layers +class DistributedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + output_bias=True, + input_format="nchw", + comm_inp_name="fin", + comm_hidden_name="fout", + act_layer=nn.GELU, + drop_rate=0.0, + drop_type="iid", + checkpointing=False, + gain=1.0, + ): + super().__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # sanity checks: + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # get effective embedding size: + comm_inp_size = comm.get_size(comm_inp_name) + comm_hid_size = comm.get_size(comm_hidden_name) + + self.fc1 = DistributedMatmul(in_features, hidden_features, input_format=input_format, comm_inp_name=comm_inp_name, comm_out_name=comm_hidden_name, bias=True) + + # initialize the weights correctly + scale = math.sqrt(2.0 / in_features) + nn.init.normal_(self.fc1.weight, mean=0.0, std=scale) + nn.init.constant_(self.fc1.bias, 0.0) + + self.fc2 = DistributedMatmul(hidden_features, out_features, input_format=input_format, comm_inp_name=comm_hidden_name, comm_out_name=comm_inp_name, bias=output_bias) + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features) + nn.init.normal_(self.fc2.weight, mean=0.0, std=scale) + if self.fc2.bias is not None: + nn.init.constant_(self.fc2.bias, 0.0) + + self.act = act_layer() + + if drop_rate > 0.0: + if drop_type == "iid": + self.drop = nn.Dropout(drop_rate) + elif drop_type == "features": + self.drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + self.drop = nn.Identity() + + def fwd(self, x): + # do the mlp + # first layer + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + + # second layer + x = self.fc2(x) + x = self.drop(x) + + return x + + @torch.compiler.disable(recursive=False) + def _checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def forward(self, x): + if self.checkpointing: + return self._checkpoint_forward(x) + else: + return self.fwd(x) + + +class DistributedPatchEmbed(nn.Module): + def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768, input_is_matmul_parallel=False, output_is_matmul_parallel=True): + super().__init__() + + # store params + self.input_parallel = input_is_matmul_parallel + self.output_parallel = output_is_matmul_parallel + + # get comm sizes: + matmul_comm_size = comm.get_size("matmul") + spatial_comm_size = comm.get_size("spatial") + + # compute parameters + assert (img_size[1] // patch_size[1]) % spatial_comm_size == 0, "Error, make sure that the spatial comm size evenly divides patched W" + num_patches = ((img_size[1] // patch_size[1]) // spatial_comm_size) * (img_size[0] // patch_size[0]) + self.img_size = (img_size[0], img_size[1] // spatial_comm_size) + self.patch_size = patch_size + self.num_patches = num_patches + + # get effective embedding size: + if self.output_parallel: + assert embed_dim % matmul_comm_size == 0, "Error, the embed_dim needs to be divisible by matmul_parallel_size" + out_chans_local = embed_dim // matmul_comm_size + else: + out_chans_local = embed_dim + + # the weights of this layer is shared across spatial parallel ranks + self.proj = nn.Conv2d(in_chans, out_chans_local, kernel_size=patch_size, stride=patch_size) + + # make sure we reduce them across rank + self.proj.weight.is_shared_mp = ["spatial"] + self.proj.bias.is_shared_mp = ["spatial"] + + # gather shapes + self.gather_shapes = compute_split_shapes(in_chans, comm.get_size("matmul")) + + def forward(self, x): + if self.input_parallel: + x = gather_from_parallel_region(x, 1, self.gather_shapes, "matmul") + + if self.output_parallel: + x = copy_to_parallel_region(x, "matmul") + + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +class DistributedAttention(nn.Module): + """Distributed Attention layer""" + + def __init__( + self, + dim, + input_format="traditional", + comm_inp_name="fin", + comm_hidden_name="fout", + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + + assert num_heads % comm.get_size(comm_hidden_name) == 0, "heads are not evenly split across model ranks" + self.num_heads_local = num_heads // comm.get_size(comm_hidden_name) + self.head_dim = dim // self.num_heads + + self.comm_inp_name = comm_inp_name + self.comm_hidden_name = comm_hidden_name + + self.qkv = DistributedMatmul(dim, dim * 3, input_format, comm_inp_name=comm_inp_name, comm_out_name=comm_hidden_name, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop_rate = attn_drop_rate + self.proj = DistributedMatmul(dim, dim, input_format, comm_inp_name=comm_hidden_name, comm_out_name=comm_inp_name, bias=False) + if proj_drop_rate > 0.0: + self.proj_drop = nn.Dropout(proj_drop_rate) + else: + self.proj_drop = nn.Identity() + + # set up weight sharing, depends on norm type + if isinstance(self.q_norm, nn.LayerNorm): + if hasattr(self.q_norm, "weight"): + self.q_norm.weight.is_shared_mp = [] + if hasattr(self.q_norm, "bias"): + self.q_norm.bias.is_shared_mp = [] + + if isinstance(self.k_norm, nn.LayerNorm): + if hasattr(self.k_norm, "weight"): + self.k_norm.weight.is_shared_mp = [] + if hasattr(self.k_norm, "bias"): + self.k_norm.bias.is_shared_mp = [] + + def forward(self, x): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads_local, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate) + + # transpose back + x = x.transpose(1, 2).reshape(B, N, self.num_heads_local * self.head_dim) + + # this is distributed again + x = self.proj(x) + + # generally we have to be super careful with dropout layers, since + # those are normalized over the dropouts. That would need to be reduced across nodes + x = self.proj_drop(x) + + return x + + +@torch.compile +def compl_mul_add_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + tmp = torch.einsum("bkixys,kiot->stbkoxy", a, b) + res = torch.stack([tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1) + c + return res + + +@torch.compile +def compl_mul_add_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + cc = torch.view_as_complex(c) + tmp = torch.einsum("bkixy,kio->bkoxy", ac, bc) + res = tmp + cc + return torch.view_as_real(res) + + +class DistributedAFNO2Dv2(nn.Module): + def __init__( + self, + hidden_size, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1, + hidden_size_factor=1, + input_is_matmul_parallel=False, + output_is_matmul_parallel=False, + use_complex_kernels=False, + ): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + # get comm sizes: + matmul_comm_size = comm.get_size("matmul") + self.spatial_comm_size = comm.get_size("spatial") + + # select fft function handles + if self.spatial_comm_size > 1: + self.fft_handle = distributed_rfft2.apply + self.ifft_handle = distributed_irfft2.apply + else: + self.fft_handle = torch.fft.rfft2 + self.ifft_handle = torch.fft.irfft2 + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.gather_shapes = compute_split_shapes(self.num_blocks, matmul_comm_size) + self.num_blocks_local = self.gather_shapes[comm.get_rank("matmul")] + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + self.mult_handle = compl_mul_add_fwd_c if use_complex_kernels else compl_mul_add_fwd + + # model paralellism + self.input_is_matmul_parallel = input_is_matmul_parallel + self.output_is_matmul_parallel = output_is_matmul_parallel + + # new + # these weights need to be synced across all spatial ranks! + self.w1 = nn.Parameter(self.scale * torch.randn(self.num_blocks_local, self.block_size, self.block_size * self.hidden_size_factor, 2)) + self.b1 = nn.Parameter(self.scale * torch.randn(self.num_blocks_local, self.block_size * self.hidden_size_factor, 1, 1, 2)) + self.w2 = nn.Parameter(self.scale * torch.randn(self.num_blocks_local, self.block_size * self.hidden_size_factor, self.block_size, 2)) + self.b2 = nn.Parameter(self.scale * torch.randn(self.num_blocks_local, self.block_size, 1, 1, 2)) + + # setting correct sharding and sharing + self.w1.is_shared_mp = ["spatial"] + self.w1.sharded_dims_mp = ["matmul", None, None, None] + + self.b1.is_shared_mp = ["spatial"] + self.b1.sharded_dims_mp = ["matmul", None, None, None, None] + + self.w2.is_shared_mp = ["spatial"] + self.w2.sharded_dims_mp = ["matmul", None, None, None] + + self.b2.is_shared_mp = ["spatial"] + self.b2.sharded_dims_mp = ["matmul", None, None, None, None] + + def forward(self, x): + if not self.input_is_matmul_parallel: + # distribute data + x = gather_from_parallel_region(x, 1, self.gather_shapes, "matmul") + + # bias + bias = x + + dtype = x.dtype + x = x.float() + B, C, H, W_local = x.shape + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + H_local = H // self.spatial_comm_size + W = W_local * self.spatial_comm_size + x = self.fft_handle(x, (H, W), (-2, -1), "ortho") + x = x.view(B, self.num_blocks_local, self.block_size, H_local, W // 2 + 1) + + # new + x = torch.view_as_real(x) + o2 = torch.zeros(x.shape, device=x.device) + + o1 = F.relu(self.mult_handle(x[:, :, :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, :], self.w1, self.b1)) + o2[:, :, :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, :] = self.mult_handle(o1, self.w2, self.b2) + + # finalize + x = F.softshrink(o2, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, C, H_local, W // 2 + 1) + x = self.ifft_handle(x, (H, W), (-2, -1), "ortho") + x = x.type(dtype) + bias + + # gather + if not self.output_is_matmul_parallel: + x = gather_from_parallel_region(x, 1, self.gather_shapes, "matmul") + + return x diff --git a/fme/ace/models/makani_mpu/mappings.py b/fme/ace/models/makani_mpu/mappings.py new file mode 100644 index 000000000..3c0545248 --- /dev/null +++ b/fme/ace/models/makani_mpu/mappings.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from typing import Any + +import torch +from torch.amp import custom_fwd, custom_bwd +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +# from makani.utils import comm +from fme.ace.utils import comm + +# torch utils +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +# we need those +from fme.ace.models.makani_mpu.helpers import _transpose + +# we need the parameter counter +from fme.ace.models.makani_models.helpers import count_parameters + + +class distributed_transpose(torch.autograd.Function): + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, x, dims, dim1_split_sizes, comm_id): + # WAR for a potential contig check torch bug for channels last contig tensors + x = x.contiguous() + xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=comm.get_group(comm_id)) + x = torch.cat(xlist, dim=dims[1]).contiguous() + ctx.dims = dims + ctx.dim0_split_sizes = dim0_split_sizes + ctx.comm_id = comm_id + return x + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, go): + dims = ctx.dims + dim0_split_sizes = ctx.dim0_split_sizes + # WAR for a potential contig check torch bug for channels last contig tensors + go = go.contiguous() + gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=comm.get_group(ctx.comm_id)) + gi = torch.cat(gilist, dim=dims[0]).contiguous() + return gi, None, None, None + + +# handler for additional gradient reductions +# helper for gradient reduction across channel parallel ranks +def init_gradient_reduction_hooks(model, device, reduction_buffer_count=1, broadcast_buffers=True, find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=False, verbose=False): + # early exit if we are not in a distributed setting: + if not dist.is_initialized(): + return model + + # set this to false in init and then find out if we can use it: + need_hooks = False + ddp_group = comm.get_group("data") + + # this is the trivial case + if comm.get_size("model") == 1: + # the simple case, we can just continue then + ddp_group = None + else: + # count parameters and reduction groups + num_parameters_total = 0 + num_parameters_shared_model = 0 + for param in model.parameters(): + # if it does not have any annotation, we assume it is shared between all model ranks + if not hasattr(param, "is_shared_mp"): + if verbose: + print(f"Parameter {param.name} has no sharing mode specified, settting to globally shared.") + param.is_shared_mp = ["model"] + + # add the sharing type to the dict + num_parameters_total += 1 + if "model" in param.is_shared_mp: + num_parameters_shared_model += 1 + + # if all parameters are shared between all model ranks, then the situation is easy + if num_parameters_shared_model == num_parameters_total: + # we can always use DDP + ddp_group = None + + # register some pre-multiply reduction hooks + if verbose: + print("Setting up gradient hooks to account for shared parameter multiplicity") + for param in model.parameters(): + param.register_hook(lambda grad: grad * float(comm.get_size("model"))) + else: + ddp_group = comm.get_group("data") + broadcast_buffers = False + need_hooks = True + + # compute bucket cap in MB: + if need_hooks: + # if we need hooks, we can only use a single reduction buffer: + reduction_buffer_count = 1 + + # determine size of model. Only local number of parameters is relevant: + _, _, local_parameter_size_bytes = count_parameters(model, device) + + # compute reduction buffer size + reduction_size_bytes = (local_parameter_size_bytes + reduction_buffer_count - 1) // reduction_buffer_count + reduction_size_mb = (reduction_size_bytes + 1048575) // 1048576 + + # we should fuse the first bucket with the others + dist._DEFAULT_FIRST_BUCKET_BYTES = reduction_size_mb * 1048576 + + # we can set up DDP and exit here + if verbose: + print("Setting up DDP communication hooks") + model = DistributedDataParallel( + model, + device_ids=[device], + output_device=device, + bucket_cap_mb=reduction_size_mb, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + process_group=ddp_group, + ) + + if not need_hooks: + return model + + if verbose: + print("Setting up custom communication hooks") + + # define comm hook: + def reduction_comm_hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + # allreduce everything first: + buff = bucket.buffer() + params = bucket.parameters() + + # define the grad reduction function + def grad_reduction(fut, grads, group, reduction="sum"): + # check if grads are complex + is_complex = [g.is_complex() for g in grads] + grads_real = [torch.view_as_real(g) if g.is_complex() else g for g in grads] + + # flatten + coalesced = _flatten_dense_tensors(grads_real) + + # reduce + if reduction == "sum": + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=comm.get_group(group), async_op=False) + elif reduction == "mean": + dist.all_reduce(coalesced, op=dist.ReduceOp.AVG, group=comm.get_group(group), async_op=False) + else: + raise NotImplementedError(f"Error, reduction {reduction} not supported.") + + # copy back + for buf, synced_real, is_comp in zip(grads, _unflatten_dense_tensors(coalesced, grads_real), is_complex): + if is_comp: + synced = torch.view_as_complex(synced_real) + else: + synced = synced_real + buf.copy_(synced) + + return bucket.buffer() + + # WAR: we need to add a workaround for complex gradients here, therefore we need to hack the allreduce step a little bit. + # once this is fixed, the below line can be uncommented and we can remove the hack + # get future for allreduce + # fut = dist.all_reduce(buff, op=dist.ReduceOp.AVG, group=comm.get_group("data"), async_op=True).get_future() + + # get future + fut = torch.futures.Future() + fut.set_result(bucket.buffer()) + + # get the data gradients first: + grads = [] + for p in params: + if p.grad is not None: + grads.append(p.grad.data) + + if grads: + fut = fut.then(lambda x: grad_reduction(x, grads=grads, group="data", reduction="mean")) + + # now go through the groups + for group in comm.get_comm_names(): + if group == "data": + continue + + # real first + grads = [] + for p in params: + if (p.grad is not None) and (group in p.is_shared_mp): + grads.append(p.grad.data) + + # append the new reduction functions + if grads: + fut = fut.then(lambda x: grad_reduction(x, grads=grads, group=group, reduction="sum")) + + return fut + + # register model comm hook + model.register_comm_hook(state=None, hook=reduction_comm_hook) + + return model diff --git a/fme/ace/models/makani_utils/checkpoint_helpers.py b/fme/ace/models/makani_utils/checkpoint_helpers.py new file mode 100644 index 000000000..04746826c --- /dev/null +++ b/fme/ace/models/makani_utils/checkpoint_helpers.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import glob +import re + +from collections import OrderedDict +from typing import Optional, Dict, Any + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +# from makani.utils import comm +from fme.ace.utils import comm +from fme.ace.models.makani_mpu.helpers import gather_uneven + +from physicsnemo.distributed.utils import split_tensor_along_dim + + +def get_latest_checkpoint_version(checkpoint_path): + try: + checkpoint_path = max(glob.glob(checkpoint_path.format(mp_rank=0, checkpoint_version="*")), key=os.path.getmtime) + pathname, _ = os.path.splitext(checkpoint_path) + latest_version = int(re.match(r"^.*?_v(\d{1,})$", pathname).groups()[0]) + except: + print(f"Could not identify version for checkpoint {checkpoint_path}. Skipping detection.") + latest_version = 0 + + return latest_version + + +def gather_model_state_dict(model: nn.Module, grads: Optional[bool]=False) -> OrderedDict: + # create empty dict to hold the state + state_dict = OrderedDict() + + # iterate over parameters and gather them from the ranks + for name, param in model.named_parameters(): + weight = param.clone() + if hasattr(param, "sharded_dims_mp"): + # gather the weight across all sharded dimensions + for d, group in enumerate(param.sharded_dims_mp): + if group is not None: + weight = gather_uneven(weight, d, group) + + state_dict[name] = weight.cpu() + + if grads: + if param.grad is not None: + grad = param.grad.clone() + if hasattr(param, "sharded_dims_mp"): + for d, group in enumerate(param.sharded_dims_mp): + if group is not None: + grad = gather_uneven(grad, d, group) + grad = grad.cpu() + else: + grad = None + + state_dict[name + ".grad"] = grad + + return state_dict + + +def scatter_model_state_dict(model: nn.Module, state_dict: OrderedDict, strict: Optional[bool] = True) -> OrderedDict(): + + # iterate over model parameters and split accordingly + for name, param in model.named_parameters(): + + # make sure that the parameter is in the state dict + if name in state_dict.keys(): + + # in this case, we need to distribute the weight + if hasattr(param, "sharded_dims_mp"): + + # make a copy + weight = state_dict[name].clone() + + # split if necessary + for d, group in enumerate(param.sharded_dims_mp): + # continue if there is nothing to do + if (group is None) or (comm.get_size(group) == 1): + continue + + weight = split_tensor_along_dim(weight, dim=d, num_chunks=comm.get_size(group))[comm.get_rank(group)] + + # update state dict + state_dict[name] = weight + + elif strict: + # TODO: maybe do at least a warning for non-strict mode + raise ValueError(f"Missing key {name}") + + return state_dict + + +def gather_optimizer_state_dict(model: nn.Module, optimizer: Optimizer) -> OrderedDict: + + # if optimizer is SGD, we can just return the local dict: + if isinstance(optimizer, torch.optim.SGD): + return optimizer.state_dict() + + # do sanity checks + if not (isinstance(optimizer, torch.optim.Adam) or isinstance(optimizer, torch.optim.AdamW)): + raise NotImplementedError("Error, only Adam and AdamW state can be stored in flexible format at the moment.") + + # state dict: + state_dict = optimizer.state_dict() + + # we need to copy the optimizer dict the hard way + optimizer_dict = OrderedDict() + optimizer_dict["param_groups"] = [] + for pgroup in state_dict["param_groups"]: + pdict = {key: value for key, value in pgroup.items()} + optimizer_dict["param_groups"].append(pdict) + + # check whether the corresponding model paramter is distributed. + # if yes, we need to gather it + optimizer_dict["state"] = {} + for index, param in enumerate(model.parameters()): + optimizer_dict["state"][index] = {"step": state_dict["state"][index]["step"].clone()} + if hasattr(param, "sharded_dims_mp"): + exp_avg = state_dict["state"][index]["exp_avg"].clone() + exp_avg_sq = state_dict["state"][index]["exp_avg_sq"].clone() + + # gather the optimizer state across all sharded dimensions + for d, group in enumerate(param.sharded_dims_mp): + if group is not None: + exp_avg = gather_uneven(exp_avg, d, group) + exp_avg_sq = gather_uneven(exp_avg_sq, d, group) + + optimizer_dict["state"][index]["exp_avg"] = exp_avg + optimizer_dict["state"][index]["exp_avg_sq"] = exp_avg_sq + else: + optimizer_dict["state"][index]["exp_avg"] = state_dict["state"][index]["exp_avg"].clone() + optimizer_dict["state"][index]["exp_avg_sq"] = state_dict["state"][index]["exp_avg_sq"].clone() + + return optimizer_dict + + +def scatter_optimizer_state_dict(model: nn.Module, optimizer: Optimizer, optimizer_state_dict: OrderedDict) -> OrderedDict(): + + # some sanity checks + # if optimizer is SGD, we can just return the local dict: + if isinstance(optimizer, torch.optim.SGD): + return optimizer_state_dict + + if not (isinstance(optimizer, torch.optim.Adam) or isinstance(optimizer, torch.optim.AdamW)): + raise NotImplementedError("Error, only Adam and AdamW state can be restored from flexible format at the moment.") + + # iterate over model parameters and split accordingly + for idp, param in enumerate(model.parameters()): + + # in this case, we need to distribute the weight + if hasattr(param, "sharded_dims_mp"): + + # clone the state + exp_avg = optimizer_state_dict["state"][idp]["exp_avg"].clone() + exp_avg_sq = optimizer_state_dict["state"][idp]["exp_avg_sq"].clone() + + for d, group in enumerate(param.sharded_dims_mp): + # continue if there is nothing to do + if (group is None) or (comm.get_size(group) == 1): + continue + + exp_avg = split_tensor_along_dim(exp_avg, dim=d, num_chunks=comm.get_size(group))[comm.get_rank(group)] + exp_avg_sq = split_tensor_along_dim(exp_avg_sq, dim=d, num_chunks=comm.get_size(group))[comm.get_rank(group)] + + # update the state dict + optimizer_state_dict["state"][idp]["exp_avg"] = exp_avg + optimizer_state_dict["state"][idp]["exp_avg_sq"] = exp_avg_sq + + return optimizer_state_dict + + +def prepend_prefix_to_state_dict( + state_dict: Dict[str, Any], + prefix: str, +) -> None: + r"""Append the prefix to states in state_dict in place. + + ..note:: + Given a `state_dict` from a local model, a DP/DDP model can load it by applying + `prepend_prefix_to_state_dict(state_dict, "module.")` before calling + :meth:`torch.nn.Module.load_state_dict`. + + Args: + state_dict (OrderedDict): a state-dict to be loaded to the model. + prefix (str): prefix. + """ + keys = list(state_dict.keys()) + for key in keys: + newkey = prefix + key + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if hasattr(state_dict, "_metadata"): + keys = list(state_dict._metadata.keys()) + for key in keys: + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + if len(key) >= 0: + newkey = prefix + key + state_dict._metadata[newkey] = state_dict._metadata.pop(key) diff --git a/fme/ace/models/makani_utils/makani_driver.py b/fme/ace/models/makani_utils/makani_driver.py new file mode 100644 index 000000000..6d45dc071 --- /dev/null +++ b/fme/ace/models/makani_utils/makani_driver.py @@ -0,0 +1,100 @@ +from typing import Optional, Dict +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +import torch.distributed as dist + +from fme.ace.utils import comm + +from fme.ace.models.makani_utils.checkpoint_helpers import gather_model_state_dict, prepend_prefix_to_state_dict, scatter_model_state_dict + +def _save_checkpoint_flexible( + checkpoint_path: str, + model: nn.Module, + loss: Optional[nn.Module] = None, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[lr_scheduler.LRScheduler] = None, + counters: Optional[Dict[str, int]] = None, +): + # checkpoint name + checkpoint_fname = checkpoint_path.format(mp_rank=0) + + # iterate over parameters and gather them from the ranks + if comm.get_size("model") > 1: + state_dict = gather_model_state_dict(model) + else: + state_dict = model.state_dict() + + # drop module prefix in case if DDP is being used + if isinstance(model, nn.parallel.DistributedDataParallel): + nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, "module.") + + store_dict = {"model_state": state_dict} + + if loss is not None: + store_dict["loss_state_dict"] = loss.state_dict() + + if optimizer is not None: + if comm.get_size("model") > 1: + store_dict["optimizer_state_dict"] = gather_optimizer_state_dict(model, optimizer) + else: + store_dict["optimizer_state_dict"] = optimizer.state_dict() + + if scheduler is not None: + store_dict["scheduler_state_dict"] = scheduler.state_dict() + + if counters is not None: + store_dict["iters"] = counters["iters"] + store_dict["epoch"] = counters["epoch"] + + # in flexible mode only rank 0 needs to save the data to disk + if comm.get_world_rank() == 0: + torch.save(store_dict, checkpoint_fname) + + return + +def _restore_checkpoint_flexible( + checkpoint_path: str, + model: nn.Module, + loss: Optional[nn.Module] = None, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[lr_scheduler.LRScheduler] = None, + counters: Optional[Dict[str, int]] = None, + strict: bool = True, +): + # when loading the weights in flexble mode we exclusively use mp_rank=0 and load them onto the cpu + checkpoint_fname = checkpoint_path.format(mp_rank=0) + checkpoint = torch.load(checkpoint_fname, map_location="cpu", weights_only=False) + + # this is reworked to avoid loading modules related to the SHT + state_dict = checkpoint["model_state"] + + if isinstance(model, nn.parallel.DistributedDataParallel): + # prepend module prefix to state dict: + prepend_prefix_to_state_dict(state_dict, "module.") + + if comm.get_size("model") > 1: + state_dict = scatter_model_state_dict(model, state_dict, strict) + + # load state dict + model.load_state_dict(state_dict, strict=strict) + + # the loss is also restored in the case that it has a state + if loss is not None: + loss.load_state_dict(checkpoint["loss_state_dict"]) + + # If finetuning, restore checkpoint does not load optimizer state, instead uses config specified lr. + if optimizer is not None: + if comm.get_size("model") > 1: + checkpoint["optimizer_state_dict"] = scatter_optimizer_state_dict(model, optimizer, checkpoint["optimizer_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if scheduler is not None: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + if counters is not None: + counters["iters"] = checkpoint["iters"] + counters["start_epoch"] = checkpoint["epoch"] diff --git a/fme/ace/models/modulus/s2convolutions.py b/fme/ace/models/modulus/s2convolutions.py index 22f569362..1cd27293d 100644 --- a/fme/ace/models/modulus/s2convolutions.py +++ b/fme/ace/models/modulus/s2convolutions.py @@ -24,6 +24,7 @@ import torch_harmonics as th import torch_harmonics.distributed as thd +from fme.core.distributed import Distributed # from tensorly.plugins import use_opt_einsum # use_opt_einsum('optimal') from tltorch.factorized_tensors.core import FactorizedTensor @@ -105,16 +106,8 @@ def __init__( if not self.separable: weight_shape += [out_channels] - if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): - self.modes_lat_local = self.inverse_transform.lmax_local - self.modes_lon_local = self.inverse_transform.mmax_local - self.lpad_local = self.inverse_transform.lpad_local - self.mpad_local = self.inverse_transform.mpad_local - else: - self.modes_lat_local = self.modes_lat - self.modes_lon_local = self.modes_lon - self.lpad = 0 - self.mpad = 0 + dist = Distributed.get_instance() + self.modes_lat_local, self.modes_lon_local = dist.get_local_modes(inverse_transform) # padded weights # if self.operator_type == 'diagonal': @@ -148,8 +141,13 @@ def __init__( self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2)) if self.operator_type == "dhconv": self.weight.is_shared_mp = ["matmul", "w"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "h" else: self.weight.is_shared_mp = ["matmul"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "w" + self.weight.sharded_dims_mp[-2] = "h" # get the contraction handle self._contract = get_contract_fun( @@ -158,6 +156,8 @@ def __init__( if bias: self.bias = nn.Parameter(scale * torch.zeros(1, out_channels, 1, 1)) + self.bias.is_shared_mp = ["model"] + self.bias.sharded_dims_mp = [None, None, None, None] def forward(self, x): # pragma: no cover dtype = x.dtype diff --git a/fme/ace/models/modulus/sfnonet.py b/fme/ace/models/modulus/sfnonet.py index de66056c8..5a91325c8 100644 --- a/fme/ace/models/modulus/sfnonet.py +++ b/fme/ace/models/modulus/sfnonet.py @@ -20,9 +20,6 @@ import torch import torch.nn as nn -# get spectral transforms from torch_harmonics -import torch_harmonics as th -import torch_harmonics.distributed as thd from torch.utils.checkpoint import checkpoint from .initialization import trunc_normal_ @@ -33,6 +30,13 @@ from .layers import MLP, DropPath, RealFFT2, SpectralAttention2d from .s2convolutions import SpectralAttentionS2, SpectralConvS2 +# for annotation of models +from dataclasses import dataclass +import physicsnemo +from physicsnemo.models.meta import ModelMetaData + +from fme.core.distributed import Distributed + # layer normalization try: from apex.normalization import FusedLayerNorm @@ -65,9 +69,10 @@ def __init__( ): super(SpectralFilterLayer, self).__init__() + dist = Distributed.get_instance() + if filter_type == "non-linear" and ( - isinstance(forward_transform, th.RealSHT) - or isinstance(forward_transform, thd.DistributedRealSHT) + isinstance(forward_transform, dist.th_real_sht()) ): self.filter = SpectralAttentionS2( forward_transform, @@ -97,8 +102,7 @@ def __init__( # spectral transform is passed to the module elif filter_type == "linear" and ( - isinstance(forward_transform, th.RealSHT) - or isinstance(forward_transform, thd.DistributedRealSHT) + isinstance(forward_transform, dist.th_real_sht()) ): self.filter = SpectralConvS2( forward_transform, @@ -151,8 +155,10 @@ def __init__( ): super(FourierNeuralOperatorBlock, self).__init__() - self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon) - self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon) + dist = Distributed.get_instance() + self.input_shape_loc, self.output_shape_loc = dist.get_input_out_shapes( + forward_transform, inverse_transform + ) # norm layer self.norm0 = norm_layer[0]() @@ -196,7 +202,7 @@ def __init__( self.norm1 = norm_layer[1]() if use_mlp == True: - MLPH = MLP + MLPH = dist.get_mlp(MLP) mlp_hidden_dim = int(embed_dim * mlp_ratio) self.mlp = MLPH( in_features=embed_dim, @@ -252,7 +258,7 @@ def forward(self, x): return x -class SphericalFourierNeuralOperatorNet(torch.nn.Module): +class SphericalFourierNeuralOperatorNetBase(torch.nn.Module): """ Spherical Fourier Neural Operator Network @@ -372,8 +378,8 @@ def __init__( spectral_layers: int = 3, checkpointing: int = 0, ): - super(SphericalFourierNeuralOperatorNet, self).__init__() - + super(SphericalFourierNeuralOperatorNetBase, self).__init__() + dist = Distributed.get_instance() self.params = params self.spectral_transform = ( params.spectral_transform @@ -482,13 +488,13 @@ def __init__( self.residual_filter_down = nn.Identity() self.residual_filter_up = nn.Identity() else: - self.residual_filter_down = th.RealSHT( + self.residual_filter_down = dist.th_real_sht()( *self.img_shape, lmax=modes_lat_residual, mmax=modes_lon_residual, grid=data_grid, ).float() - self.residual_filter_up = th.InverseRealSHT( + self.residual_filter_up = dist.th_inverse_real_sht()( *self.img_shape, lmax=modes_lat_residual, mmax=modes_lon_residual, @@ -497,8 +503,8 @@ def __init__( # prepare the spectral transforms if self.spectral_transform == "sht": - sht_handle = th.RealSHT - isht_handle = th.InverseRealSHT + sht_handle = dist.th_real_sht() + isht_handle = dist.th_inverse_real_sht() # set up self.trans_down = sht_handle( @@ -519,8 +525,8 @@ def __init__( raise NotImplementedError( "Residual filter factor is not implemented for FFT spectral transform" ) - fft_handle = th.RealFFT2 - ifft_handle = th.InverseRealFFT2 + fft_handle = dist.th_real_fft2() + ifft_handle = dist.th_inverse_real_fft2() # effective image size: self.img_shape_eff = ( @@ -548,11 +554,7 @@ def __init__( raise (ValueError("Unknown spectral transform")) # use the SHT/FFT to compute the local, downscaled grid dimensions - self.img_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) - self.img_shape_eff = (self.trans_down.nlat, self.trans_down.nlon) - self.h_loc = self.itrans.nlat - self.w_loc = self.itrans.nlon - + self.img_shape_loc,self.img_shape_eff,self.h_loc, self.w_loc = dist.set_image_shapes(self.trans_down,self.itrans_up, self.itrans ) # determine activation function if self.activation_function == "relu": self.activation_function = nn.ReLU @@ -571,9 +573,16 @@ def __init__( encoder_modules.append( nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True) ) + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] + if encoder_modules[-1].bias is not None: + encoder_modules[-1].bias.is_shared_mp = ["spatial"] encoder_modules.append(self.activation_function()) current_dim = encoder_hidden_dim + #final layer encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)) + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] self.encoder = nn.Sequential(*encoder_modules) # dropout @@ -582,21 +591,26 @@ def __init__( # pick norm layer if self.normalization_layer == "layer_norm": - norm_layer0 = partial( - nn.LayerNorm, - normalized_shape=(self.img_shape_loc[0], self.img_shape_loc[1]), - eps=1e-6, - ) - norm_layer1 = partial( - nn.LayerNorm, normalized_shape=(self.h_loc, self.w_loc), eps=1e-6 - ) + if dist.comm_get_size("spatial") > 1: + ## CHECK ME: norm_layer0 and norm_layer1, as coded in makani + norm_layer0 = partial(dist.layer_norm(), normalized_shape=(self.embed_dim), elementwise_affine=True, eps=1e-6) + norm_layer1 = norm_layer0 + ## CHECK ME: norm_layer0 and norm_layer1, as coded in ace + else: + norm_layer0 = partial( + dist.layer_norm(), + normalized_shape=(self.img_shape_loc[0], self.img_shape_loc[1]), + eps=1e-6, + ) + norm_layer1 = partial( + dist.layer_norm(), normalized_shape=(self.h_loc, self.w_loc), eps=1e-6 + ) elif self.normalization_layer == "instance_norm": norm_layer0 = partial( - nn.InstanceNorm2d, + dist.instance_norm_2d(), num_features=self.embed_dim, eps=1e-6, - affine=True, - track_running_stats=False, + affine=True ) norm_layer1 = norm_layer0 elif self.normalization_layer == "none": @@ -665,9 +679,16 @@ def __init__( decoder_modules.append( nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True) ) + # weight sharing + decoder_modules[-1].weight.is_shared_mp = ["spatial"] + # decoder_modules[-1].weight.sharded_dims_mp = [None, None, None, None] + if decoder_modules[-1].bias is not None: + decoder_modules[-1].bias.is_shared_mp = ["spatial"] decoder_modules.append(self.activation_function()) current_dim = decoder_hidden_dim decoder_modules.append(nn.Conv2d(current_dim, self.out_chans, 1, bias=False)) + # weight sharing + decoder_modules[-1].weight.is_shared_mp = ["spatial"] self.decoder = nn.Sequential(*decoder_modules) # learned position embedding @@ -679,7 +700,13 @@ def __init__( ) ) # self.pos_embed = nn.Parameter( torch.zeros(1, self.embed_dim, self.img_shape_eff[0], self.img_shape_eff[1]) ) - self.pos_embed.is_shared_mp = ["matmul"] + if dist.spatial_parallelism: + self.pos_embed.is_shared_mp = [] + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] + self.pos_embed.type = "direct" + else: + self.pos_embed.is_shared_mp = ["matmul"] + trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) @@ -747,3 +774,25 @@ def forward(self, x): x = self.decoder(x) return x + + +# this part exposes the model to modulus by constructing modulus Modules +@dataclass +class SphericalFourierNeuralOperatorNetMetaData(ModelMetaData): + name: str = "SphericalFourierNeuralOperatorNet" + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + +def init_sfno(): + """Helper function to initialize SFNO model""" + dist = Distributed.get_instance() + if dist.spatial_parallelism: + return physicsnemo.Module.from_torch( + SphericalFourierNeuralOperatorNetBase, + SphericalFourierNeuralOperatorNetMetaData() + ) + return SphericalFourierNeuralOperatorNetBase + +SphericalFourierNeuralOperatorNet = init_sfno() diff --git a/fme/ace/models/modulus/test_distributed_spectral_conv.py b/fme/ace/models/modulus/test_distributed_spectral_conv.py new file mode 100644 index 000000000..dce91b4fa --- /dev/null +++ b/fme/ace/models/modulus/test_distributed_spectral_conv.py @@ -0,0 +1,435 @@ +import os + +import torch +import torch.distributed as dist +from fme.core.device import get_device +from fme.core.testing import validate_tensor + +from .sfnonet import SphericalFourierNeuralOperatorNet, SFNO + + +from .layers import MLP, DropPath, RealFFT2, SpectralAttention2d, InverseRealFFT2 +from .s2convolutions import SpectralAttentionS2, SpectralConvS2 + +from fme.ace.models.makani_mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1, DistributedRealFFT2, DistributedInverseRealFFT2, DistributedRealFFT3, DistributedInverseRealFFT3 +from fme.ace.models.makani_mpu.mappings import init_gradient_reduction_hooks + +from fme.ace.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd +from physicsnemo.distributed.utils import split_tensor_along_dim +DIR = os.path.abspath(os.path.dirname(__file__)) + +# this computes a relative error compatible with torch.allclose or np.allclose +def relative_error(tensor1, tensor2): + return torch.sum(torch.abs(tensor1-tensor2)) / torch.sum(torch.abs(tensor2)) + +# this computes an absolute error compatible with torch.allclose or np.allclose +def absolute_error(tensor1, tensor2): + return torch.max(torch.abs(tensor1-tensor2)) + +def split_helper(tensor, dim=None, group=None): + with torch.no_grad(): + if (dim is not None) and dist.get_world_size(group=group): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + # split in dim + tensor_list_local = split_tensor_along_dim(tensor, dim=dim, num_chunks=gsize) + tensor_local = tensor_list_local[grank] + else: + tensor_local = tensor.clone() + + return tensor_local + + +def gather_helper(tensor, dim=None, group=None): + # get shapes + if (dim is not None) and (dist.get_world_size(group=group) > 1): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + shape_loc = torch.tensor([tensor.shape[dim]], dtype=torch.long, device=tensor.device) + shape_list = [torch.empty_like(shape_loc) for _ in range(dist.get_world_size(group=group))] + shape_list[grank] = shape_loc + dist.all_gather(shape_list, shape_loc, group=group) + tshapes = [] + for ids in range(gsize): + tshape = list(tensor.shape) + tshape[dim] = shape_list[ids].item() + tshapes.append(tuple(tshape)) + tens_gather = [torch.empty(tshapes[ids], dtype=tensor.dtype, device=tensor.device) for ids in range(gsize)] + tens_gather[grank] = tensor + dist.all_gather(tens_gather, tensor, group=group) + tensor_gather = torch.cat(tens_gather, dim=dim) + else: + tensor_gather = tensor.clone() + + return tensor_gather + +def _split_helper(tensor, w_group, h_group): + tensor_local = split_helper(tensor, dim=-1, group=w_group) + tensor_local = split_helper(tensor_local, dim=-2, group=h_group) + return tensor_local + + +def _gather_helper(tensor, w_group, h_group): + tensor_gather = gather_helper(tensor, dim=-2, group=h_group) + tensor_gather = gather_helper(tensor_gather, dim=-1, group=w_group) + + return tensor_gather + +def _split_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_local = split_helper(tensor, dim=hdim, group=h_group) + tensor_local = split_helper(tensor_local, dim=wdim, group=w_group) + return tensor_local + + +def _gather_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_gather = gather_helper(tensor, dim=hdim, group=h_group) + tensor_gather = gather_helper(tensor_gather, dim=wdim, group=w_group) + return tensor_gather + +def setup_test(): + from mpi4py import MPI + mpi_comm = MPI.COMM_WORLD.Dup() + mpi_comm_rank = mpi_comm.Get_rank() + mpi_comm_size = mpi_comm.Get_size() + if torch.cuda.is_available(): + if mpi_comm_rank == 0: + print("Running test on GPU") + local_rank = mpi_comm_rank % torch.cuda.device_count() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.cuda.manual_seed(333) + else: + if mpi_comm_rank == 0: + print("Running test on CPU") + device = torch.device("cpu") + torch.manual_seed(333) + return mpi_comm, device + +def _init_comms(): + # set up distributed + grid_size_h = int(os.getenv("GRID_H", 1)) + grid_size_w = int(os.getenv("GRID_W", 1)) + grid_size_e = int(os.getenv("GRID_E", 1)) + world_size = grid_size_h * grid_size_w * grid_size_e + + # init groups + comm.init( + model_parallel_sizes=[grid_size_h, grid_size_w, 1, 1], + model_parallel_names=["h", "w", "fin", "fout"], + data_parallel_sizes=[grid_size_e, -1], + data_parallel_names=["ensemble", "batch"], + ) + world_rank = comm.get_world_rank() + + # store comm group parameters + wrank = comm.get_rank("w") + hrank = comm.get_rank("h") + erank = comm.get_rank("ensemble") + w_group = comm.get_group("w") + h_group = comm.get_group("h") + e_group = comm.get_group("ensemble") + # initializing sht process groups just to be sure + thd.init(h_group, w_group) + + if world_rank == 0: + print(f"Running distributed tests on grid H x W x E = {grid_size_h} x {grid_size_w} x {grid_size_e}") + + return w_group, h_group, e_group, world_rank, world_size + +def test_distributed_fft2(): + verbose=True + mpi_comm, device = setup_test() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + + # 256, 512, 0, 32, 8, 1e-6 + # nlat, nlon, nalt, batch_size, num_chan, tol, + tol=1e-6 + B, C, H, W = 32, 8, 256, 512 + + # set up handles + forward_transform_local = RealFFT2(nlat=H, nlon=W).to(device) + forward_transform_dist = DistributedRealFFT2(nlat=H, nlon=W).to(device) + + # create tensors + inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device) + + ############################################################# + # local transform + ############################################################# + # FWD pass + inp_full.requires_grad = True + out_full = forward_transform_local(inp_full) + + # create grad for backward + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # BWD pass + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = _split_helper(inp_full, w_group, h_group) + inp_local.requires_grad = True + out_local = forward_transform_dist(inp_local) + + # BWD pass + ograd_local = _split_helper(ograd_full, w_group, h_group) + out_local = forward_transform_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + + # set eval dims + dims = (-1,-2,-3) + + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = _gather_helper(out_local, w_group, h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + assert err.item() <= tol + + ############################################################# + # evaluate BWD pass + ############################################################# + with torch.no_grad(): + igrad_gather_full = _gather_helper(igrad_local, w_group, h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of gradients: {err.item()}") + assert err.item() <= tol + +def _init_seed(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + return + +def test_distributed_ifft2(): + verbose=True + mpi_comm, device = setup_test() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + # 256, 512, 0, 32, 8, 1e-6 + # nlat, nlon, nalt, batch_size, num_chan, tol, + tol=1e-6 + B, C, H, W = 32, 8, 256, 512 + forward_transform_local = RealFFT2(nlat=H, nlon=W).to(device) + backward_transform_local = InverseRealFFT2(nlat=H, nlon=W).to(device) + backward_transform_dist = DistributedInverseRealFFT2(nlat=H, nlon=W).to(device) + + # create tensors + dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device) + inp_full = forward_transform_local(dummy_full) + + ############################################################# + # local transform + ############################################################# + # FWD pass + inp_full.requires_grad = True + out_full = backward_transform_local(inp_full) + + # create grad for backward + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # BWD pass + out_full.backward(ograd_full) + + # repeat once due to known irfft bug + inp_full.grad = None + out_full = backward_transform_local(inp_full) + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = _split_helper(inp_full, w_group, h_group) + inp_local.requires_grad = True + out_local = backward_transform_dist(inp_local) + + # BWD pass + ograd_local = _split_helper(ograd_full, w_group, h_group) + out_local = backward_transform_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + + # set eval dims + dims = (-1,-2,-3) + + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = _gather_helper(out_local, w_group, h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + assert err.item() <= tol + + ############################################################# + # evaluate BWD pass + ############################################################# + with torch.no_grad(): + igrad_gather_full = _gather_helper(igrad_local, w_group, h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of gradients: {err.item()}") + assert err.item() <= tol + comm.cleanup() + +def test_distributed_spectral_conv(): + tol=1e-6 + verbose=True + mpi_comm, device = setup_test() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + # set up handles + B, C, Hi, Wi, Ho, Wo = 32, 8, 256, 512, 256, 512 + print("world_rank", world_rank) + print("world_size", world_size) + + forward_transform_local = th.RealSHT(nlat=Hi, nlon=Wi).to(device) + inverse_transform_local = th.InverseRealSHT(nlat=Ho, nlon=Wo, lmax=forward_transform_local.lmax, mmax=forward_transform_local.mmax).to(device) + forward_transform_dist = thd.DistributedRealSHT(nlat=Hi, nlon=Wi).to(device) + inverse_transform_dist = thd.DistributedInverseRealSHT(nlat=Ho, nlon=Wo, lmax=forward_transform_dist.lmax, mmax=forward_transform_dist.mmax).to(device) + + _init_seed(333) + spect_conv_local = SpectralConvS2( + forward_transform_local, + inverse_transform_local, + C, + C, + operator_type="dhconv", + use_tensorly=False, + bias=True + ).to(device) + + spect_conv_dist = SpectralConvS2( + forward_transform_dist, + inverse_transform_dist, + C, + C, + operator_type="dhconv", + use_tensorly=False, + bias=True + ).to(device) + # set up wgrad reductions + spect_conv_dist = init_gradient_reduction_hooks( + spect_conv_dist, + device=device, + reduction_buffer_count=1, + broadcast_buffers=False, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + verbose=False, + ) + # make sure weights are the same: + with torch.no_grad(): + weight = _split_helper_conv(spect_conv_local.weight, hdim=-2, wdim=None, w_group=w_group, h_group=h_group) + print("spect_conv_local.weight",spect_conv_local.weight.shape) + print("weight",weight.shape) + print("spect_conv_dist.module.weight",spect_conv_dist.module.weight.shape) + spect_conv_dist.module.weight.copy_(weight) + spect_conv_dist.module.bias.copy_(spect_conv_local.bias) + + # input + _init_seed(444) + inp_full = torch.randn((B, C, Hi, Wi), dtype=torch.float32, device=device) + # ############################################################# + # # local transform + # ############################################################# + # # FWD pass + inp_full.requires_grad = True + out_full, _ = spect_conv_local(inp_full) + # create grad for backward + _init_seed(555) + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # # BWD pass + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + wgrad_full = spect_conv_local.weight.grad.clone() + bgrad_full = spect_conv_local.bias.grad.clone() + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = _split_helper_conv(inp_full, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + print("inp_local", inp_local.shape) + print("inp_full", inp_full.shape) + inp_local.requires_grad = True + out_local, _ = spect_conv_dist(inp_local) + + # BWD pass + ograd_local = _split_helper_conv(ograd_full, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + print("ograd_local", ograd_local.shape) + print("ograd_full", ograd_full.shape) + out_local, _ = spect_conv_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + wgrad_local = spect_conv_dist.module.weight.grad.clone() + bgrad_local = spect_conv_dist.module.bias.grad.clone() + mpi_comm.Barrier() + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = _gather_helper_conv(out_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + # self.assertTrue(err.item() <= tol) + assert err.item() <= tol + mpi_comm.Barrier() + ############################################################# + # evaluate input grads + ############################################################# + with torch.no_grad(): + igrad_gather_full = _gather_helper_conv(igrad_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of input gradients: {err.item()}") + assert err.item() <= tol + # self.assertTrue(err.item() <= tol) + mpi_comm.Barrier() + ############################################################# + # evaluate Weight grads + ############################################################# + with torch.no_grad(): + wgrad_gather_full = _gather_helper_conv(wgrad_local, hdim=-2, wdim=None, w_group=w_group, h_group=h_group) + print("wgrad_gather_full", wgrad_local.shape) + print("wgrad_gather_full", wgrad_gather_full.shape) + err = relative_error(wgrad_gather_full, wgrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of weight gradients: {err.item()}") + # self.assertTrue(err.item() <= tol) + assert err.item() <= tol + mpi_comm.Barrier() + + with torch.no_grad(): + bgrad_gather_list = [torch.empty_like(bgrad_local) for _ in range(world_size)] + bgrad_gather_list[world_rank] = bgrad_local + dist.all_gather(bgrad_gather_list, bgrad_local, group=None) + errs = [] + for bgrad_gather_full in bgrad_gather_list: + errs.append(relative_error(bgrad_gather_full, bgrad_full)) + err = torch.mean(torch.stack(errs, dim=0)) + if verbose and (world_rank == 0): + print(f"final relative error of bias gradients: {err.item()}") + assert err.item() <= tol + comm.cleanup() diff --git a/fme/ace/models/modulus/test_sfnonet_spatial_dist.py b/fme/ace/models/modulus/test_sfnonet_spatial_dist.py new file mode 100644 index 000000000..2406dc243 --- /dev/null +++ b/fme/ace/models/modulus/test_sfnonet_spatial_dist.py @@ -0,0 +1,305 @@ +import os +import sys +import torch +from fme.core.device import get_device + +from .sfnonet import SFNO + +DIR = os.path.abspath(os.path.dirname(__file__)) + +from fme.ace.models.makani_mpu.mappings import init_gradient_reduction_hooks +from fme.core.distributed import Distributed +from fme.core.dataset.test_helper import gather_helper_conv, relative_error, init_seed +from fme.ace.models.makani_utils.makani_driver import _save_checkpoint_flexible, _restore_checkpoint_flexible +from physicsnemo.distributed.mappings import reduce_from_parallel_region + +def test_sfnonet_without_sp(): + ## without domain decomposition + os.environ['H_PARALLEL_SIZE'] = '1' + verbose=False + input_channels = 3 + output_channels = 3 + img_shape = (8, 16) + n_samples = 4 + embed_dim=16 + num_layers=2 + + model = SFNO( + params=None, + embed_dim=embed_dim, + num_layers=num_layers, + # operator_type="dhconv", + # normalization_layer="layer_norm", + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ) + # must initialize on CPU to get the same results on GPU + inp_full = torch.randn(n_samples, input_channels, *img_shape) + inp_full.requires_grad = True + # with torch.no_grad(): + out_full = model(inp_full) + loss_full = torch.sum(out_full) + + # perform backward pass + loss_full.backward() + igrad_full = inp_full.grad.clone() + + assert out_full.shape == (n_samples, output_channels, *img_shape) + tmp_path="testdata" + torch.save(out_full, os.path.join(tmp_path, "out_full.pt")) + torch.save(inp_full, os.path.join(tmp_path, "inp_full.pt")) + torch.save(loss_full, os.path.join(tmp_path, "loss_full.pt")) + torch.save(igrad_full, os.path.join(tmp_path, "igrad_full.pt")) + + _save_checkpoint_flexible(checkpoint_path=os.path.join(tmp_path, "checkpoint.pt"), + model=model) + +def test_sfnonet_with_sp(): + tmp_path="testdata" + os.environ['H_PARALLEL_SIZE'] = '2' + verbose=False + input_channels = 3 + output_channels = 3 + img_shape = (8, 16) + n_samples = 4 + embed_dim=16 + num_layers=2 + + dist = Distributed.get_instance() + mpi_comm_rank = dist.local_rank + + w_group = dist.comm_get_group("w") + h_group = dist.comm_get_group("h") + world_rank = dist.rank + + device=get_device() + + model_dist = SFNO( + params=None, + embed_dim=embed_dim, + num_layers=num_layers, + # operator_type="dhconv", + # normalization_layer="layer_norm", + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + + # save reduction hooks + model_dist = init_gradient_reduction_hooks( + model_dist, + device=device, + reduction_buffer_count=1, + broadcast_buffers=False, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + verbose=True, + ) + + # load checkpoint + _restore_checkpoint_flexible(checkpoint_path=os.path.join(tmp_path, "checkpoint.pt"), + model=model_dist) + + # must initialize on CPU to get the same results on GPU + # inp_full = torch.randn(n_samples, input_channels, *img_shape) + inp_full = torch.load(os.path.join(tmp_path, "inp_full.pt")) + + # split input + # inputs: ntimes, nsamples, h, w + this_shape=(inp_full.shape[-2],inp_full.shape[-1]) + ## Create a leaf variable + inp_local_host = (inp_full[:,:,*dist.get_local_slices(this_shape)]).detach().clone() + inp_local=inp_local_host.to(device) + inp_local.requires_grad = True + if world_rank == 0: + print("inp_full", inp_full.shape) + print("inp_local", inp_local.shape) + + out_local = model_dist(inp_local) + loss_dist = reduce_from_parallel_region(torch.sum(out_local), "model") + loss_dist.backward() + igrad_local = inp_local.grad.clone() + + out_full = torch.load(os.path.join(tmp_path, "out_full.pt")) + + with torch.no_grad(): + out_full_device=out_full.to(device) + out_gather_full = gather_helper_conv(out_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(out_gather_full, out_full_device) + if world_rank == 0: + print(f"final relative error of output: {err.item()}") + assert err < 0.0006 + + loss_full=torch.load(os.path.join(tmp_path, "loss_full.pt")) + + with torch.no_grad(): + loss_full_device=loss_full.to(device) + err = relative_error(loss_dist, loss_full) + if (world_rank == 0): + print(f"final relative error of loss: {err.item()}") + # mpi_comm.Barrier() + assert err < 1e-3 + + ############################################################# + # evaluate BWD pass + ############################################################# + # dgrad + igrad_full = torch.load(os.path.join(tmp_path, "igrad_full.pt")) + with torch.no_grad(): + igrad_full_device=igrad_full.to(device) + igrad_gather_full = gather_helper_conv(igrad_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(igrad_gather_full, igrad_full_device) + if (world_rank == 0): + print(f"final relative error of input gradient: {err.item()}") + # cleanup + assert err < 1e-3 + +def test_sfnonet_spatial_dist_output_is_unchanged(): + # torch.manual_seed(0) + # fix seed + init_seed(333) + ## without domain decomposition + verbose=False + input_channels = 3 + output_channels = 3 + img_shape = (8, 16) + n_samples = 4 + embed_dim=16 + num_layers=2 + with Distributed.non_distributed(): + model = SFNO( + params=None, + embed_dim=embed_dim, + num_layers=num_layers, + # operator_type="dhconv", + # normalization_layer="layer_norm", + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ) + # must initialize on CPU to get the same results on GPU + inp_full = torch.randn(n_samples, input_channels, *img_shape) + inp_full.requires_grad = True + # with torch.no_grad(): + out_full = model(inp_full) + loss_full = torch.sum(out_full) + + # # perform backward pass + # loss_full.backward() + # igrad_full = inp_full.grad.clone() + + # assert out_full.shape == (n_samples, output_channels, *img_shape) + # tmp_path="testdata" + # torch.save(out_full, "testdata/test_sfnonet_spatial_dist_output_is_unchanged.pt") + + # # get state dict + # state_dict_full = checkpoint_helpers.gather_model_state_dict(model, grads=False) + + + # torch.save(out_full, os.path.join(tmp_path, "out_full.pt")) + # # torch.save(igrad_full, os.path.join(tmp_path, "igrad_full.pt")) + # # if mpi_comm_rank == 0: + # _save_checkpoint_flexible(checkpoint_path=os.path.join(tmp_path, "checkpoint.pt"), + # model=model) + # # delete local model + # del model + # print("--------------------------------------------------") + + # dist.shutdown() + + # ## with domain decomposition + # dist = Distributed() + # mpi_comm_rank = dist.local_rank + + # h_parallel_size=2 + # w_parallel_size=2 + # dist._init_distributed(h_parallel_size = h_parallel_size, w_parallel_size=w_parallel_size) + # # thd.init(h_parallel_size, w_parallel_size) + + + # comm = dist.get_comm() + # w_group = comm.get_group("w") + # h_group = comm.get_group("h") + # world_rank = comm.get_world_rank() + # device=get_device() + + # model_dist = SFNO( + # params=None, + # embed_dim=embed_dim, + # num_layers=num_layers, + # # operator_type="dhconv", + # # normalization_layer="layer_norm", + # img_shape=img_shape, + # in_chans=input_channels, + # out_chans=output_channels, + # ).to(device) + + + # # save reduction hooks + # model_dist = init_gradient_reduction_hooks( + # model_dist, + # device=device, + # reduction_buffer_count=1, + # broadcast_buffers=False, + # find_unused_parameters=False, + # gradient_as_bucket_view=True, + # static_graph=True, + # verbose=True, + # ) + + # # load checkpoint + # _restore_checkpoint_flexible(checkpoint_path=os.path.join(tmp_path, "checkpoint.pt"), + # model=model_dist) + + # # split input + # inp_full_device=inp_full.to(device) + # inp_local= split_helper_conv(inp_full_device, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + # inp_local.requires_grad = True + # if world_rank == 0: + # print("inp_full", inp_full.shape) + # print("inp_local", inp_local.shape) + + # # with torch.no_grad(): + # out_local = model_dist(inp_local) + # loss_dist = reduce_from_parallel_region(torch.sum(out_local), "model") + # loss_dist.backward() + # igrad_local = inp_local.grad.clone() + + # # get weights and wgrads + # state_dict_gather_full = checkpoint_helpers.gather_model_state_dict(model_dist, grads=True) + + # # output + # # mpi_comm.Barrier() + # with torch.no_grad(): + # out_full_device=out_full.to(device) + # out_gather_full = gather_helper_conv(out_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + # err = relative_error(out_gather_full, out_full_device) + # if world_rank == 0: + # print(f"final relative error of output: {err.item()}") + # # mpi_comm.Barrier() + # assert err < 1e-3 + # # loss + # with torch.no_grad(): + # loss_full_device=loss_full.to(device) + # err = relative_error(loss_dist, loss_full) + # if (world_rank == 0): + # print(f"final relative error of loss: {err.item()}") + # # mpi_comm.Barrier() + # assert err < 1e-3 + # ############################################################# + # # evaluate BWD pass + # ############################################################# + # # dgrad + # with torch.no_grad(): + # igrad_full_device=igrad_full.to(device) + # igrad_gather_full = gather_helper_conv(igrad_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + # err = relative_error(igrad_gather_full, igrad_full_device) + # if (world_rank == 0): + # print(f"final relative error of input gradient: {err.item()}") + # # cleanup + # assert err < 1e-3 + # # mpi_comm.Barrier() + + # comm.cleanup() diff --git a/fme/ace/registry/sfno.py b/fme/ace/registry/sfno.py index e1c636a3e..50028a3e3 100644 --- a/fme/ace/registry/sfno.py +++ b/fme/ace/registry/sfno.py @@ -6,7 +6,7 @@ ) from fme.ace.models.modulus.sfnonet import SphericalFourierNeuralOperatorNet from fme.ace.registry.registry import ModuleConfig, ModuleSelector - +from fme.core.distributed import Distributed # this is based on the call signature of SphericalFourierNeuralOperatorNet at # https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501 @@ -46,16 +46,13 @@ def build( n_out_channels: int, img_shape: tuple[int, int], ): - sfno_net = SphericalFourierNeuralOperatorNet( + return SphericalFourierNeuralOperatorNet( params=self, in_chans=n_in_channels, out_chans=n_out_channels, img_shape=img_shape, ) - return sfno_net - - @ModuleSelector.register("SFNO-v0.1.0") @dataclasses.dataclass class SFNO_V0_1_0(ModuleConfig): diff --git a/fme/ace/test_spatial_parallelism/test_annual_sp.py b/fme/ace/test_spatial_parallelism/test_annual_sp.py new file mode 100644 index 000000000..64a6b77f8 --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_annual_sp.py @@ -0,0 +1,153 @@ +import numpy as np +import pytest +import torch +import os + +import datetime +import xarray as xr +import cftime +import matplotlib.pyplot as plt +from fme.core.device import get_device +from fme.core.mask_provider import MaskProvider +from fme.ace.aggregator.inference.annual import ( + GlobalMeanAnnualAggregator, +) +from fme.core.coordinates import ( + LatLonCoordinates, +) +from fme.core.distributed import Distributed + +TIMESTEP = datetime.timedelta(hours=6) +tmp_path="testdata" +def int(): + # Define the sizes + torch.manual_seed(42) + torch.set_printoptions(precision=12, sci_mode=False) # Adjust precision as needed + n_sample = 4 # Example batch size + n_lat = 18 # Example size for latitude + n_lon = 36 # Example size for longitude + n_time = 365 * 4 * 30 + + # Create the latitude tensor + lat = torch.linspace(-90, 90, n_lat) + + # Create the longitude tensor + lon = torch.linspace(0, 360, n_lon) + + input_tensor = torch.randn(n_sample, n_time, n_lat, n_lon) + input_tensor.is_shared_mp = ["spatial"] + # torch.save(input_tensor, os.path.join(tmp_path, "input-annual-test.pt")) + return lat, lon, n_lat, n_lon, n_sample, n_time, input_tensor + +def test_annual_aggregator_wo_sp(): + os.environ['H_PARALLEL_SIZE'] = '1' + os.environ['W_PARALLEL_SIZE'] = '1' + # need to have two actual full years of data for plotting to get exercised + lat_host, lon_host, n_lat, n_lon, n_sample, n_time, input_ = int() + # input_ = torch.load(os.path.join(tmp_path, "input-annual-test.pt")) + + device=get_device() + lat = lat_host.to(device) + lon = lon_host.to(device) + coords = LatLonCoordinates(lat=lat, lon=lon) + gridded_ops = coords.get_gridded_operations(mask_provider=MaskProvider()) + + agg = GlobalMeanAnnualAggregator( + ops=gridded_ops, timestep=TIMESTEP + ) + data = {"a": input_.to(device)} + + time = xr.DataArray( + [ + [ + ( + cftime.DatetimeProlepticGregorian(2000, 1, 1) + + i * datetime.timedelta(hours=6) + ) + for i in range(n_time) + ] + for _ in range(n_sample) + ], + dims=["sample", "time"], + ) + agg.record_batch(time, data) + logs = agg.get_logs(label="test") + print(logs) + assert len(logs) > 0 + assert "test/a" in logs + assert isinstance(logs["test/a"], plt.Figure) + figure=logs["test/a"] + figure.savefig("test.png") + for ax in figure.get_axes(): + # Loop through each line in the axes + for line in ax.get_lines(): + x_data = line.get_xdata() + y_data = line.get_ydata() + np.savetxt(tmp_path+"/y_data.txt", y_data) + np.savetxt(tmp_path+"/x_data.txt",x_data) + # print("X data:", x_data) + # print("Y data:", y_data) + + +def test_annual_aggregator_w_sp(): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '1' + # need to have two actual full years of data for plotting to get exercised + lat_host, lon_host, n_lat, n_lon, n_sample, n_time, input_ = int() + dist = Distributed.get_instance() + device=get_device() + inp_local_host = (input_[:,:,*dist.get_local_slices((n_lat,n_lon))]).detach().clone() + inp_local=inp_local_host.to(device) + + device=get_device() + lat = lat_host.to(device) + lon = lon_host.to(device) + coords = LatLonCoordinates(lat=lat, lon=lon) + gridded_ops = coords.get_gridded_operations(mask_provider=MaskProvider()) + + agg = GlobalMeanAnnualAggregator( + ops=gridded_ops, timestep=TIMESTEP + ) + data = {"a": inp_local} + + time = xr.DataArray( + [ + [ + ( + cftime.DatetimeProlepticGregorian(2000, 1, 1) + + i * datetime.timedelta(hours=6) + ) + for i in range(n_time) + ] + for _ in range(n_sample) + ], + dims=["sample", "time"], + ) + agg.record_batch(time, data) + logs = agg.get_logs(label="test") + y_data_ref = np.loadtxt(tmp_path+"/y_data.txt") + x_data_ref = np.loadtxt(tmp_path+"/x_data.txt") + print(logs) + + if len(logs) > 0: + # assert len(logs) > 0 + assert "test/a" in logs + assert isinstance(logs["test/a"], plt.Figure) + figure =logs["test/a"] + figure.savefig("test-sp.png") + # for ax in figure.get_axes(): + # # Loop through each line in the axes + # for line in ax.get_lines(): + # x_data = line.get_xdata() + # y_data = line.get_ydata() + # np.testing.assert_allclose(y_data_ref, y_data, rtol=5e-05, atol=1e-10, equal_nan=False, err_msg='', verbose=True) + # np.testing.assert_allclose(x_data_ref, x_data, rtol=1e-08, atol=1e-13, equal_nan=False, err_msg='', verbose=True) +def test_annual_aggregator_w_sp2(): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '1' + # need to have two actual full years of data for plotting to get exercised + lat_host, lon_host, n_lat, n_lon, n_sample, n_time, input_ = int() + dist = Distributed.get_instance() + device=get_device() + inp_local_host = (input_[:,:,*dist.get_local_slices((n_lat,n_lon))]).detach().clone() + inp_local=inp_local_host.to(device) diff --git a/fme/ace/test_spatial_parallelism/test_coordinates_sp.py b/fme/ace/test_spatial_parallelism/test_coordinates_sp.py new file mode 100644 index 000000000..c1e285476 --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_coordinates_sp.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +import torch +import os +from fme.core.distributed import Distributed +from fme.core.coordinates import ( + LatLonCoordinates, +) + +from fme.core.mask_provider import MaskProvider +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + + +tmp_path="testdata" +def int(): + # Define the sizes + torch.manual_seed(42) + torch.set_printoptions(precision=12, sci_mode=False) # Adjust precision as needed + batch_size = 4 # Example batch size + nlat = 180 # Example size for latitude + nlon = 360 # Example size for longitude + + # Create the latitude tensor + lat = torch.linspace(-90, 90, nlat) + + # Create the longitude tensor + lon = torch.linspace(0, 360, nlon) + + input_tensor = torch.rand(batch_size, nlat, nlon) + return lat, lon, nlat, nlon, batch_size, input_tensor + +def test_lat_lon_ops_from_coords_wo_sp(): + os.environ['H_PARALLEL_SIZE'] = '1' + lat, lon, nlat, nlon, batch_size, input_ = int() + coords = LatLonCoordinates(lat=lat, lon=lon) + gridded_ops = coords.get_gridded_operations(mask_provider=MaskProvider()) + result = gridded_ops.area_weighted_mean(input_, name="T_0") + print(result) + +def test_lat_lon_ops_from_coords_w_sp(): + lat_host, lon_host, nlat, nlon, batch_size, input_= int() + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + dist = Distributed.get_instance() + device=get_device() + lat = lat_host.to(device) + lon = lon_host.to(device) + coords = LatLonCoordinates(lat=lat, lon=lon) + gridded_ops = coords.get_gridded_operations(mask_provider=MaskProvider()) + inp_local_host = (input_[:,*dist.get_local_slices((nlat,nlon))]).detach().clone() + inp_local=inp_local_host.to(device) + result_local = gridded_ops.area_weighted_mean(inp_local, name="T_0") + print("result_local",result_local) + print("dist._distributed", dist._distributed) + result = dist.reduce_mean(result_local) + print("result", result) + torch.testing.assert_close(result.to("cpu"), torch.tensor([0.501348972321, 0.500475645065, 0.500276744366, 0.497519612312])) diff --git a/fme/ace/test_spatial_parallelism/test_distributed_fft2_and_ifft2.py b/fme/ace/test_spatial_parallelism/test_distributed_fft2_and_ifft2.py new file mode 100644 index 000000000..76027b1dc --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_distributed_fft2_and_ifft2.py @@ -0,0 +1,209 @@ +import os +import torch + + +from fme.ace.models.modulus.layers import RealFFT2, InverseRealFFT2 + +from fme.ace.models.makani_mpu.fft import DistributedRealFFT2, DistributedInverseRealFFT2 +from fme.ace.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd +from test_helper import relative_error, _split_helper, _gather_helper + +def setup_test(): + from mpi4py import MPI + mpi_comm = MPI.COMM_WORLD.Dup() + mpi_comm_rank = mpi_comm.Get_rank() + mpi_comm_size = mpi_comm.Get_size() + if torch.cuda.is_available(): + if mpi_comm_rank == 0: + print("Running test on GPU") + local_rank = mpi_comm_rank % torch.cuda.device_count() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.cuda.manual_seed(333) + else: + if mpi_comm_rank == 0: + print("Running test on CPU") + device = torch.device("cpu") + torch.manual_seed(333) + return mpi_comm, device + +def _init_comms(): + # set up distributed + os.environ['GRID_H'] = '2' + os.environ['GRID_W'] = '2' + grid_size_h = int(os.getenv("GRID_H", 1)) + grid_size_w = int(os.getenv("GRID_W", 1)) + grid_size_e = int(os.getenv("GRID_E", 1)) + world_size = grid_size_h * grid_size_w * grid_size_e + + # init groups + comm.init( + model_parallel_sizes=[grid_size_h, grid_size_w, 1, 1], + model_parallel_names=["h", "w", "fin", "fout"], + data_parallel_sizes=[grid_size_e, -1], + data_parallel_names=["ensemble", "batch"], + ) + world_rank = comm.get_world_rank() + + # store comm group parameters + wrank = comm.get_rank("w") + hrank = comm.get_rank("h") + erank = comm.get_rank("ensemble") + w_group = comm.get_group("w") + h_group = comm.get_group("h") + e_group = comm.get_group("ensemble") + # initializing sht process groups just to be sure + thd.init(h_group, w_group) + + if world_rank == 0: + print(f"Running distributed tests on grid H x W x E = {grid_size_h} x {grid_size_w} x {grid_size_e}") + + return w_group, h_group, e_group, world_rank, world_size + +def test_distributed_fft2(): + verbose=True + mpi_comm, device = setup_test() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + + # 256, 512, 0, 32, 8, 1e-6 + # nlat, nlon, nalt, batch_size, num_chan, tol, + tol=1e-6 + B, C, H, W = 32, 8, 256, 512 + + # set up handles + forward_transform_local = RealFFT2(nlat=H, nlon=W).to(device) + forward_transform_dist = DistributedRealFFT2(nlat=H, nlon=W).to(device) + + # create tensors + inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device) + + ############################################################# + # local transform + ############################################################# + # FWD pass + inp_full.requires_grad = True + out_full = forward_transform_local(inp_full) + + # create grad for backward + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # BWD pass + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = _split_helper(inp_full, w_group, h_group) + inp_local.requires_grad = True + out_local = forward_transform_dist(inp_local) + + # BWD pass + ograd_local = _split_helper(ograd_full, w_group, h_group) + out_local = forward_transform_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + + # set eval dims + dims = (-1,-2,-3) + + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = _gather_helper(out_local, w_group, h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + assert err.item() <= tol + + ############################################################# + # evaluate BWD pass + ############################################################# + with torch.no_grad(): + igrad_gather_full = _gather_helper(igrad_local, w_group, h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of gradients: {err.item()}") + assert err.item() <= tol + + +def test_distributed_ifft2(): + verbose=True + mpi_comm, device = setup_test() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + # 256, 512, 0, 32, 8, 1e-6 + # nlat, nlon, nalt, batch_size, num_chan, tol, + tol=1e-6 + B, C, H, W = 32, 8, 256, 512 + forward_transform_local = RealFFT2(nlat=H, nlon=W).to(device) + backward_transform_local = InverseRealFFT2(nlat=H, nlon=W).to(device) + backward_transform_dist = DistributedInverseRealFFT2(nlat=H, nlon=W).to(device) + + # create tensors + dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device) + inp_full = forward_transform_local(dummy_full) + + ############################################################# + # local transform + ############################################################# + # FWD pass + inp_full.requires_grad = True + out_full = backward_transform_local(inp_full) + + # create grad for backward + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # BWD pass + out_full.backward(ograd_full) + + # repeat once due to known irfft bug + inp_full.grad = None + out_full = backward_transform_local(inp_full) + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = _split_helper(inp_full, w_group, h_group) + inp_local.requires_grad = True + out_local = backward_transform_dist(inp_local) + + # BWD pass + ograd_local = _split_helper(ograd_full, w_group, h_group) + out_local = backward_transform_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + + # set eval dims + dims = (-1,-2,-3) + + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = _gather_helper(out_local, w_group, h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + assert err.item() <= tol + + ############################################################# + # evaluate BWD pass + ############################################################# + with torch.no_grad(): + igrad_gather_full = _gather_helper(igrad_local, w_group, h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of gradients: {err.item()}") + assert err.item() <= tol + comm.cleanup() diff --git a/fme/ace/test_spatial_parallelism/test_distributed_spectral_conv.py b/fme/ace/test_spatial_parallelism/test_distributed_spectral_conv.py new file mode 100644 index 000000000..9d37d769a --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_distributed_spectral_conv.py @@ -0,0 +1,220 @@ +import os + +import torch +from fme.core.distributed import Distributed +# import torch.distributed as dist +from fme.core.device import get_device +from fme.core.testing import validate_tensor + +from fme.ace.models.modulus.layers import MLP, DropPath, RealFFT2, SpectralAttention2d, InverseRealFFT2 +from fme.ace.models.modulus.s2convolutions import SpectralAttentionS2, SpectralConvS2 + +from fme.ace.models.makani_mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1, DistributedRealFFT2, DistributedInverseRealFFT2, DistributedRealFFT3, DistributedInverseRealFFT3 +from fme.ace.models.makani_mpu.mappings import init_gradient_reduction_hooks + +from fme.ace.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd +from physicsnemo.distributed.utils import split_tensor_along_dim +from test_helper import gather_helper_conv, split_helper_conv, relative_error, _split_helper, _gather_helper, init_seed + +DIR = os.path.abspath(os.path.dirname(__file__)) + +def setup_test(): + from mpi4py import MPI + mpi_comm = MPI.COMM_WORLD.Dup() + mpi_comm_rank = mpi_comm.Get_rank() + mpi_comm_size = mpi_comm.Get_size() + # if torch.cuda.is_available(): + # if mpi_comm_rank == 0: + # print("Running test on GPU") + # local_rank = mpi_comm_rank % torch.cuda.device_count() + # device = torch.device(f"cuda:{local_rank}") + # torch.cuda.set_device(device) + # torch.cuda.manual_seed(333) + # else: + # if mpi_comm_rank == 0: + # print("Running test on CPU") + # device = torch.device("cpu") + torch.manual_seed(333) + return mpi_comm, device + +def _init_comms(): + # set up distributed + os.environ['GRID_H'] = '2' + os.environ['GRID_W'] = '2' + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + grid_size_h = int(os.getenv("GRID_H", 1)) + grid_size_w = int(os.getenv("GRID_W", 1)) + grid_size_e = int(os.getenv("GRID_E", 1)) + world_size = grid_size_h * grid_size_w * grid_size_e + + # init groups + dist=Distributed.get_instance() + world_rank = dist.rank + + # store comm group parameters + wrank = dist.comm_get_rank("w") + hrank = dist.comm_get_rank("h") + erank = dist.comm_get_rank("ensemble") + w_group = dist.comm_get_group("w") + h_group = dist.comm_get_group("h") + e_group = dist.comm_get_group("ensemble") + + if world_rank == 0: + print(f"Running distributed tests on grid H x W x E = {grid_size_h} x {grid_size_w} x {grid_size_e}") + + return w_group, h_group, e_group, world_rank, world_size + +def test_distributed_spectral_conv(): + tol=1e-6 + verbose=True + # mpi_comm, device = setup_test() + device=get_device() + w_group, h_group, e_group, world_rank, world_size = _init_comms() + # set up handles + B, C, Hi, Wi, Ho, Wo = 32, 8, 256, 512, 256, 512 + print("world_rank", world_rank) + print("world_size", world_size) + + # input + init_seed(444) + inp_full = torch.randn((B, C, Hi, Wi), dtype=torch.float32, device=device) + + init_seed(333) + + ## without domain decomposition + with Distributed.force_non_distributed(): + forward_transform_local = th.RealSHT(nlat=Hi, nlon=Wi).to(device) + inverse_transform_local = th.InverseRealSHT(nlat=Ho, nlon=Wo, lmax=forward_transform_local.lmax, mmax=forward_transform_local.mmax).to(device) + + spect_conv_local = SpectralConvS2( + forward_transform_local, + inverse_transform_local, + C, + C, + operator_type="dhconv", + use_tensorly=False, + bias=True + ).to(device) + + + # ############################################################# + # # local transform + # ############################################################# + # # FWD pass + inp_full.requires_grad = True + out_full, _ = spect_conv_local(inp_full) + # create grad for backward + init_seed(555) + with torch.no_grad(): + # create full grad + ograd_full = torch.randn_like(out_full) + + # # BWD pass + out_full.backward(ograd_full) + igrad_full = inp_full.grad.clone() + wgrad_full = spect_conv_local.weight.grad.clone() + bgrad_full = spect_conv_local.bias.grad.clone() + + forward_transform_dist = thd.DistributedRealSHT(nlat=Hi, nlon=Wi).to(device) + inverse_transform_dist = thd.DistributedInverseRealSHT(nlat=Ho, nlon=Wo, lmax=forward_transform_dist.lmax, mmax=forward_transform_dist.mmax).to(device) + + spect_conv_dist = SpectralConvS2( + forward_transform_dist, + inverse_transform_dist, + C, + C, + operator_type="dhconv", + use_tensorly=False, + bias=True + ).to(device) + # set up wgrad reductions + spect_conv_dist = init_gradient_reduction_hooks( + spect_conv_dist, + device=device, + reduction_buffer_count=1, + broadcast_buffers=False, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + verbose=False, + ) + # make sure weights are the same: + with torch.no_grad(): + weight = split_helper_conv(spect_conv_local.weight, hdim=-2, wdim=None, w_group=w_group, h_group=h_group) + print("spect_conv_local.weight",spect_conv_local.weight.shape) + print("weight",weight.shape) + print("spect_conv_dist.module.weight",spect_conv_dist.module.weight.shape) + spect_conv_dist.module.weight.copy_(weight) + spect_conv_dist.module.bias.copy_(spect_conv_local.bias) + + ############################################################# + # distributed transform + ############################################################# + # FWD pass + inp_local = split_helper_conv(inp_full, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + print("inp_local", inp_local.shape) + print("inp_full", inp_full.shape) + inp_local.requires_grad = True + out_local, _ = spect_conv_dist(inp_local) + + # BWD pass + ograd_local = split_helper_conv(ograd_full, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + print("ograd_local", ograd_local.shape) + print("ograd_full", ograd_full.shape) + out_local, _ = spect_conv_dist(inp_local) + out_local.backward(ograd_local) + igrad_local = inp_local.grad.clone() + wgrad_local = spect_conv_dist.module.weight.grad.clone() + bgrad_local = spect_conv_dist.module.bias.grad.clone() + dist.barrier() + ############################################################# + # evaluate FWD pass + ############################################################# + with torch.no_grad(): + out_gather_full = gather_helper_conv(out_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(out_gather_full, out_full) + if verbose and (world_rank == 0): + print(f"final relative error of output: {err.item()}") + # self.assertTrue(err.item() <= tol) + assert err.item() <= tol + dist.barrier() + ############################################################# + # evaluate input grads + ############################################################# + with torch.no_grad(): + igrad_gather_full = gather_helper_conv(igrad_local, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + err = relative_error(igrad_gather_full, igrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of input gradients: {err.item()}") + assert err.item() <= tol + # self.assertTrue(err.item() <= tol) + dist.barrier() + ############################################################# + # evaluate Weight grads + ############################################################# + with torch.no_grad(): + wgrad_gather_full = gather_helper_conv(wgrad_local, hdim=-2, wdim=None, w_group=w_group, h_group=h_group) + print("wgrad_gather_full", wgrad_local.shape) + print("wgrad_gather_full", wgrad_gather_full.shape) + err = relative_error(wgrad_gather_full, wgrad_full) + if verbose and (world_rank == 0): + print(f"final relative error of weight gradients: {err.item()}") + # self.assertTrue(err.item() <= tol) + assert err.item() <= tol + dist.barrier() + + with torch.no_grad(): + bgrad_gather_list = [torch.empty_like(bgrad_local) for _ in range(world_size)] + bgrad_gather_list[world_rank] = bgrad_local + dist.all_gather(bgrad_gather_list, bgrad_local, group=None) + errs = [] + for bgrad_gather_full in bgrad_gather_list: + errs.append(relative_error(bgrad_gather_full, bgrad_full)) + err = torch.mean(torch.stack(errs, dim=0)) + if verbose and (world_rank == 0): + print(f"final relative error of bias gradients: {err.item()}") + assert err.item() <= tol + dist.shutdown() diff --git a/fme/ace/test_spatial_parallelism/test_helper.py b/fme/ace/test_spatial_parallelism/test_helper.py new file mode 100644 index 000000000..069924f37 --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_helper.py @@ -0,0 +1,95 @@ +import os + +import torch +import torch.distributed as dist +from physicsnemo.distributed.utils import split_tensor_along_dim + +from pathlib import Path + +def create_directory(directory_name): + """ + Create a directory if it does not already exist. + + Parameters: + directory_name (str): The name of the directory to create. + + Returns: + None + """ + try: + # Using pathlib to create the directory + Path(directory_name).mkdir(parents=True, exist_ok=True) + print(f"Directory '{directory_name}' created successfully or already exists.") + except Exception as e: + print(f"An error occurred while creating the directory: {e}") + +# this computes a relative error compatible with torch.allclose or np.allclose +def relative_error(tensor1, tensor2): + return torch.sum(torch.abs(tensor1-tensor2)) / torch.sum(torch.abs(tensor2)) + +# this computes an absolute error compatible with torch.allclose or np.allclose +def absolute_error(tensor1, tensor2): + return torch.max(torch.abs(tensor1-tensor2)) + +def gather_helper(tensor, dim=None, group=None): + # get shapes + if (dim is not None) and (dist.get_world_size(group=group) > 1): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + shape_loc = torch.tensor([tensor.shape[dim]], dtype=torch.long, device=tensor.device) + shape_list = [torch.empty_like(shape_loc) for _ in range(dist.get_world_size(group=group))] + shape_list[grank] = shape_loc + dist.all_gather(shape_list, shape_loc, group=group) + tshapes = [] + for ids in range(gsize): + tshape = list(tensor.shape) + tshape[dim] = shape_list[ids].item() + tshapes.append(tuple(tshape)) + tens_gather = [torch.empty(tshapes[ids], dtype=tensor.dtype, device=tensor.device) for ids in range(gsize)] + tens_gather[grank] = tensor + dist.all_gather(tens_gather, tensor, group=group) + tensor_gather = torch.cat(tens_gather, dim=dim) + else: + tensor_gather = tensor.clone() + + return tensor_gather + +def split_helper(tensor, dim=None, group=None): + with torch.no_grad(): + if (dim is not None) and dist.get_world_size(group=group): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + # split in dim + tensor_list_local = split_tensor_along_dim(tensor, dim=dim, num_chunks=gsize) + tensor_local = tensor_list_local[grank] + else: + tensor_local = tensor.clone() + + return tensor_local + +def gather_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_gather = gather_helper(tensor, dim=hdim, group=h_group) + tensor_gather = gather_helper(tensor_gather, dim=wdim, group=w_group) + return tensor_gather + +def split_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_local = split_helper(tensor, dim=hdim, group=h_group) + tensor_local = split_helper(tensor_local, dim=wdim, group=w_group) + return tensor_local + +def _split_helper(tensor, w_group, h_group): + tensor_local = split_helper(tensor, dim=-1, group=w_group) + tensor_local = split_helper(tensor_local, dim=-2, group=h_group) + return tensor_local + +def _gather_helper(tensor, w_group, h_group): + tensor_gather = gather_helper(tensor, dim=-2, group=h_group) + tensor_gather = gather_helper(tensor_gather, dim=-1, group=w_group) + + return tensor_gather + +def init_seed(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + return diff --git a/fme/ace/test_spatial_parallelism/test_loss_sp.py b/fme/ace/test_spatial_parallelism/test_loss_sp.py new file mode 100644 index 000000000..558710c88 --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_loss_sp.py @@ -0,0 +1,119 @@ +import pytest +import torch +import os +import numpy as np +from fme.core import metrics +from fme.core.device import get_device +from fme.core.gridded_ops import GriddedOperations, LatLonOperations +from fme.core.loss import ( + AreaWeightedMSELoss, + CRPSLoss, + EnergyScoreLoss, + GlobalMeanLoss, + LossConfig, + StepLossConfig, + VariableWeightingLoss, + WeightedMappingLoss, + _construct_weight_tensor, +) +from fme.core.normalizer import StandardNormalizer +from fme.core.packer import Packer +from fme.ace.aggregator.one_step.reduced import MeanAggregator +from fme.core.distributed import Distributed + +@pytest.mark.parametrize("global_mean_type", [None]) +def test_loss_builds_and_runs_wo_sp(global_mean_type): + nx=8 + ny=8 + torch.manual_seed(0) + data_tensor=torch.randn(1, 2, nx, ny, device=get_device()) + example_data = { + "a":data_tensor , + } + area_weights = torch.ones(nx,ny).to(get_device())*5 + + config = LossConfig(global_mean_type=global_mean_type) + loss = config.build( + reduction="mean", + gridded_operations=LatLonOperations(area_weights), + ) + + x = torch.randn(1, 2, nx, ny, device=get_device()) + y = torch.randn(1, 2, nx, ny, device=get_device()) + + result = loss(x, y) + + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=result, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + tmp_path="testdata-loss" + torch.save(area_weights, os.path.join(tmp_path, "area_weights.pt")) + torch.save(data_tensor, os.path.join(tmp_path, "example_data.pt")) + torch.save(x, os.path.join(tmp_path, "x.pt")) + torch.save(y, os.path.join(tmp_path, "y.pt")) + print("loss", logs["metrics/loss"] ) + torch.save(logs["metrics/loss"], os.path.join(tmp_path, "loss.pt")) + +@pytest.mark.parametrize("global_mean_type", [None]) +def test_loss_builds_and_runs_with_sp(global_mean_type): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + nx=8 + ny=8 + tmp_path="testdata-loss" + tensor_data_host = torch.load(os.path.join(tmp_path, "example_data.pt")) + x_host=torch.load(os.path.join(tmp_path, "x.pt")) + y_host=torch.load(os.path.join(tmp_path, "y.pt")) + loss_serial=torch.load(os.path.join(tmp_path, "loss.pt")) + + torch.manual_seed(0) + + # tensor_data_host=torch.randn(1, 2, nx, ny) + area_weights = torch.ones(nx,ny)*5.0 + aggregator = MeanAggregator(LatLonOperations(area_weights)) + dist = Distributed.get_instance() + this_shape=(tensor_data_host.shape[-2],tensor_data_host.shape[-1]) + tensor_data_local_host = (tensor_data_host[:,:,*dist.get_local_slices(this_shape)]).detach().clone() + tensor_data_local=tensor_data_local_host.to(dist.local_rank) + example_data = { + "a": tensor_data_local + } + + config = LossConfig(global_mean_type=global_mean_type) + loss = config.build( + reduction="mean", + gridded_operations=LatLonOperations(area_weights), + ) + + # x_host = torch.randn(1, 2, nx, ny) + # y_host = torch.randn(1, 2, nx, ny) + + this_shape_x=(x_host.shape[-2],x_host.shape[-1]) + x_local_host = (x_host[:,:,*dist.get_local_slices(this_shape_x)]).detach().clone() + x_local=x_local_host.to(dist.local_rank) + y_local_host = (y_host[:,:,*dist.get_local_slices(this_shape_x)]).detach().clone() + y_local=y_local_host.to(dist.local_rank) + + result_local = loss(x_local, y_local) + + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=result_local, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + + error_tol=1e-13 + logs = aggregator.get_logs(label="metrics") + # print("lost", logs["metrics/loss"] ) + # print("loss_serial", loss_serial ) + rel_diff = np.abs(loss_serial - logs["metrics/loss"])/loss_serial + assert rel_diff < error_tol diff --git a/fme/ace/test_spatial_parallelism/test_reduced_sp.py b/fme/ace/test_spatial_parallelism/test_reduced_sp.py new file mode 100644 index 000000000..df680d3ba --- /dev/null +++ b/fme/ace/test_spatial_parallelism/test_reduced_sp.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest +import torch +import os +from fme.ace.aggregator.one_step.reduced import MeanAggregator +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + +from fme.core.distributed import Distributed + + +def test_loss_wo_sp(): + """ + Basic test the aggregator combines loss correctly + with multiple batches and no distributed training. + """ + nx=8 + ny=8 + torch.manual_seed(0) + example_data = { + "a": torch.randn(1, 2, nx, ny, device=get_device()), + } + area_weights = torch.ones(nx,ny).to(get_device()) + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=1.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + aggregator.record_batch( + loss=2.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 1.5 + aggregator.record_batch( + loss=3.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 2.0 + +def test_loss_with_sp(): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + nx=8 + ny=8 + torch.manual_seed(0) + tensor_data_host=torch.randn(1, 2, nx, ny) + area_weights = torch.ones(nx,ny) + aggregator = MeanAggregator(LatLonOperations(area_weights)) + dist = Distributed.get_instance() + this_shape=(tensor_data_host.shape[-2],tensor_data_host.shape[-1]) + tensor_data_local_host = (tensor_data_host[:,:,*dist.get_local_slices(this_shape)]).detach().clone() + tensor_data_local=tensor_data_local_host.to(dist.local_rank) + + example_data = { + "a": tensor_data_local + } + + aggregator.record_batch( + loss=1.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + aggregator.record_batch( + loss=2.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 1.5 + aggregator.record_batch( + loss=3.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + print("lost", logs["metrics/loss"] ) + assert logs["metrics/loss"] == 2.0 diff --git a/fme/ace/test_train_sp.py b/fme/ace/test_train_sp.py new file mode 100755 index 000000000..d033da3e9 --- /dev/null +++ b/fme/ace/test_train_sp.py @@ -0,0 +1,544 @@ +import copy +import dataclasses +import pathlib +import subprocess +import tempfile +import unittest.mock +from typing import Literal +from pathlib import Path +import dacite +import numpy as np +import pytest +import torch +import xarray as xr +import yaml +import fme +from fme.ace.aggregator.inference.main import InferenceEvaluatorAggregatorConfig +from fme.ace.aggregator.one_step.main import OneStepAggregatorConfig +from fme.ace.data_loading.config import DataLoaderConfig +from fme.ace.data_loading.inference import ( + InferenceDataLoaderConfig, + InferenceInitialConditionIndices, +) +from fme.ace.inference.data_writer.file_writer import FileWriterConfig +from fme.ace.inference.data_writer.main import DataWriterConfig +from fme.ace.inference.evaluator import InferenceEvaluatorConfig +from fme.ace.inference.evaluator import main as inference_evaluator_main +from fme.ace.registry.test_hpx import ( + conv_next_block_config, + decoder_config, + down_sampling_block_config, + encoder_config, + output_layer_config, + recurrent_block_config, + up_sampling_block_config, +) +from fme.ace.stepper.single_module import StepperConfig +from fme.ace.stepper.time_length_probabilities import ( + TimeLengthProbabilities, + TimeLengthProbability, +) +from fme.ace.testing import ( + DimSizes, + MonthlyReferenceData, + save_nd_netcdf, + save_scalar_netcdf, +) +from fme.ace.train.train import build_trainer, prepare_directory +from fme.ace.train.train import main as train_main +from fme.ace.train.train_config import ( + InlineInferenceConfig, + TrainBuilders, + TrainConfig, + WeatherEvaluationConfig, +) +from fme.core.coordinates import ( + HEALPixCoordinates, + HorizontalCoordinates, + LatLonCoordinates, +) +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig +from fme.core.dataset.xarray import XarrayDataConfig +from fme.core.generics.trainer import ( + _restore_checkpoint, + count_parameters, + epoch_checkpoint_enabled, +) +from fme.core.logging_utils import LoggingConfig +from fme.core.loss import StepLossConfig +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.ocean import OceanConfig +from fme.core.optimization import OptimizationConfig +from fme.core.registry.corrector import CorrectorSelector +from fme.core.registry.module import ModuleSelector +from fme.core.scheduler import SchedulerConfig +from fme.core.step.single_module import SingleModuleStepConfig +from fme.core.step.step import StepSelector +from fme.core.testing.model import compare_restored_parameters +from fme.core.testing.wandb import mock_wandb +from fme.core.typing_ import Slice +from fme.core.distributed import Distributed +JOB_SUBMISSION_SCRIPT_PATH = ( + pathlib.PurePath(__file__).parent / "run-train-and-inference.sh" +) + +# @pytest.fixture +# def custom_tmp_path(request): +# # Create a temporary directory +# temp_dir = tempfile.mkdtemp() +# # Yield the path to the temporary directory +# yield Path(temp_dir) + +def _get_test_yaml_files( + *, + train_data_path, + valid_data_path, + monthly_data_filename: pathlib.Path | None, + results_dir, + global_means_path, + global_stds_path, + in_variable_names, + out_variable_names, + mask_name, + n_forward_steps=2, + nettype="SphericalFourierNeuralOperatorNet", + log_to_wandb=False, + max_epochs=1, + segment_epochs=1, + inference_forward_steps=2, + use_healpix=False, + crps_training=False, + save_per_epoch_diagnostics=False, + log_validation_maps=False, + skip_inline_inference=False, + time_buffer=1, +): + input_time_size = 1 + output_time_size = 1 + if nettype == "HEALPixRecUNet": + in_channels = len(in_variable_names) + out_channels = len(out_variable_names) + prognostic_variables = min( + out_channels, in_channels + ) # how many variables in/out share. + # in practice, we will need to compare variable names, since there + # are some input-only and some output-only channels. + # TODO: https://github.com/ai2cm/full-model/issues/1046 + n_constants = 0 + decoder_input_channels = 0 # was 1, to indicate insolation - now 0 + input_time_size = 1 # TODO: change to 2 (issue #1177) + output_time_size = 1 # TODO: change to 4 (issue #1177) + + conv_next_block = conv_next_block_config(in_channels=in_channels) + down_sampling_block = down_sampling_block_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config( + conv_next_block, down_sampling_block, n_channels=[16, 8, 4] + ) + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + decoder = decoder_config( + conv_next_block, + up_sampling_block, + output_layer, + recurrent_block, + n_channels=[4, 8, 16], + ) + net_config = dict( + encoder=encoder, + decoder=decoder, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + ) + spatial_dimensions_str: Literal["healpix", "latlon"] = "healpix" + elif nettype == "Samudra": + net_config = dict( + ch_width=[8, 16], + dilation=[2, 4], + n_layers=[1, 1], + ) + spatial_dimensions_str = "latlon" + elif nettype == "SphericalFourierNeuralOperatorNet": + net_config = dict( + num_layers=2, + embed_dim=12, + ) + spatial_dimensions_str = "latlon" + elif nettype == "NoiseConditionedSFNO": + net_config = dict( + num_layers=2, + embed_dim=12, + ) + if use_healpix: + net_config["data_grid"] = "healpix" + spatial_dimensions_str = "healpix" + else: + spatial_dimensions_str = "latlon" + + if nettype == "SphericalFourierNeuralOperatorNet": + corrector_config: AtmosphereCorrectorConfig | CorrectorSelector = ( + CorrectorSelector( + type="atmosphere_corrector", + config=dataclasses.asdict( + AtmosphereCorrectorConfig(conserve_dry_air=True) + ), + ) + ) + else: + corrector_config = AtmosphereCorrectorConfig() + + logging_config = LoggingConfig( + log_to_screen=True, + log_to_wandb=log_to_wandb, + log_to_file=True, + project="fme", + entity="ai2cm", + ) + if skip_inline_inference: + inline_inference_config = None + weather_evaluation_config = None + else: + inline_inference_config = InlineInferenceConfig( + aggregator=InferenceEvaluatorAggregatorConfig( + monthly_reference_data=( + str(monthly_data_filename) + if monthly_data_filename is not None + else None + ), + ), + loader=InferenceDataLoaderConfig( + dataset=XarrayDataConfig( + data_path=str(valid_data_path), + spatial_dimensions=spatial_dimensions_str, + ), + start_indices=InferenceInitialConditionIndices( + first=0, + n_initial_conditions=4, + interval=1, + ), + ), + n_forward_steps=inference_forward_steps, + forward_steps_in_memory=2, + ) + weather_evaluation_config = WeatherEvaluationConfig( + aggregator=InferenceEvaluatorAggregatorConfig( + monthly_reference_data=( + str(monthly_data_filename) + if monthly_data_filename is not None + else None + ), + ), + loader=InferenceDataLoaderConfig( + dataset=XarrayDataConfig( + data_path=str(valid_data_path), + spatial_dimensions=spatial_dimensions_str, + ), + start_indices=InferenceInitialConditionIndices( + first=0, + n_initial_conditions=4, + interval=1, + ), + ), + n_forward_steps=inference_forward_steps, + forward_steps_in_memory=2, + ) + + train_config = TrainConfig( + train_loader=DataLoaderConfig( + dataset=XarrayDataConfig( + data_path=str(train_data_path), + spatial_dimensions=spatial_dimensions_str, + ), + batch_size=4, + num_data_workers=0, + time_buffer=time_buffer, + sample_with_replacement=10, + ), + validation_loader=DataLoaderConfig( + dataset=XarrayDataConfig( + data_path=str(valid_data_path), + spatial_dimensions=spatial_dimensions_str, + ), + batch_size=4, + num_data_workers=0, + ), + optimization=OptimizationConfig( + use_gradient_accumulation=True, + optimizer_type="Adam", + lr=0.001, + kwargs=dict(weight_decay=0.01), + scheduler=SchedulerConfig( + type="CosineAnnealingLR", + kwargs=dict(T_max=1), + ), + ), + stepper=StepperConfig( + loss=StepLossConfig(type="MSE"), + crps_training=crps_training, + train_n_forward_steps=TimeLengthProbabilities( + outcomes=[ + TimeLengthProbability(steps=1, probability=0.5), + TimeLengthProbability(steps=n_forward_steps, probability=0.5), + ] + ), + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + crps_training=crps_training, + in_names=in_variable_names, + out_names=out_variable_names, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + global_means_path=str(global_means_path), + global_stds_path=str(global_stds_path), + ), + residual=NormalizationConfig( + global_means_path=str(global_means_path), + global_stds_path=str(global_stds_path), + ), + ), + builder=ModuleSelector( + type=nettype, + config=net_config, + ), + ocean=OceanConfig( + surface_temperature_name=in_variable_names[0], + ocean_fraction_name=mask_name, + ), + corrector=corrector_config, + ) + ), + ), + ), + inference=inline_inference_config, + weather_evaluation=weather_evaluation_config, + max_epochs=max_epochs, + segment_epochs=segment_epochs, + save_checkpoint=True, + logging=logging_config, + experiment_dir=str(results_dir), + save_per_epoch_diagnostics=save_per_epoch_diagnostics, + validation_aggregator=OneStepAggregatorConfig( + log_snapshots=log_validation_maps, + log_mean_maps=log_validation_maps, + ), + ) + + inference_config = InferenceEvaluatorConfig( + experiment_dir=str(results_dir), + n_forward_steps=6, + forward_steps_in_memory=2, + checkpoint_path=str(results_dir / "training_checkpoints" / "best_ckpt.tar"), + data_writer=DataWriterConfig( + save_monthly_files=False, + save_prediction_files=False, + files=[FileWriterConfig("autoregressive")], + ), + aggregator=InferenceEvaluatorAggregatorConfig( + log_video=True, + ), + logging=logging_config, + loader=InferenceDataLoaderConfig( + dataset=XarrayDataConfig( + data_path=str(valid_data_path), + spatial_dimensions=spatial_dimensions_str, + ), + start_indices=InferenceInitialConditionIndices( + first=0, + n_initial_conditions=4, + interval=1, + ), + ), + ) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f_train: + f_train.write(yaml.dump(dataclasses.asdict(train_config))) + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".yaml" + ) as f_inference: + f_inference.write(yaml.dump(dataclasses.asdict(inference_config))) + + return f_train.name, f_inference.name + + +def get_sizes( + spatial_dims: HorizontalCoordinates = LatLonCoordinates( + lon=torch.Tensor(np.arange(32)), + lat=torch.Tensor(np.arange(16)), + ), + n_time=3, + nz_interface=3, +) -> DimSizes: + return DimSizes( + n_time=n_time, + horizontal=copy.deepcopy(spatial_dims.loaded_sizes), + nz_interface=nz_interface, + ) + + +def _setup( + path, + nettype, + log_to_wandb=False, + max_epochs=1, + segment_epochs=1, + n_time=10, + timestep_days=5, + inference_forward_steps=2, + use_healpix=False, + save_per_epoch_diagnostics=False, + crps_training=False, + log_validation_maps=False, + skip_inline_inference=False, + time_buffer=1, +): + if not path.exists(): + path.mkdir() + seed = 0 + np.random.seed(seed) + in_variable_names = [ + "PRESsfc", + "specific_total_water_0", + "specific_total_water_1", + "surface_temperature", + "baz", + ] + out_variable_names = [ + "PRESsfc", + "specific_total_water_0", + "specific_total_water_1", + "surface_temperature", + ] + mask_name = "mask" + all_variable_names = list(set(in_variable_names + out_variable_names)) + + if use_healpix: + hpx_coords = HEALPixCoordinates( + face=torch.Tensor(np.arange(12)), + width=torch.Tensor(np.arange(16)), + height=torch.Tensor(np.arange(16)), + ) + dim_sizes = get_sizes(spatial_dims=hpx_coords, n_time=n_time) + else: + dim_sizes = get_sizes(n_time=n_time) + + data_dir = path / "data" + stats_dir = path / "stats" + results_dir = path / "results" + data_dir.mkdir() + stats_dir.mkdir() + results_dir.mkdir() + save_nd_netcdf( + data_dir / "data.nc", + dim_sizes, + variable_names=all_variable_names + [mask_name], + timestep_days=timestep_days, + ) + save_scalar_netcdf( + stats_dir / "stats-mean.nc", + variable_names=all_variable_names, + ) + save_scalar_netcdf( + stats_dir / "stats-stddev.nc", + variable_names=all_variable_names, + ) + + monthly_dim_sizes: DimSizes + if use_healpix: + # monthly reference functionality not supported for HEALPix + # see https://github.com/ai2cm/full-model/issues/1561 + monthly_data_filename = None + else: + monthly_dim_sizes = get_sizes(n_time=10 * 12, nz_interface=1) + monthly_reference_data = MonthlyReferenceData( + path=data_dir, + names=out_variable_names, + dim_sizes=monthly_dim_sizes, + n_ensemble=3, + ) + monthly_data_filename = monthly_reference_data.data_filename + + train_config_filename, inference_config_filename = _get_test_yaml_files( + train_data_path=data_dir, + valid_data_path=data_dir, + monthly_data_filename=monthly_data_filename, + results_dir=results_dir, + global_means_path=stats_dir / "stats-mean.nc", + global_stds_path=stats_dir / "stats-stddev.nc", + in_variable_names=in_variable_names, + out_variable_names=out_variable_names, + mask_name=mask_name, + nettype=nettype, + log_to_wandb=log_to_wandb, + max_epochs=max_epochs, + segment_epochs=segment_epochs, + inference_forward_steps=inference_forward_steps, + use_healpix=use_healpix, + crps_training=crps_training, + save_per_epoch_diagnostics=save_per_epoch_diagnostics, + log_validation_maps=log_validation_maps, + skip_inline_inference=skip_inline_inference, + time_buffer=time_buffer, + ) + return train_config_filename, inference_config_filename + + +@pytest.mark.parametrize( + "nettype, crps_training, log_validation_maps, use_healpix", + [ + ("SphericalFourierNeuralOperatorNet", False, True, False), + ], +) +def test_train_and_inference( + tmp_path, + nettype, + crps_training, + log_validation_maps: bool, + use_healpix: bool, + very_fast_only: bool, +): + """Ensure that ACE training and subsequent standalone inference run without errors. + + Args: + tmp_path: pytext fixture for temporary workspace. + nettype: parameter indicating model architecture to use. + very_fast_only: parameter indicating whether to skip slow tests. + """ + if very_fast_only: + pytest.skip("Skipping non-fast tests") + # Let's generate the configuration file on a single processor. + with Distributed.non_distributed(): + train_config, inference_config = _setup( + tmp_path, + nettype, + log_to_wandb=False, + timestep_days=20, + n_time=int(366 * 3 / 20 + 1), + inference_forward_steps=50,#int(366 * 3 / 20 / 2 - 1) * 2, # must be even + use_healpix=use_healpix, + crps_training=crps_training, + save_per_epoch_diagnostics=True, + log_validation_maps=log_validation_maps, + ) + # return + # with mock_wandb() as wandb: + train_main( + yaml_config=train_config + ) + # wandb_logs = wandb.get_logs() + # for log in wandb_logs: + # # ensure inference time series is not logged + # assert "inference/mean/forecast_step" not in log + + # epoch_logs = wandb_logs[-1] + # assert "inference/mean_step_20_norm/weighted_rmse/channel_mean" in epoch_logs + # assert "val/mean_norm/weighted_rmse/channel_mean" in epoch_logs + + # train_main( + # yaml_config=train_config, + # ) diff --git a/fme/ace/utils/comm.py b/fme/ace/utils/comm.py new file mode 100644 index 000000000..c17fcde13 --- /dev/null +++ b/fme/ace/utils/comm.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import logging +import math +from typing import Union +import numpy as np + +# we are using the distributed manager from physicsnemo +from physicsnemo.distributed.manager import DistributedManager +from physicsnemo.distributed.config import ProcessGroupNode, ProcessGroupConfig + +# we need this +_DM = None +_COMM_ROOTS = {} + + +def get_size(name: str) -> int: + global _DM + if (_DM is not None) and (_DM.world_size > 1): + return _DM.group_size(name) + else: + return 1 + + +def get_rank(name: str) -> int: + global _DM + if (_DM is not None) and (_DM.world_size > 1): + return _DM.group_rank(name) + else: + return 0 + + +def get_group(name: str): + global _DM + if _DM is not None: + return _DM.group(name) + else: + return None + + +def get_root(name: str) -> int: + global _DM + global _COMM_ROOTS + if (name in _COMM_ROOTS) and (_DM.world_size > 1): + return _COMM_ROOTS[name] + else: + return 0 + + +# specialized routines for world comms +def get_world_size(): + global _DM + if _DM is not None: + return _DM.world_size + else: + return 1 + + +def get_world_rank(): + global _DM + if _DM is not None: + return _DM.rank + else: + return 0 + + +def get_local_rank(): + global _DM + if _DM is not None: + return _DM.local_rank + else: + return 0 + + +def get_comm_names(): + global _DM + if _DM is not None: + return [name for name in _DM.group_names if (not name.startswith("__orthogonal_to"))] + else: + return [] + + +def get_model_comm_names(): + return [x for x in get_comm_names() if x not in ["world", "data", "ensemble", "batch"]] + + +def is_distributed(name: str): + global _DM + if _DM is not None: + return name in _DM.group_names + else: + return False + + +def cleanup(): + global _DM + if _DM is not None: + _DM.cleanup() + _DM = None + return + + +# initialization routine +def init(model_parallel_sizes=[1, 1, 1, 1], model_parallel_names=["h", "w", "fin", "fout"], data_parallel_sizes=[1, -1], data_parallel_names=["ensemble", "batch"], verbose=False): + + # call basic init first + DistributedManager.initialize() + + # extract manager object + global _DM + _DM = DistributedManager() + + # create process group config: + world = ProcessGroupNode("world", size=_DM.world_size) + pconfig = ProcessGroupConfig(world) + + # add nodes: + # model + pconfig.add_node(ProcessGroupNode("model"), parent="world") + # spatial and matmul + pconfig.add_node(ProcessGroupNode("spatial"), parent="model") + pconfig.add_node(ProcessGroupNode("matmul"), parent="model") + # subgroups for spatial + pconfig.add_node(ProcessGroupNode("h"), parent="spatial") + pconfig.add_node(ProcessGroupNode("w"), parent="spatial") + # subgroups for matmul: + pconfig.add_node(ProcessGroupNode("fin"), parent="matmul") + pconfig.add_node(ProcessGroupNode("fout"), parent="matmul") + # add data node last + pconfig.add_node(ProcessGroupNode("data"), parent="world") + # other data parallel dims + for dgname in data_parallel_names: + pconfig.add_node(ProcessGroupNode(dgname), parent="data") + + # set up leaf sizes + # model + model_leaf_config = {} + for k, v in zip(model_parallel_names, model_parallel_sizes): + model_leaf_config[k] = v + # data + data_group_size = _DM.world_size // math.prod(model_leaf_config.values()) + data_leaf_config = {} + for k, v in zip(data_parallel_names, data_parallel_sizes): + data_leaf_config[k] = v + # determine some automatic shapes: only one is supported, the others will + # default to 1 + ndata = 1 + for k in data_leaf_config: + v = data_leaf_config[k] + if v > 0: + ndata *= v + for k in data_leaf_config: + v = data_leaf_config[k] + if v <= 0: + data_leaf_config[k] = data_group_size // ndata + # the others will automatically be sized 1 + ndata = data_group_size + # fuse leaf configs + leaf_config = model_leaf_config + for k in data_leaf_config: + leaf_config[k] = data_leaf_config[k] + # update sizes + pconfig.set_leaf_group_sizes(leaf_config, update_parent_sizes=True) + + # create remaining process groups + _DM.create_groups_from_config(pconfig, verbose=(verbose and (_DM.rank == 0))) + + # get comm roots: + global _COMM_ROOTS + for gname in get_comm_names(): + rank = _DM.rank + for grp in _DM._group_ranks[gname]: + if rank in grp: + _COMM_ROOTS[gname] = min(grp) + + if verbose: + import torch + + for rank in range(_DM.world_size): + if rank == _DM.rank: + print(f"{rank}: groups:") + for gname in get_comm_names(): + print(f"\t{gname}: {_DM._group_ranks[gname]}, root={_COMM_ROOTS[gname]}") + torch.distributed.barrier(device_ids=[_DM.local_rank]) + + return get_size("model") diff --git a/fme/core/dataset/test_helper.py b/fme/core/dataset/test_helper.py new file mode 100644 index 000000000..92883664b --- /dev/null +++ b/fme/core/dataset/test_helper.py @@ -0,0 +1,84 @@ +import os + +import torch +import torch.distributed as dist +from physicsnemo.distributed.utils import split_tensor_along_dim + +from pathlib import Path + +def create_directory(directory_name): + """ + Create a directory if it does not already exist. + + Parameters: + directory_name (str): The name of the directory to create. + + Returns: + None + """ + try: + # Using pathlib to create the directory + Path(directory_name).mkdir(parents=True, exist_ok=True) + print(f"Directory '{directory_name}' created successfully or already exists.") + except Exception as e: + print(f"An error occurred while creating the directory: {e}") + +# this computes a relative error compatible with torch.allclose or np.allclose +def relative_error(tensor1, tensor2): + return torch.sum(torch.abs(tensor1-tensor2)) / torch.sum(torch.abs(tensor2)) + +# this computes an absolute error compatible with torch.allclose or np.allclose +def absolute_error(tensor1, tensor2): + return torch.max(torch.abs(tensor1-tensor2)) + +def gather_helper(tensor, dim=None, group=None): + # get shapes + if (dim is not None) and (dist.get_world_size(group=group) > 1): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + shape_loc = torch.tensor([tensor.shape[dim]], dtype=torch.long, device=tensor.device) + shape_list = [torch.empty_like(shape_loc) for _ in range(dist.get_world_size(group=group))] + shape_list[grank] = shape_loc + dist.all_gather(shape_list, shape_loc, group=group) + tshapes = [] + for ids in range(gsize): + tshape = list(tensor.shape) + tshape[dim] = shape_list[ids].item() + tshapes.append(tuple(tshape)) + tens_gather = [torch.empty(tshapes[ids], dtype=tensor.dtype, device=tensor.device) for ids in range(gsize)] + tens_gather[grank] = tensor + dist.all_gather(tens_gather, tensor, group=group) + tensor_gather = torch.cat(tens_gather, dim=dim) + else: + tensor_gather = tensor.clone() + + return tensor_gather + +def split_helper(tensor, dim=None, group=None): + with torch.no_grad(): + if (dim is not None) and dist.get_world_size(group=group): + gsize = dist.get_world_size(group=group) + grank = dist.get_rank(group=group) + # split in dim + tensor_list_local = split_tensor_along_dim(tensor, dim=dim, num_chunks=gsize) + tensor_local = tensor_list_local[grank] + else: + tensor_local = tensor.clone() + + return tensor_local + +def gather_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_gather = gather_helper(tensor, dim=hdim, group=h_group) + tensor_gather = gather_helper(tensor_gather, dim=wdim, group=w_group) + return tensor_gather + +def split_helper_conv(tensor, hdim=-2, wdim=-1, w_group=1, h_group=1): + tensor_local = split_helper(tensor, dim=hdim, group=h_group) + tensor_local = split_helper(tensor_local, dim=wdim, group=w_group) + return tensor_local + +def init_seed(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + return diff --git a/fme/core/dataset/test_xarray_sp_dist.py b/fme/core/dataset/test_xarray_sp_dist.py new file mode 100755 index 000000000..ba0b54e39 --- /dev/null +++ b/fme/core/dataset/test_xarray_sp_dist.py @@ -0,0 +1,414 @@ +"""This file contains unit tests of XarrayDataset.""" + +import dataclasses +import datetime +import os +from collections import namedtuple +from collections.abc import Iterable + +import cftime +import numpy as np +import pandas as pd +import pytest +import torch +import xarray as xr +from xarray.coding.times import CFDatetimeCoder + +from fme.core.coordinates import ( + DepthCoordinate, + HybridSigmaPressureCoordinate, + LatLonCoordinates, + NullVerticalCoordinate, +) +from fme.core.dataset.concat import XarrayConcat, get_dataset +from fme.core.dataset.merged import MergedXarrayDataset +from fme.core.dataset.time import RepeatedInterval, TimeSlice +from fme.core.dataset.utils import FillNaNsConfig +from fme.core.distributed import Distributed +import torch_harmonics.distributed as thd +from fme.core.dataset.xarray import ( + GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD, + OverwriteConfig, + XarrayDataConfig, + XarrayDataset, + XarraySubset, + _get_cumulative_timesteps, + _get_file_local_index, + _get_raw_times, + _get_timestep, + _get_vertical_coordinate, + _repeat_and_increment_time, + get_xarray_dataset, +) + +from fme.core.dataset.test_helper import gather_helper_conv, relative_error, init_seed +from fme.core.mask_provider import MaskProvider +from fme.core.typing_ import Slice + +from .utils import as_broadcasted_tensor + +SLICE_NONE = slice(None) +MOCK_DATA_FREQ = "3h" +MOCK_DATA_START_DATE = "2003-03" +MOCK_DATA_LAT_DIM, MOCK_DATA_LON_DIM = ("lat", "lon") + + +@dataclasses.dataclass +class VariableNames: + time_dependent_names: Iterable[str] + time_invariant_names: Iterable[str] + initial_condition_names: Iterable[str] + + def _concat(self, *lists): + return_value = [] + for list in lists: + return_value.extend(list) + return return_value + + @property + def all_names(self) -> list[str]: + return self._concat( + self.time_dependent_names, + self.time_invariant_names, + self.initial_condition_names, + ) + + @property + def spatial_resolved_names(self) -> list[str]: + return self._concat(self.time_dependent_names, self.time_invariant_names) + + +MockData = namedtuple( + "MockData", ("tmpdir", "obs_times", "start_times", "start_indices", "var_names") +) + + +def _get_data( + tmp_path_factory, + dirname, + start, + end, + file_freq, + step_freq, + calendar, + with_nans=False, + var_names=["foo", "bar"], + write_extra_vars=True, + add_ensemble_dim=False, +) -> MockData: + """Constructs an xarray dataset and saves to disk in netcdf format.""" + obs_times = xr.date_range( + start, + end, + freq=step_freq, + calendar=calendar, + inclusive="left", + use_cftime=True, + ) + start_times = xr.date_range( + start, + end, + freq=file_freq, + calendar=calendar, + inclusive="left", + use_cftime=True, + ) + #NOTE: fixing random seed + np.random.seed(333) + obs_delta = obs_times[1] - obs_times[0] + n_levels = 2 + n_lat, n_lon = 4, 8 + n_sample = 3 + + non_time_dims = ("sample", "lat", "lon") if add_ensemble_dim else ("lat", "lon") + non_time_shape = (n_sample, n_lat, n_lon) if add_ensemble_dim else (n_lat, n_lon) + + constant_var = xr.DataArray( + np.random.randn(*non_time_shape).astype(np.float32), + dims=non_time_dims, + ) + constant_scalar_var = xr.DataArray(1.0).astype(np.float32) + ak = {f"ak_{i}": float(i) for i in range(n_levels)} + bk = {f"bk_{i}": float(i + 1) for i in range(n_levels)} + tmpdir = tmp_path_factory.mktemp(dirname) + filenames = [] + for i, first in enumerate(start_times): + if first != start_times[-1]: + last = start_times[i + 1] + else: + last = obs_times[-1] + obs_delta + time = xr.date_range( + first, + last, + freq=step_freq, + calendar=calendar, + inclusive="left", + use_cftime=True, + ) + data_vars: dict[str, float | xr.DataArray] = {**ak, **bk} + for var_name in var_names: + data = np.random.randn(len(time), *non_time_shape).astype(np.float32) + if with_nans: + data[0, :, 0] = np.nan + data_vars[var_name] = xr.DataArray(data, dims=("time", *non_time_dims)) + + data_varying_scalar = np.random.randn(len(time)).astype(np.float32) + if with_nans: + constant_var[0, 0] = np.nan + + if write_extra_vars: + data_vars["varying_scalar_var"] = xr.DataArray( + data_varying_scalar, dims=("time",) + ) + data_vars["constant_var"] = constant_var + data_vars["constant_scalar_var"] = constant_scalar_var + + coords = { + "time": xr.DataArray(time, dims=("time",)), + "lat": xr.DataArray(np.arange(n_lat, dtype=np.float32), dims=("lat",)), + "lon": xr.DataArray(np.arange(n_lon, dtype=np.float32), dims=("lon",)), + } + if add_ensemble_dim: + coords["sample"] = xr.DataArray( + np.arange(n_sample, dtype=np.float32), dims=("sample",) + ) + # variable without the ensemble dimension is useful for checking + # broadcast behavior + data_vars["var_no_ensemble_dim"] = xr.DataArray( + np.random.randn(len(time), n_lat, n_lon).astype(np.float32), + dims=("time", "lat", "lon"), + ) + # set values to sample index for testing convenience + sample_index_values = np.broadcast_to( + np.arange(n_sample).reshape(1, n_sample, 1, 1), # shape [1, ns, 1, 1], + (len(time), n_sample, n_lat, n_lon), + ) + data_vars["var_matches_sample_index"] = ( + xr.zeros_like(data_vars["foo"]) + sample_index_values + ) + + ds = xr.Dataset(data_vars=data_vars, coords=coords) + filename = tmpdir / f"{first.strftime('%Y%m%d%H')}.nc" + ds.to_netcdf( + filename, + unlimited_dims=["time"], + format="NETCDF4", + ) + filenames.append(filename) + + initial_condition_names = () + start_indices = _get_cumulative_timesteps(_get_raw_times(filenames, "netcdf4")) + if write_extra_vars: + variable_names = VariableNames( + time_dependent_names=(*var_names, "varying_scalar_var"), + time_invariant_names=("constant_var", "constant_scalar_var"), + initial_condition_names=initial_condition_names, + ) + else: + variable_names = VariableNames( + time_dependent_names=var_names, + time_invariant_names=(), + initial_condition_names=initial_condition_names, + ) + return MockData(tmpdir, obs_times, start_times, start_indices, variable_names) + + +def get_mock_monthly_netcdfs( + tmp_path_factory, + dirname, + with_nans=False, + end_date="2003-06", + var_names=["foo", "bar"], + write_extra_vars=True, + add_ensemble_dim=False, +) -> MockData: + return _get_data( + tmp_path_factory, + dirname, + start=MOCK_DATA_START_DATE, + end=end_date, + file_freq="MS", + step_freq=MOCK_DATA_FREQ, + calendar="standard", + with_nans=with_nans, + var_names=var_names, + write_extra_vars=write_extra_vars, + add_ensemble_dim=add_ensemble_dim, + ) + + +@pytest.fixture(scope="session") +def mock_monthly_netcdfs(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs(tmp_path_factory, "month") + + +@pytest.fixture(scope="session") +def mock_monthly_netcdfs_another_source(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs( + tmp_path_factory, "month_another_source", var_names=["baz", "qux"] + ) + + +@pytest.fixture(scope="session") +def mock_monthly_netcdfs_another_source_diff_time(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs( + tmp_path_factory, + "month_another_source", + end_date="2003-08", + var_names=["baz", "qux"], + write_extra_vars=False, + ) + + +@pytest.fixture(scope="session") +def mock_monthly_netcdfs_with_nans(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs(tmp_path_factory, "month_with_nans", with_nans=True) + + +@pytest.fixture(scope="session") +def mock_monthly_netcdfs_ensemble_dim(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs( + tmp_path_factory, + "month_with_ensemble_dim", + add_ensemble_dim=True, + var_names=["foo", "bar", "var_no_ensemble_dim", "var_matches_sample_index"], + ) + + +@pytest.fixture(scope="session") +def mock_monthly_zarr_ensemble_dim( + tmp_path_factory, mock_monthly_netcdfs_ensemble_dim +) -> MockData: + zarr_parent = tmp_path_factory.mktemp("zarr") + filename = "data.zarr" + data = load_files_without_dask( + mock_monthly_netcdfs_ensemble_dim.tmpdir.glob("*.nc") + ) + data.to_zarr(zarr_parent / filename) + return MockData( + zarr_parent, + mock_monthly_netcdfs_ensemble_dim.obs_times, + mock_monthly_netcdfs_ensemble_dim.start_times, + mock_monthly_netcdfs_ensemble_dim.start_indices, + mock_monthly_netcdfs_ensemble_dim.var_names, + ) + + +def load_files_without_dask(files, engine="netcdf4") -> xr.Dataset: + """Load a sequence of files without dask, concatenating along the time dimension. + + We load the data from the files into memory to ensure Datasets are properly closed, + since xarray cannot concatenate Datasets lazily without dask anyway. This should be + acceptable for the small datasets we construct for test purposes. + """ + datasets = [] + for file in sorted(files): + with xr.open_dataset( + file, + decode_times=CFDatetimeCoder(use_cftime=True), + decode_timedelta=False, + engine=engine, + ) as ds: + datasets.append(ds.load()) + return xr.concat(datasets, dim="time", data_vars="minimal", coords="minimal") + + +@pytest.fixture(scope="session") +def mock_monthly_zarr(tmp_path_factory, mock_monthly_netcdfs) -> MockData: + zarr_parent = tmp_path_factory.mktemp("zarr") + filename = "data.zarr" + data = load_files_without_dask(mock_monthly_netcdfs.tmpdir.glob("*.nc")) + data.to_zarr(zarr_parent / filename) + return MockData( + zarr_parent, + mock_monthly_netcdfs.obs_times, + mock_monthly_netcdfs.start_times, + mock_monthly_netcdfs.start_indices, + mock_monthly_netcdfs.var_names, + ) + + +@pytest.fixture(scope="session") +def mock_monthly_zarr_with_nans( + tmp_path_factory, mock_monthly_netcdfs_with_nans +) -> MockData: + zarr_parent = tmp_path_factory.mktemp("zarr") + filename = "data.zarr" + data = load_files_without_dask(mock_monthly_netcdfs_with_nans.tmpdir.glob("*.nc")) + data.to_zarr(zarr_parent / filename) + return MockData( + zarr_parent, + mock_monthly_netcdfs_with_nans.obs_times, + mock_monthly_netcdfs_with_nans.start_times, + mock_monthly_netcdfs_with_nans.start_indices, + mock_monthly_netcdfs_with_nans.var_names, + ) + + +@pytest.fixture(scope="session") +def mock_yearly_netcdfs(tmp_path_factory): + return _get_data( + tmp_path_factory, + "yearly", + start="1999", + end="2001", + file_freq="YS", + step_freq="1D", + calendar="noleap", + ) +# TODO: Make this run with a Bash script; I am running this test manually. +# 1. Get an interactive node in PM. +# 2. Then srun -n 4 pytest test_xarray_sp_dist.py. +def test_concat_of_XarrayConcat_w_spatial_parallel(mock_monthly_netcdfs): + # We must use the same random seed because this code will be executed several times. + init_seed(333) + mock_data: MockData = mock_monthly_netcdfs + + n_timesteps = 5 + names = mock_data.var_names.all_names + ## without domain decomposition + with Distributed.non_distributed(): + config_ref = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(None, 4)) + ref, _ = get_dataset([config_ref], names, n_timesteps) + niters= len(ref) + tensor_refs=[] + for i in range(niters): + ref_t, _, _=ref[i] + for var in ref_t: + reft = ref_t[var] + # NOTE: We need to make a hard copy because the reference gets overwritten. + tensor_refs.append(reft.clone()) + + dist = Distributed.get_instance() + w_group = dist.comm_get_group("w") + h_group = dist.comm_get_group("h") + config_c1 = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(None, 4)) + c1, _ = get_dataset([config_c1], names, n_timesteps) + + with torch.no_grad(): + niters= len(ref) + j=0 + for i in range(niters): + t1,_,_=c1[i] + for var in ref_t: + reft = tensor_refs[j] + j+=1 + c1t = t1[var] + # NOTE: only check variables w time, lat, and lon + if len(c1t.shape) > 3: + #gather_helper_conv assumes that the distribution is across the GPUs. + c1t=c1t.to(dist.local_rank) + c1t_full = gather_helper_conv(c1t, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + # Get back to the CPU so that it can be compared with the reference. + c1t_full_cpu=c1t_full.to("cpu") + err = relative_error(c1t_full_cpu, reft) + if (dist.local_rank == 0): + print(var, f"final relative error of output: {err.item()}") + this_shape=c1t_full_cpu.shape + for f in range(this_shape[0]): + for g in range(this_shape[1]): + for k in range(this_shape[2]): + diff = abs(c1t_full_cpu[f,g,k]-reft[f,g,k]) + if diff > 1e-12: + print(f,g, k, " : " ,c1t_full_cpu[f,g,k], reft[f,g,k]) + assert torch.equal(c1t_full_cpu,reft) diff --git a/fme/core/dataset/xarray.py b/fme/core/dataset/xarray.py index 1dba23fc7..04a092012 100644 --- a/fme/core/dataset/xarray.py +++ b/fme/core/dataset/xarray.py @@ -19,6 +19,7 @@ import xarray as xr from xarray.coding.times import CFDatetimeCoder +from fme.core.distributed import Distributed from fme.core.coordinates import ( DepthCoordinate, HorizontalCoordinates, @@ -43,6 +44,7 @@ load_series_data_zarr_async, ) + SLICE_NONE = slice(None) GET_RAW_TIMES_NUM_FILES_PARALLELIZATION_THRESHOLD = 12 logger = logging.getLogger(__name__) @@ -576,6 +578,7 @@ def __init__( self._check_isel_dimensions(first_dataset.sizes) self._labels = set(config.labels) self._infer_timestep = config.infer_timestep + self._dist = Distributed.get_instance() def _check_isel_dimensions(self, data_dim_sizes): # Horizontal dimensions are not currently supported, as the current isel code @@ -830,19 +833,26 @@ def get_sample_by_time_slice( else: ds = self._open_file(file_idx) ds = ds.isel(**self.isel) + # ds_local, shape_local = self._dist.dataset_reshape(ds, self.dims, shape) tensor_dict = load_series_data( idx=start, n_steps=n_steps, - ds=ds, + ds=ds, #ds_local, names=self._time_dependent_names, final_dims=self.dims, - final_shape=shape, + final_shape=shape, #shape_local, fill_nans=self.fill_nans, ) + # ds_local.close() + # del ds_local ds.close() del ds + #CHECK: DO I also need to del ds + tensor_dict_local=self._dist.get_local_tensor_dict(tensor_dict, self._shape_excluding_time_after_selection) + for n in self._time_dependent_names: - arrays.setdefault(n, []).append(tensor_dict[n]) + arrays.setdefault(n, []).append(tensor_dict_local[n]) + # arrays.setdefault(n, []).append(tensor_dict[n]) tensors: TensorDict = {} for n, tensor_list in arrays.items(): @@ -854,11 +864,20 @@ def get_sample_by_time_slice( ds = self._open_file(idxs[0]) ds = ds.isel(**self.isel) shape = [total_steps] + self._shape_excluding_time_after_selection + # ds_local, shape_local = self._dist.dataset_reshape(ds, self.dims, shape) + for name in self._time_invariant_names: variable = ds[name].variable if self.fill_nans is not None: variable = variable.fillna(self.fill_nans.value) - tensors[name] = as_broadcasted_tensor(variable, self.dims, shape) + tensor_globar = as_broadcasted_tensor(variable, self.dims, shape) + if len(shape) == 3: + tensors[name]=tensor_globar[:,*self._dist.get_local_slices(self._shape_excluding_time_after_selection)] + else: + tensors[name] = tensor_globar + # ds_local.close() + # del ds_local + #CHECK: DO I also need to del ds ds.close() del ds diff --git a/fme/core/distributed.py b/fme/core/distributed.py index 26bedbfc3..47c2c51bb 100644 --- a/fme/core/distributed.py +++ b/fme/core/distributed.py @@ -1,13 +1,27 @@ import logging import os from collections.abc import Callable - +import contextlib import torch.distributed from torch.nn import SyncBatchNorm from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel from fme.core.device import get_device, using_gpu, using_srun +from fme.ace.utils import comm +from physicsnemo.distributed.utils import compute_split_shapes +from fme.ace.models.makani_mpu.mappings import init_gradient_reduction_hooks +from fme.ace.models.makani_mpu.layers import DistributedMLP +from fme.ace.models.makani_mpu.fft import DistributedRealFFT2, DistributedInverseRealFFT2 +from fme.ace.models.makani_mpu.layer_norm import DistributedInstanceNorm2d, DistributedLayerNorm +from torch import nn +from fme.ace.models.makani_utils.checkpoint_helpers import ( + gather_model_state_dict as gmsd, + scatter_model_state_dict as smsd, +) +import torch_harmonics as th +import torch_harmonics.distributed as thd +from fme.core.dataset.test_helper import gather_helper_conv logger = logging.getLogger(__name__) @@ -59,12 +73,38 @@ def get_instance(cls) -> "Distributed": return singleton def __init__(self): - if torch.distributed.is_available() and not torch.distributed.is_initialized(): + h = int(os.environ.get("H_PARALLEL_SIZE", 1)) + w = int(os.environ.get("W_PARALLEL_SIZE", 1)) + fin = int(os.environ.get("FIN_PARALLEL_SIZE", 1)) + fout = int(os.environ.get("FOUT_PARALLEL_SIZE", 1)) + + self.spatial_parallelism = False + if (h > 1) or (w > 1) or (fin > 1) or (fout > 1): + self._distributed = self._init_makani_distributed(h, w, fin, fout) + self.spatial_parallelism = True + elif torch.distributed.is_available() and not torch.distributed.is_initialized(): self._distributed = self._init_distributed() else: self._distributed = False self._seed = 0 + def _init_makani_distributed(self, h, w, fin, fout): + distributed = (h > 1) or (w > 1) or (fin > 1) or (fout > 1) + if distributed: + # comm.init takes care of everything + comm.init( + model_parallel_sizes=[h, w, fin, fout], + model_parallel_names=["h", "w", "fin", "fout"], + verbose=False, + ) + self.world_size = comm.get_world_size() + self.rank = comm.get_world_rank() + self.local_rank = comm.get_local_rank() + self._device_id = self.local_rank + distributed = True + torch.cuda.set_device(self._device_id) + return distributed + def _init_distributed(self): if "RANK" in os.environ and not using_srun(): # we were executed with torchrun if using_gpu(): @@ -107,6 +147,77 @@ def _init_distributed(self): distributed = False return distributed + def comm_get_size(self, key: str): + return comm.get_size(key) if self.spatial_parallelism else 1 + + def comm_get_group(self, key: str): + return comm.get_group(key) if self.spatial_parallelism else 1 + + def comm_get_rank(self, key: str): + return comm.get_rank(key) if self.spatial_parallelism else 0 + + def scatter_model_state_dict(self, model: nn.Module, state_dict, strict=True): + if (self.spatial_parallelism) and (comm.get_size("model") > 1): + state_dict = smsd(model, state_dict, strict=strict) + return state_dict + + def gather_model_state_dict(self, model: nn.Module): + if (self.spatial_parallelism) and (comm.get_size("model") > 1): + return gmsd(model) + return model.state_dict() + + def get_local_shape_and_offset(self, crop_shape): + local_shape_h, local_shape_w = crop_shape + local_offset_h, local_offset_w = 0, 0 + size_h, size_w = self.comm_get_size("h"), self.comm_get_size("w") + rank_h, rank_w = self.comm_get_rank("h"), self.comm_get_rank("w") + if size_h > 1: + shapes_h = compute_split_shapes(local_shape_h, size_h) + local_shape_h = shapes_h[rank_h] + local_offset_h = sum(shapes_h[:rank_h]) + if size_w > 1: + shapes_w = compute_split_shapes(local_shape_w, size_w) + local_shape_w = shapes_w[rank_w] + local_offset_w = sum(shapes_w[:rank_w]) + return local_shape_h, local_offset_h, local_shape_w, local_offset_w + + def get_local_tensor_dict(self, tensor_dict, shape_excluding_time): + tensor_dict_local = {} + for n, tensor in tensor_dict.items(): + if len(tensor.shape) == 3: + tensor_dict_local[n] = tensor[ + :, *self.get_local_slices(shape_excluding_time) + ] + else: + tensor_dict_local[n] = tensor + + return tensor_dict_local + + def get_local_slices(self, crop_shape): + local_shape_h, local_shape_w = crop_shape + local_offset_h, local_offset_w = 0, 0 + size_h, size_w = self.comm_get_size("h"), self.comm_get_size("w") + rank_h, rank_w = self.comm_get_rank("h"), self.comm_get_rank("w") + if size_h > 1: + shapes_h = compute_split_shapes(local_shape_h, size_h) + local_shape_h = shapes_h[rank_h] + local_offset_h = sum(shapes_h[:rank_h]) + if size_w > 1: + shapes_w = compute_split_shapes(local_shape_w, size_w) + local_shape_w = shapes_w[rank_w] + local_offset_w = sum(shapes_w[:rank_w]) + return slice( + local_offset_h, local_offset_h + local_shape_h + ), slice( + local_offset_w, local_offset_w + local_shape_w + ) + + def sampler_replicas(self): + return self.comm_get_size("batch") if self.spatial_parallelism else self.world_size + + def sampler_rank(self): + return self.comm_get_rank("batch") if self.spatial_parallelism else self.rank + def get_sampler( self, dataset: torch.utils.data.Dataset, @@ -116,17 +227,25 @@ def get_sampler( return torch.utils.data.DistributedSampler( dataset, shuffle=shuffle, - num_replicas=self.world_size, - rank=self.rank, + num_replicas=self.sampler_replicas(), + rank=self.sampler_rank(), seed=self._seed, drop_last=drop_last, ) + def check_local_batch_size(self, batch_size): + if batch_size % self.comm_get_size("data") != 0: + raise ValueError( + f"batch_size ({batch_size}) must be divisible by " + f"data workers ({self.comm_get_size('data')})" + ) + def local_batch_size(self, batch_size: int) -> int: """ Get the local batch size for the current process. """ - return batch_size // self.world_size + new_world_size = self.comm_get_size("data") if self.spatial_parallelism else self.world_size + return batch_size // new_world_size def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: """ @@ -238,11 +357,35 @@ def is_distributed(self) -> bool: def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: """ Wrap a model with DistributedDataParallel if running in a distributed context. + For spatial parallelism, uses custom gradient reduction hooks. + For standard data parallelism, uses PyTorch's DistributedDataParallel. """ - if self.is_distributed() and any(p.requires_grad for p in module.parameters()): + # Only wrap if there are trainable parameters + if not any(p.requires_grad for p in module.parameters()): + return DummyWrapper(module) + + if self.spatial_parallelism: + # Use custom gradient reduction for spatial/model parallelism + capture_stream = torch.cuda.Stream(device="cuda") + with torch.cuda.stream(capture_stream): + module = init_gradient_reduction_hooks( + module, + device=self.local_rank, + reduction_buffer_count=1, + broadcast_buffers=False, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=False, + verbose=False, + ) + capture_stream.synchronize() + return module + + if self.is_distributed(): + # Use standard PyTorch DDP for data parallelism if using_gpu(): device_ids = [self._device_id] - output_device = [self._device_id] + output_device = self._device_id else: device_ids = None output_device = None @@ -251,8 +394,130 @@ def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: device_ids=device_ids, output_device=output_device, ) + + return DummyWrapper(module) + + def get_local_modes(self, inverse_transform): + if isinstance(inverse_transform, thd.DistributedInverseRealSHT): + if self.spatial_parallelism: + modes_lat_local = inverse_transform.l_shapes[self.comm_get_rank("h")] + modes_lon_local = inverse_transform.m_shapes[self.comm_get_rank("w")] + # These variables are not used + # nlat_local = inverse_transform.lat_shapes[comm.get_rank("h")] + # nlon_local = inverse_transform.lon_shapes[comm.get_rank("w")] + else: + modes_lat_local = inverse_transform.lmax_local + modes_lon_local = inverse_transform.mmax_local + # These variables are not used + # self.lpad = 0 + # self.mpad = 0 else: - return DummyWrapper(module) + modes_lat_local = inverse_transform.lmax + modes_lon_local = inverse_transform.mmax + return modes_lat_local, modes_lon_local + + def get_input_out_shapes(self,forward_transform,inverse_transform): + if (self.comm_get_size("spatial") > 1): + input_shape_loc = ( + forward_transform.lat_shapes[self.comm_get_rank("h")], + forward_transform.lon_shapes[self.comm_get_rank("w")] + ) + output_shape_loc = ( + inverse_transform.lat_shapes[self.comm_get_rank("h")], + inverse_transform.lon_shapes[self.comm_get_rank("w")] + ) + else: + input_shape_loc = ( + forward_transform.nlat, + forward_transform.nlon + ) + output_shape_loc = ( + inverse_transform.nlat, + inverse_transform.nlon + ) + return input_shape_loc, output_shape_loc + + def dataset_reshape(self, ds, dims, shape): + shape_excluding_time = (shape[1], shape[2]) + # Check for the presence of latitude and longitude dimensions + has_lat = "lat" in dims + has_lon = "lon" in dims + has_latitude = "latitude" in dims + has_longitude = "longitude" in dims + + # Get local slices for height and width + slice_h, slice_w = self.get_local_slices(shape_excluding_time) + + # Determine the appropriate dimension names for latitude and longitude + lat_dim = "lat" if has_lat else "latitude" if has_latitude else None + lon_dim = "lon" if has_lon else "longitude" if has_longitude else None + + # Check if both dimensions are available + if lat_dim is not None and lon_dim is not None: + ds = ds.isel(**{lat_dim: slice_h, lon_dim: slice_w}) + shape[1] = slice_h.stop - slice_h.start + shape[2] = slice_w.stop - slice_w.start + return ds, shape + + def get_mlp(self, mlp): + return DistributedMLP if self.spatial_parallelism else mlp + + def init_thd(self, _thd): + if (self.comm_get_size("spatial") > 1) and (not _thd.is_initialized()): + polar_group = self.comm_get_group("h") if self.comm_get_size("h") > 1 else None + azimuth_group = self.comm_get_group("w") if self.comm_get_size("w") > 1 else None + _thd.init(polar_group, azimuth_group) + return _thd + + def th_real_sht(self): + _thd = self.init_thd(thd) + return _thd.DistributedRealSHT if self.spatial_parallelism else th.RealSHT + + def th_inverse_real_sht(self): + _thd = self.init_thd(thd) + return _thd.DistributedInverseRealSHT if self.spatial_parallelism else th.InverseRealSHT + + def th_real_fft2(self): + return DistributedRealFFT2 if self.spatial_parallelism else th.RealFFT2 + + def th_inverse_real_fft2(self): + return DistributedInverseRealFFT2 if self.spatial_parallelism else th.InverseRealFFT2 + + def instance_norm_2d(self): + return DistributedInstanceNorm2d if self.spatial_parallelism else nn.InstanceNorm2d + + def layer_norm(self): + return DistributedLayerNorm if self.spatial_parallelism else nn.LayerNorm + + def set_image_shapes(self, trans_down, itrans_up, itrans): + img_shape_loc = None + img_shape_eff = None + h_loc = None + w_loc = None + + if self.comm_get_size("spatial") > 1: + img_shape_loc = (trans_down.lat_shapes[self.comm_get_rank("h")], + trans_down.lon_shapes[self.comm_get_rank("w")]) + img_shape_eff = (itrans_up.lat_shapes[self.comm_get_rank("h")], + itrans_up.lon_shapes[self.comm_get_rank("w")]) + h_loc = itrans.lat_shapes[self.comm_get_rank("h")] + w_loc = itrans.lon_shapes[self.comm_get_rank("w")] + else: + img_shape_loc = (trans_down.nlat, trans_down.nlon) + # CHECK: should be itrans_up? + img_shape_eff = (trans_down.nlat, trans_down.nlon) + h_loc = itrans.nlat + w_loc = itrans.nlon + + return img_shape_loc, img_shape_eff, h_loc, w_loc + + def gather_spatial_distributed(self, local_tensor, gather=True): + if gather and self.spatial_parallelism: + w_group = self.comm_get_group("w") + h_group = self.comm_get_group("h") + return gather_helper_conv(local_tensor, hdim=-2, wdim=-1, w_group=w_group, h_group=h_group) + else : + return local_tensor def barrier(self): """ @@ -278,7 +543,10 @@ def shutdown(self): self.barrier() if self._distributed: logger.debug(f"Shutting down rank {self.rank}") - torch.distributed.destroy_process_group() + if self.spatial_parallelism: + comm.cleanup() + else: + torch.distributed.destroy_process_group() singleton: Distributed | None = None diff --git a/fme/core/generics/trainer.py b/fme/core/generics/trainer.py index d52c67eb3..9e20a6c1e 100644 --- a/fme/core/generics/trainer.py +++ b/fme/core/generics/trainer.py @@ -386,10 +386,10 @@ def train(self): wandb = WandB.get_instance() wandb.log(all_logs, step=self.num_batches_seen) - if dist.is_root(): - if self.config.save_checkpoint: + if self.config.save_checkpoint: + if dist.is_root(): logging.info(f"Saving checkpoints for epoch {self._epochs_trained}") - self.save_all_checkpoints(valid_loss, inference_error) + self.save_all_checkpoints(valid_loss, inference_error) def _log_first_batch_metrics(self): wandb = WandB.get_instance() @@ -417,7 +417,6 @@ def train_one_epoch(self): ) self.train_data.set_epoch(self._epochs_trained + 1) wandb = WandB.get_instance() - dist = Distributed.get_instance() names_to_log = ("batch_loss", "training_samples_per_second_on_rank_0", "lr") aggregator = self._aggregator_builder.get_train_aggregator() n_samples_seen_since_logging = 0 @@ -470,14 +469,13 @@ def train_one_epoch(self): metrics_to_log = {k: metrics[k] for k in names_to_log if k in metrics} logging.info(f"Step {self.num_batches_seen}: {metrics_to_log}") n_samples_seen_since_logging = 0 - if ( - dist.is_root() - and self.config.checkpoint_every_n_batches > 0 + if (self.config.checkpoint_every_n_batches > 0 and self.num_batches_seen % self.config.checkpoint_every_n_batches == 0 ): self._save_restart_checkpoints() self._last_saved_num_batches_seen = self.num_batches_seen - if dist.is_root() and self.num_batches_seen > self._last_saved_num_batches_seen: + + if self.num_batches_seen > self._last_saved_num_batches_seen: self._save_restart_checkpoints() # before incrementing epoch so we will validate after resuming # noqa: E501 # we will save restart checkpoints again after validation/inference # are recorded to wandb @@ -582,8 +580,7 @@ def save_checkpoint( ema_checkpoint_path: str | None = None, include_optimization: bool = False, ): - if not Distributed.get_instance().is_root(): - raise RuntimeError("Only the root process should save checkpoints") + dist = Distributed.get_instance() # save to a temporary file in case we get pre-empted during save temporary_location = os.path.join( os.path.dirname(checkpoint_path), f".{uuid.uuid4()}.tmp" @@ -616,17 +613,17 @@ def save_checkpoint( # never include optimization in EMA checkpoint if "optimization" in ema_data: ema_data.pop("optimization") - torch.save(ema_data, ema_temporary_location) - torch.save(data, temporary_location) - if ema_temporary_location is not None and ema_checkpoint_path is not None: + if dist.is_root(): + torch.save(ema_data, ema_temporary_location) + if dist.is_root(): + torch.save(data, temporary_location) + if ema_temporary_location is not None and ema_checkpoint_path is not None: os.replace(ema_temporary_location, ema_checkpoint_path) - os.replace(temporary_location, checkpoint_path) + os.replace(temporary_location, checkpoint_path) finally: - if os.path.exists(temporary_location): - os.remove(temporary_location) - if ema_temporary_location is not None and os.path.exists( - ema_temporary_location - ): + if dist.is_root() and os.path.exists(temporary_location): + os.remove(temporary_location) + if ema_temporary_location is not None and os.path.exists(ema_temporary_location): os.remove(ema_temporary_location) def restore_checkpoint(self, checkpoint_path, ema_checkpoint_path): diff --git a/fme/core/gridded_ops.py b/fme/core/gridded_ops.py index 596faf8fa..985d5ad61 100644 --- a/fme/core/gridded_ops.py +++ b/fme/core/gridded_ops.py @@ -14,6 +14,7 @@ from fme.core.mask_provider import MaskProviderABC, NullMaskProvider from fme.core.tensors import assert_dict_allclose from fme.core.typing_ import TensorDict, TensorMapping +from fme.core.distributed import Distributed class GriddedOperations(abc.ABC): @@ -293,7 +294,12 @@ def __init__( "Area weights must be longitudinally uniform, " "as assumed for zonal mean." ) + + dist = Distributed.get_instance() + area_weights = area_weights[*dist.get_local_slices(area_weights.shape)] + self._device_area = area_weights.to(get_device()) + # NOTE: we do not need the *.to("cpu") lines. self._cpu_area = area_weights.to("cpu") self._device_mask_provider = mask_provider.to(get_device()) self._cpu_mask_provider = mask_provider.to("cpu") diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index cf9648c6a..1de2021fb 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -22,6 +22,7 @@ from fme.core.registry import CorrectorSelector, ModuleSelector from fme.core.step.step import StepABC, StepConfigABC, StepSelector from fme.core.typing_ import TensorDict, TensorMapping +from fme.core.device import get_device, using_gpu DEFAULT_TIMESTEP = datetime.timedelta(hours=6) DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP) @@ -218,6 +219,7 @@ def __init__( ) else: self.ocean = None + self.module = config.builder.build( n_in_channels=n_in_channels, n_out_channels=n_out_channels, @@ -228,8 +230,8 @@ def __init__( self._config = config self._no_optimization = NullOptimization() - dist = Distributed.get_instance() - self.module = dist.wrap_module(self.module) + self.dist = Distributed.get_instance() + self.module = self.dist.wrap_module(self.module) self._timestep = timestep @@ -237,6 +239,7 @@ def __init__( self.in_names = config.in_names self.out_names = config.out_names + @property def config(self) -> SingleModuleStepConfig: return self._config @@ -325,8 +328,10 @@ def get_state(self): Returns: The state of the stepper. """ + # iterate over parameters and gather them from the ranks + state_dict= self.dist.gather_model_state_dict(self.module) return { - "module": self.module.state_dict(), + "module": state_dict, } def load_state(self, state: dict[str, Any]) -> None: @@ -337,11 +342,13 @@ def load_state(self, state: dict[str, Any]) -> None: state: The state to load. """ module = state["module"] + #CHECK: Getting an error because this key is missing + # if I use strict=true if "module.device_buffer" in module: - # for backwards compatibility with old checkpoints - del module["module.device_buffer"] - self.module.load_state_dict(module) - + # for backwards compatibility with old checkpoints + del module["module.device_buffer"] + module=self.dist.scatter_model_state_dict(self.module, module,strict=False) + self.module.load_state_dict(module,strict=False) def step_with_adjustments( input: TensorMapping, diff --git a/fme/core/test_loss_sp.py b/fme/core/test_loss_sp.py new file mode 100644 index 000000000..558710c88 --- /dev/null +++ b/fme/core/test_loss_sp.py @@ -0,0 +1,119 @@ +import pytest +import torch +import os +import numpy as np +from fme.core import metrics +from fme.core.device import get_device +from fme.core.gridded_ops import GriddedOperations, LatLonOperations +from fme.core.loss import ( + AreaWeightedMSELoss, + CRPSLoss, + EnergyScoreLoss, + GlobalMeanLoss, + LossConfig, + StepLossConfig, + VariableWeightingLoss, + WeightedMappingLoss, + _construct_weight_tensor, +) +from fme.core.normalizer import StandardNormalizer +from fme.core.packer import Packer +from fme.ace.aggregator.one_step.reduced import MeanAggregator +from fme.core.distributed import Distributed + +@pytest.mark.parametrize("global_mean_type", [None]) +def test_loss_builds_and_runs_wo_sp(global_mean_type): + nx=8 + ny=8 + torch.manual_seed(0) + data_tensor=torch.randn(1, 2, nx, ny, device=get_device()) + example_data = { + "a":data_tensor , + } + area_weights = torch.ones(nx,ny).to(get_device())*5 + + config = LossConfig(global_mean_type=global_mean_type) + loss = config.build( + reduction="mean", + gridded_operations=LatLonOperations(area_weights), + ) + + x = torch.randn(1, 2, nx, ny, device=get_device()) + y = torch.randn(1, 2, nx, ny, device=get_device()) + + result = loss(x, y) + + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=result, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + tmp_path="testdata-loss" + torch.save(area_weights, os.path.join(tmp_path, "area_weights.pt")) + torch.save(data_tensor, os.path.join(tmp_path, "example_data.pt")) + torch.save(x, os.path.join(tmp_path, "x.pt")) + torch.save(y, os.path.join(tmp_path, "y.pt")) + print("loss", logs["metrics/loss"] ) + torch.save(logs["metrics/loss"], os.path.join(tmp_path, "loss.pt")) + +@pytest.mark.parametrize("global_mean_type", [None]) +def test_loss_builds_and_runs_with_sp(global_mean_type): + os.environ['H_PARALLEL_SIZE'] = '2' + os.environ['W_PARALLEL_SIZE'] = '2' + nx=8 + ny=8 + tmp_path="testdata-loss" + tensor_data_host = torch.load(os.path.join(tmp_path, "example_data.pt")) + x_host=torch.load(os.path.join(tmp_path, "x.pt")) + y_host=torch.load(os.path.join(tmp_path, "y.pt")) + loss_serial=torch.load(os.path.join(tmp_path, "loss.pt")) + + torch.manual_seed(0) + + # tensor_data_host=torch.randn(1, 2, nx, ny) + area_weights = torch.ones(nx,ny)*5.0 + aggregator = MeanAggregator(LatLonOperations(area_weights)) + dist = Distributed.get_instance() + this_shape=(tensor_data_host.shape[-2],tensor_data_host.shape[-1]) + tensor_data_local_host = (tensor_data_host[:,:,*dist.get_local_slices(this_shape)]).detach().clone() + tensor_data_local=tensor_data_local_host.to(dist.local_rank) + example_data = { + "a": tensor_data_local + } + + config = LossConfig(global_mean_type=global_mean_type) + loss = config.build( + reduction="mean", + gridded_operations=LatLonOperations(area_weights), + ) + + # x_host = torch.randn(1, 2, nx, ny) + # y_host = torch.randn(1, 2, nx, ny) + + this_shape_x=(x_host.shape[-2],x_host.shape[-1]) + x_local_host = (x_host[:,:,*dist.get_local_slices(this_shape_x)]).detach().clone() + x_local=x_local_host.to(dist.local_rank) + y_local_host = (y_host[:,:,*dist.get_local_slices(this_shape_x)]).detach().clone() + y_local=y_local_host.to(dist.local_rank) + + result_local = loss(x_local, y_local) + + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=result_local, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + + error_tol=1e-13 + logs = aggregator.get_logs(label="metrics") + # print("lost", logs["metrics/loss"] ) + # print("loss_serial", loss_serial ) + rel_diff = np.abs(loss_serial - logs["metrics/loss"])/loss_serial + assert rel_diff < error_tol diff --git a/fme/sht_fix.py b/fme/sht_fix.py index d1696e4a4..3f04e61b4 100644 --- a/fme/sht_fix.py +++ b/fme/sht_fix.py @@ -214,5 +214,5 @@ def forward(self, x: torch.Tensor): return x -torch_harmonics.RealSHT = RealSHT -torch_harmonics.InverseRealSHT = InverseRealSHT +# torch_harmonics.RealSHT = RealSHT +# torch_harmonics.InverseRealSHT = InverseRealSHT diff --git a/requirements.txt b/requirements.txt index 52dd58ca2..7a8ec516b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ torch>=2 wandb[media]>=0.19.0 xarray zarr>=3 +nvidia-physicsnemo