diff --git a/tests/test_distributed_sht.py b/tests/test_distributed_sht.py index 4ada3c4f..19939f8d 100644 --- a/tests/test_distributed_sht.py +++ b/tests/test_distributed_sht.py @@ -111,26 +111,53 @@ def _gather_helper_bwd(self, tensor, transform_dist): @parameterized.expand( [ - [64, 128, 32, 8, "equiangular", False, 1e-7,1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [65, 128, 1, 10, "equiangular", False, 1e-6, 1e-6], - [65, 128, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], - [4, 8, 1, 10, "equiangular", False, 1e-6, 1e-6], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [64, 128, 1, 10, "equiangular", True, 1e-6, 1e-6], - [65, 128, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # lmax automatically determined + # Scalar SHT + [32, 64, None, 32, 8, "equiangular", False, 1e-7,1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, None, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, None, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + [ 4, 8, None, 1, 10, "equiangular", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [32, 64, None, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, None, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # downsampling: + # Scalar SHT + [32, 64, 8, 32, 8, "equiangular", False, 1e-7,1e-9], + [32, 64, 8, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, 9, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, 8, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, 8, 32, 8, "equiangular", True, 1e-7,1e-9], + [32, 64, 8, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [33, 64, 9, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, 8, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # upsampling + # Scalar SHT + [32, 64, 64, 32, 8, "equiangular", False, 1e-7,1e-9], + [32, 64, 64, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, 65, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, 64, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, 64, 32, 8, "equiangular", True, 1e-7,1e-9], + [32, 64, 64, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [33, 64, 65, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, 64, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + + ], skip_on_empty=True ) - def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, atol, rtol, verbose=False): + def test_distributed_sht(self, nlat, nlon, lmax, batch_size, num_chan, grid, vector, atol, rtol, verbose=False): set_seed(333) @@ -138,11 +165,11 @@ def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, a # set up handles if vector: - forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) - forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) + forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) else: - forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device) - forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) + forward_transform_local = th.RealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) # create tensors if vector: @@ -187,38 +214,63 @@ def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, a @parameterized.expand( [ - [64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", False, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], - [65, 128, 1, 10, "equiangular", False, 1e-6, 1e-6], - [65, 128, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [64, 128, 32, 8, "equiangular", True, 1e-7, 1e-9], - [64, 128, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], - [65, 128, 1, 10, "equiangular", True, 1e-6, 1e-6], - [65, 128, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # lmax automatically determined + # Scalar SHT + [32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, None, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, None, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, None, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [33, 64, None, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, None, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # downsampling (SHT is upsampling) + # Scalar SHT + [32, 64, 64, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, 64, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, 65, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, 64, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, 64, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, 64, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [33, 64, 65, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, 64, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], + # upsampling (SHT is downsampling) + # Scalar SHT + [32, 64, 8, 32, 8, "equiangular", False, 1e-7, 1e-9], + [32, 64, 8, 32, 8, "legendre-gauss", False, 1e-7, 1e-9], + [33, 64, 9, 1, 10, "equiangular", False, 1e-6, 1e-6], + [33, 64, 8, 1, 10, "legendre-gauss", False, 1e-6, 1e-6], + # Vector SHT + [32, 64, 8, 32, 8, "equiangular", True, 1e-7, 1e-9], + [32, 64, 8, 32, 8, "legendre-gauss", True, 1e-7, 1e-9], + [33, 64, 9, 1, 10, "equiangular", True, 1e-6, 1e-6], + [33, 64, 8, 1, 10, "legendre-gauss", True, 1e-6, 1e-6], ], skip_on_empty=True ) - def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, atol, rtol, verbose=True): + def test_distributed_isht(self, nlat, nlon, lmax, batch_size, num_chan, grid, vector, atol, rtol, verbose=True): set_seed(333) B, C, H, W = batch_size, num_chan, nlat, nlon if vector: - forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) - backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) - backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) + forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) else: - forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device) - backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) - backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) + forward_transform_local = th.RealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) + backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, lmax=lmax, mmax=lmax, grid=grid).to(self.device) # create tensors if vector: diff --git a/tests/test_sht.py b/tests/test_sht.py index 7991c4c1..fab5986b 100644 --- a/tests/test_sht.py +++ b/tests/test_sht.py @@ -219,22 +219,14 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose= set_seed(333) - if grid == "equiangular": - mmax = nlat // 2 - elif grid == "lobatto": - mmax = nlat - 1 - else: - mmax = nlat - lmax = mmax - # init on cpu - sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm) - isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm) + sht_host = th.RealSHT(nlat, nlon, grid=grid, norm=norm) + isht_host = th.InverseRealSHT(nlat, nlon, grid=grid, norm=norm) # init on device with torch.device(self.device): - sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm) - isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm) + sht_device = th.RealSHT(nlat, nlon, grid=grid, norm=norm) + isht_device = th.InverseRealSHT(nlat, nlon, grid=grid, norm=norm) self.assertTrue(compare_tensors(f"sht weights", sht_host.weights.cpu(), sht_device.weights.cpu(), atol=atol, rtol=rtol, verbose=verbose)) self.assertTrue(compare_tensors(f"isht weights", isht_host.pct.cpu(), isht_device.pct.cpu(), atol=atol, rtol=rtol, verbose=verbose)) diff --git a/torch_harmonics/distributed/distributed_sht.py b/torch_harmonics/distributed/distributed_sht.py index cf48aa91..98311728 100644 --- a/torch_harmonics/distributed/distributed_sht.py +++ b/torch_harmonics/distributed/distributed_sht.py @@ -38,6 +38,7 @@ from torch_harmonics.truncation import truncate_sht from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly +from torch_harmonics.fft import rfft, irfft from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim @@ -144,10 +145,7 @@ def forward(self, x: torch.Tensor): x = distributed_transpose_azimuth(x, (-3, -1), self.lon_shapes) # apply real fft in the longitudinal direction: make sure to truncate to nlon - x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward") - - # truncate - x = x[..., :self.mmax] + x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward") # transpose: after this, m is split and c is local if self.comm_size_azimuth > 1: @@ -300,13 +298,8 @@ def forward(self, x: torch.Tensor): if self.comm_size_azimuth > 1: x = distributed_transpose_azimuth(x, (-3, -1), self.m_shapes) - # set DCT and nyquist frequencies to 0: - x[..., 0].imag = 0.0 - if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]): - x[..., self.nlon // 2].imag = 0.0 - # apply the inverse (real) FFT - x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + x = irfft(x, n=self.nlon, dim=-1, norm="forward") # transpose: after this, m is split and channels are local if self.comm_size_azimuth > 1: @@ -423,10 +416,7 @@ def forward(self, x: torch.Tensor): x = distributed_transpose_azimuth(x, (-4, -1), self.lon_shapes) # apply real fft in the longitudinal direction: make sure to truncate to nlon - x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward") - - # truncate - x = x[..., :self.mmax] + x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward") # transpose: after this, m is split and c is local if self.comm_size_azimuth > 1: @@ -601,13 +591,8 @@ def forward(self, x: torch.Tensor): if self.comm_size_azimuth > 1: x = distributed_transpose_azimuth(x, (-4, -1), self.m_shapes) - # set DCT and nyquist frequencies to zero - x[..., 0].imag = 0.0 - if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]): - x[..., self.nlon // 2].imag = 0.0 - # apply the inverse (real) FFT - x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + x = irfft(x, n=self.nlon, dim=-1, norm="forward") # transpose: after this, m is split and channels are local if self.comm_size_azimuth > 1: diff --git a/torch_harmonics/fft.py b/torch_harmonics/fft.py new file mode 100644 index 00000000..6c9f5000 --- /dev/null +++ b/torch_harmonics/fft.py @@ -0,0 +1,86 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2026 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Optional + +import torch +import torch.fft as fft +import torch.nn as nn + + +def _pad_dim_right(x: torch.Tensor, dim: int, target_size: int, value: float = 0.0) -> torch.Tensor: + """Pad tensor along a single dimension to target_size (right-side only).""" + ndim = x.ndim + dim = dim if dim >= 0 else ndim + dim + pad_amount = target_size - x.shape[dim] + # F.pad expects (left, right) for last dim, then second-to-last, etc. + pad_spec = [0] * (2 * ndim) + pad_spec[(ndim - 1 - dim) * 2 + 1] = pad_amount + return nn.functional.pad(x, tuple(pad_spec), value=value) + + +def rfft(x: torch.Tensor, nmodes: Optional[int] = None, dim: int = -1, **kwargs) -> torch.Tensor: + """ + Real FFT with the correct padding behavior. + If nmodes is given and larger than x.size(dim), x is zero-padded along dim before FFT. + """ + + if "n" in kwargs: + raise ValueError("The 'n' argument is not allowed. Use 'nmodes' instead.") + + x = fft.rfft(x, dim=dim, **kwargs) + + if nmodes is not None and nmodes > x.shape[dim]: + x = _pad_dim_right(x, dim, nmodes, value=0.0) + elif nmodes is not None and nmodes < x.shape[dim]: + x = x.narrow(dim, 0, nmodes) + + return x + +def irfft(x: torch.Tensor, n: Optional[int] = None, dim: int = -1, **kwargs) -> torch.Tensor: + """ + Torch version of IRFFT handles paddign and truncation correctly. + This routine only applies Hermitian symmetry to avoid artifacts which occur depending on the backend. + """ + + if n is None: + n = 2 * (x.size(dim) - 1) + + # ensure that imaginary part of 0 and nyquist components are zero + # this is important because not all backend algorithms provided through the + # irfft interface ensure that + x[..., 0].imag = 0.0 + if (n % 2 == 0) and (n // 2 < x.size(dim)): + x[..., n // 2].imag = 0.0 + + x = fft.irfft(x, n=n, dim=dim, **kwargs) + + return x \ No newline at end of file diff --git a/torch_harmonics/sht.py b/torch_harmonics/sht.py index 2af8dba3..38765a7f 100644 --- a/torch_harmonics/sht.py +++ b/torch_harmonics/sht.py @@ -36,6 +36,7 @@ from torch_harmonics.truncation import truncate_sht from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly +from torch_harmonics.fft import rfft, irfft class RealSHT(nn.Module): @@ -120,7 +121,7 @@ def forward(self, x: torch.Tensor): assert x.shape[-1] == self.nlon # apply real fft in the longitudinal direction - x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward") # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -235,14 +236,7 @@ def forward(self, x: torch.Tensor): # apply the inverse (real) FFT x = torch.view_as_complex(xs) - # ensure that imaginary part of 0 and nyquist components are zero - # this is important because not all backend algorithms provided through the - # irfft interface ensure that - x[..., 0].imag = 0.0 - if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax): - x[..., self.nlon // 2].imag = 0.0 - - x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + x = irfft(x, n=self.nlon, dim=-1, norm="forward") return x @@ -334,7 +328,7 @@ def forward(self, x: torch.Tensor): assert x.shape[-1] == self.nlon # apply real fft in the longitudinal direction - x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + x = 2.0 * torch.pi * rfft(x, nmodes=self.mmax, dim=-1, norm="forward") # do the Legendre-Gauss quadrature x = torch.view_as_real(x) @@ -468,13 +462,6 @@ def forward(self, x: torch.Tensor): # apply the inverse (real) FFT x = torch.view_as_complex(xs) - # ensure that imaginary part of 0 and nyquist components are zero - # this is important because not all backend algorithms provided through the - # irfft interface ensure that - x[..., 0].imag = 0.0 - if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax): - x[..., self.nlon // 2].imag = 0.0 - - x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + x = irfft(x, n=self.nlon, dim=-1, norm="forward") return x