Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 97 additions & 45 deletions tests/test_distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,38 +111,65 @@ def _gather_helper_bwd(self, tensor, transform_dist):

@parameterized.expand(
[
[64, 128, 32, 8, "equiangular", False, 1e-7,1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[65, 128, 1, 10, "equiangular", False, 1e-6, 1e-6],
[65, 128, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
[4, 8, 1, 10, "equiangular", False, 1e-6, 1e-6],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[64, 128, 1, 10, "equiangular", True, 1e-6, 1e-6],
[65, 128, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# lmax automatically determined
# Scalar SHT
[32, 64, None, 32, 8, "equiangular", False, 1e-7,1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, None, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, None, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
[ 4, 8, None, 1, 10, "equiangular", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[32, 64, None, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, None, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# downsampling:
# Scalar SHT
[32, 64, 8, 32, 8, "equiangular", False, 1e-7,1e-9],
[32, 64, 8, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, 9, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, 8, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, 8, 32, 8, "equiangular", True, 1e-7,1e-9],
[32, 64, 8, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[33, 64, 9, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, 8, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# upsampling
# Scalar SHT
[32, 64, 64, 32, 8, "equiangular", False, 1e-7,1e-9],
[32, 64, 64, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, 65, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, 64, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, 64, 32, 8, "equiangular", True, 1e-7,1e-9],
[32, 64, 64, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[33, 64, 65, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, 64, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],


], skip_on_empty=True
)
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, atol, rtol, verbose=False):
def test_distributed_sht(self, nlat, nlon, lmax, batch_size, num_chan, grid, vector, atol, rtol, verbose=False):

set_seed(333)

B, C, H, W = batch_size, num_chan, nlat, nlon

# set up handles
if vector:
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
else:
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)

# create tensors
if vector:
Expand Down Expand Up @@ -187,38 +214,63 @@ def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, a

@parameterized.expand(
[
[64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[65, 128, 1, 10, "equiangular", False, 1e-6, 1e-6],
[65, 128, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9],
[64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[65, 128, 1, 10, "equiangular", True, 1e-6, 1e-6],
[65, 128, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# lmax automatically determined
# Scalar SHT
[32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, None, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, None, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[33, 64, None, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, None, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# downsampling (SHT is upsampling)
# Scalar SHT
[32, 64, 64, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, 64, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, 65, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, 64, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, 64, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, 64, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[33, 64, 65, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, 64, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
# upsampling (SHT is downsampling)
# Scalar SHT
[32, 64, 8, 32, 8, "equiangular", False, 1e-7, 1e-9],
[32, 64, 8, 32, 8, "legendre-gauss", False, 1e-7, 1e-9],
[33, 64, 9, 1, 10, "equiangular", False, 1e-6, 1e-6],
[33, 64, 8, 1, 10, "legendre-gauss", False, 1e-6, 1e-6],
# Vector SHT
[32, 64, 8, 32, 8, "equiangular", True, 1e-7, 1e-9],
[32, 64, 8, 32, 8, "legendre-gauss", True, 1e-7, 1e-9],
[33, 64, 9, 1, 10, "equiangular", True, 1e-6, 1e-6],
[33, 64, 8, 1, 10, "legendre-gauss", True, 1e-6, 1e-6],
], skip_on_empty=True
)
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, atol, rtol, verbose=True):
def test_distributed_isht(self, nlat, nlon, lmax, batch_size, num_chan, grid, vector, atol, rtol, verbose=True):

set_seed(333)

B, C, H, W = batch_size, num_chan, nlat, nlon

if vector:
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
else:
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device)

# create tensors
if vector:
Expand Down
16 changes: 4 additions & 12 deletions tests/test_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,14 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose=

set_seed(333)

if grid == "equiangular":
mmax = nlat // 2
elif grid == "lobatto":
mmax = nlat - 1
else:
mmax = nlat
lmax = mmax

# init on cpu
sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
sht_host = th.RealSHT(nlat, nlon, grid=grid, norm=norm)
isht_host = th.InverseRealSHT(nlat, nlon, grid=grid, norm=norm)

# init on device
with torch.device(self.device):
sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
sht_device = th.RealSHT(nlat, nlon, grid=grid, norm=norm)
isht_device = th.InverseRealSHT(nlat, nlon, grid=grid, norm=norm)

self.assertTrue(compare_tensors(f"sht weights", sht_host.weights.cpu(), sht_device.weights.cpu(), atol=atol, rtol=rtol, verbose=verbose))
self.assertTrue(compare_tensors(f"isht weights", isht_host.pct.cpu(), isht_device.pct.cpu(), atol=atol, rtol=rtol, verbose=verbose))
Expand Down
25 changes: 5 additions & 20 deletions torch_harmonics/distributed/distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torch_harmonics.truncation import truncate_sht
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
from torch_harmonics.fft import rfft, irfft
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Expand Down Expand Up @@ -144,10 +145,7 @@ def forward(self, x: torch.Tensor):
x = distributed_transpose_azimuth(x, (-3, -1), self.lon_shapes)

# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")

# truncate
x = x[..., :self.mmax]
x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward")

# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
Expand Down Expand Up @@ -300,13 +298,8 @@ def forward(self, x: torch.Tensor):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth(x, (-3, -1), self.m_shapes)

# set DCT and nyquist frequencies to 0:
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0

# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
x = irfft(x, n=self.nlon, dim=-1, norm="forward")

# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
Expand Down Expand Up @@ -423,10 +416,7 @@ def forward(self, x: torch.Tensor):
x = distributed_transpose_azimuth(x, (-4, -1), self.lon_shapes)

# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")

# truncate
x = x[..., :self.mmax]
x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward")

# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
Expand Down Expand Up @@ -601,13 +591,8 @@ def forward(self, x: torch.Tensor):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth(x, (-4, -1), self.m_shapes)

# set DCT and nyquist frequencies to zero
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0

# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
x = irfft(x, n=self.nlon, dim=-1, norm="forward")

# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
Expand Down
86 changes: 86 additions & 0 deletions torch_harmonics/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2026 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import Optional

import torch
import torch.fft as fft
import torch.nn as nn


def _pad_dim_right(x: torch.Tensor, dim: int, target_size: int, value: float = 0.0) -> torch.Tensor:
"""Pad tensor along a single dimension to target_size (right-side only)."""
ndim = x.ndim
dim = dim if dim >= 0 else ndim + dim
pad_amount = target_size - x.shape[dim]
# F.pad expects (left, right) for last dim, then second-to-last, etc.
pad_spec = [0] * (2 * ndim)
pad_spec[(ndim - 1 - dim) * 2 + 1] = pad_amount
return nn.functional.pad(x, tuple(pad_spec), value=value)


def rfft(x: torch.Tensor, nmodes: Optional[int] = None, dim: int = -1, **kwargs) -> torch.Tensor:
"""
Real FFT with the correct padding behavior.
If nmodes is given and larger than x.size(dim), x is zero-padded along dim before FFT.
"""

if "n" in kwargs:
raise ValueError("The 'n' argument is not allowed. Use 'nmodes' instead.")

x = fft.rfft(x, dim=dim, **kwargs)

if nmodes is not None and nmodes > x.shape[dim]:
x = _pad_dim_right(x, dim, nmodes, value=0.0)
elif nmodes is not None and nmodes < x.shape[dim]:
x = x.narrow(dim, 0, nmodes)

return x

def irfft(x: torch.Tensor, n: Optional[int] = None, dim: int = -1, **kwargs) -> torch.Tensor:
"""
Torch version of IRFFT handles paddign and truncation correctly.
This routine only applies Hermitian symmetry to avoid artifacts which occur depending on the backend.
"""

if n is None:
n = 2 * (x.size(dim) - 1)

# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (n % 2 == 0) and (n // 2 < x.size(dim)):
x[..., n // 2].imag = 0.0

x = fft.irfft(x, n=n, dim=dim, **kwargs)

return x
Loading