From 2f4dcde090a94b8f46439d98f41bf1c40c10559f Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Thu, 2 Oct 2025 17:36:21 -0700 Subject: [PATCH 1/2] diffusion_algorithm.py and diffusion_model.py This PR includes basid interface for diffusion-like models. Currently, only FlowMatching is included. Other methods such as score matching or mean flow can be added later if necessary. Several NN models (DiT, AdaNet, MLPNet, SimpleMLPNet) for modeling the velocit/score/noise are also included. --- .pre-commit-config.yaml | 4 +- alf/algorithms/diffusion_algorithm.py | 338 ++++++++++++++++++ alf/algorithms/diffusion_algorithm_test.py | 220 ++++++++++++ alf/algorithms/diffusion_model.py | 385 +++++++++++++++++++++ 4 files changed, 945 insertions(+), 2 deletions(-) create mode 100644 alf/algorithms/diffusion_algorithm.py create mode 100644 alf/algorithms/diffusion_algorithm_test.py create mode 100644 alf/algorithms/diffusion_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71456ed22..0322a5353 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.0 hooks: - id: codespell - args: [ "--skip", "*.hook", "--ignore-words-list", "ans,nd,Bu,astroid,hart" ] + args: [ "--skip", "*.hook", "--ignore-words-list", "ans,nd,Bu,astroid,hart", "--ignore-multiline-regex", "codespell:ignore-begin.*codespell:ignore-end" ] diff --git a/alf/algorithms/diffusion_algorithm.py b/alf/algorithms/diffusion_algorithm.py new file mode 100644 index 000000000..c708a1cbe --- /dev/null +++ b/alf/algorithms/diffusion_algorithm.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Diffusion-like generative models driven by SDEs.""" + +import math +import torch +import alf +from alf.utils import summary_utils +from alf.utils.dist_utils import TruncatedNormal +from alf.nest import get_nest_batch_size +from alf.utils.common import expand_dims_as + + +class SDE: + r"""Base stochastic differential equation interface. + + .. math:: + + dx = -\beta(t) x dt + g(t) dw + + Sub-classes provide closed-form coefficients for specific SDE families. + + """ + + def pt0(self, t): + """Return the closed-form :math:`p(x_t | x_0)` parameters. + + Args: + t: Tensor of time values at which to evaluate the marginal. + + Returns: + Tuple ``(alpha, sigma)`` representing the linear coefficient and + standard deviation of the conditional distribution. + :math:`p(x_t | x_0) = Normal(alpha * x_0, sigma^2 I)`. + """ + raise NotImplementedError + + def pt0_dot(self, t): + """Return the derivatives of pt0 with respect to time. + + Args: + t: Tensor of time values to differentiate at. + + Returns: + Tuple ``(alpha_dot, sigma_dot)`` corresponding to the time + derivatives of the parameters of :math:`p(x_t | x_0)`. + """ + raise NotImplementedError + + def diffusion_coeff(self, t): + r"""Return the drift and diffusion coefficients :math:`(\beta, g)`. + + Args: + t: Tensor of time values at which to evaluate the coefficients. + + Returns: + Tuple ``(beta, g)`` describing the SDE drift and diffusion terms. + """ + raise NotImplementedError + + +class OTSDE(SDE): + # codespell:ignore-begin + r"""Optimal transport SDE with closed-form linear coefficients. + + This corresponds to commonly used flow matching model ("FM/OT" in https://arxiv.org/abs/2210.02747). + + .. math:: + + dx = -\frac{1}{1-t} x dt + \sqrt{\frac{2t}{1-t}} dw + + """ + + # codespell:ignore-end + + def pt0(self, t): + return 1 - t, t + + def pt0_dot(self, t): + return -1, 1 + + def diffusion_coeff(self, t): + return 1 / (1 - t), (2 * t / (1 - t)).sqrt() + + +class RTSDE(SDE): + r"""SDE with trigonometric drift and diffusion. + + .. math:: + + dx = - \frac{\pi}{2}\tan(\frac{\pi}{2}t)xdt + \sqrt{\pi\tan(\frac{\pi}{2}t)} dw + + """ + + def pt0(self, t): + angle = 0.5 * math.pi * t + return torch.cos(angle), torch.sin(angle) + + def pt0_dot(self, t): + angle = 0.5 * math.pi * t + return -0.5 * math.pi * torch.sin(angle), 0.5 * math.pi * torch.cos( + angle) + + def diffusion_coeff(self, t): + beta = 0.5 * math.pi * torch.tan(0.5 * math.pi * t) + return 0.5 * beta, (2 * beta)**0.5 + + +class SDEGenerator(torch.nn.Module): + """Base class for diffusion-like generators driven by an SDE.""" + + def __init__(self, + input_spec, + output_spec, + model_ctor, + sde: SDE, + mean_flow=False, + time_sampler=torch.rand, + steps=5): + """Initialize the generator and underlying neural model. + + Args: + input_spec: Spec describing conditional inputs. + output_spec: Spec for generated data. + model_ctor: Factory creating the predictor network. The callable + should accept ``(input_spec, output_spec, mean_flow)`` and + return an :class:`alf.networks.Network` that consumes + ``(x_t, inputs, t[, h])``. + sde: Stochastic differential equation describing the generative + process. + mean_flow: If ``True``, the model will receive an additional input + encoding the look-ahead horizon for mean flow training. + time_sampler: Callable that samples the training time ``t``. + steps: Number of Euler steps used during sampling. + """ + super().__init__() + self._model = model_ctor(input_spec, output_spec, mean_flow=mean_flow) + self._sde = sde + self._output_spec = output_spec + self._steps = steps + if isinstance(output_spec, alf.BoundedTensorSpec): + self._min = torch.tensor(output_spec.minimum) + self._max = torch.tensor(output_spec.maximum) + self._time_sampler = time_sampler + + def calc_loss(self, inputs, samples, f_neg_energy=None, sample_mask=None): + """Compute per-sample losses for training. + + Args: + inputs: Conditional input nest consumed by the model during + training. + samples: Observed ``x_0`` tensors. When ``None`` the implementation + will draw ``x_0`` from the prior distribution. + f_neg_energy: Optional callable ``f(x0, inputs) -> Tensor`` that + returns per-sample negative energies used to importance weight + the loss. + sample_mask: Optional broadcastable tensor used to zero out losses + of masked samples. + """ + raise NotImplementedError + + def sample(self, inputs, steps): + """Generate samples given the conditioning inputs. + + Args: + inputs: Conditional inputs used during generation. + steps: Number of Euler integration steps to use. + + Returns: + Generated samples matching ``output_spec``. + """ + raise NotImplementedError + + @property + def state_spec(self): + return () + + def _apply_sample_mask(self, diff, sample_mask): + """Apply a mask to the loss tensor if provided. + + Args: + diff: Tensor containing per-element loss values. + sample_mask: Optional mask tensor broadcastable to ``diff``. + + Returns: + Masked loss tensor. + """ + if sample_mask is not None: + if sample_mask.ndim == 1: + sample_mask = expand_dims_as(sample_mask, diff) + diff = diff * sample_mask + return diff + + def _calc_weights(self, x0, alpha, sigma, inputs, f_neg_energy): + """Compute importance weights based on the energy function. + + Args: + x0: Tensor of ``x_0`` samples. + alpha: Scaling factor from ``p(x_t | x_0)``. + sigma: Standard deviation from ``p(x_t | x_0)``. + inputs: Conditional inputs associated with each ``x_0`` sample. + f_neg_energy: Callable returning negative energies for importance + weighting. + + Returns: + Weights representing the likelihood of each ``x_0`` under the + energy function. + """ + with torch.no_grad(): + neg_energy = f_neg_energy(x0, inputs) + return neg_energy.exp() + + +def expand_dims(x, ndim): + """Reshape ``x`` to add ``ndim`` singleton dimensions at the end. + + Args: + x: Tensor to reshape. + ndim: Number of singleton dimensions to append. + + Returns: + Reshaped tensor with additional singleton dimensions. + """ + return x.reshape(x.shape[0], *((1, ) * ndim)) + + +@alf.configurable +class FlowMatching(SDEGenerator): + """Flow matching objective that regresses the true velocity field.""" + + def _get_x0_xt_x1(self, samples, batch_size, alpha, sigma): + """Sample ``(x_0, x_t, x_1)`` triplets for flow matching. + + Args: + samples: Optional tensor of ground-truth ``x_0`` values. + batch_size: Number of samples to draw. + alpha: Scaling factor from the SDE marginal. + sigma: Standard deviation from the SDE marginal. + + Returns: + Tuple ``(x0, xt, x1)`` consistent with the SDE marginals. + """ + if samples is None: + if isinstance(self._output_spec, alf.BoundedTensorSpec): + # For bounded data, we use a p1(.) that is uniform within the bounds. + xt = self._output_spec.sample((batch_size, )) + # x1 = self._output_spec.randn((batch_size,)) + + x1_max = (xt - alpha * self._min) / sigma + x1_min = (xt - alpha * self._max) / sigma + dist = TruncatedNormal(loc=torch.zeros_like(x1_min), + scale=torch.ones_like(x1_min), + lower_bound=x1_min, + upper_bound=x1_max) + x1 = dist.sample() + + # x1_max = x1_max.minimum(self._max) + # x1_min = x1_min.maximum(self._min) + # x1 = torch.rand((batch_size,) + self._output_spec.shape) * (x1_max - x1_min) + x1_min + + else: + xt = self._output_spec.randn((batch_size, )) + x1 = self._output_spec.randn((batch_size, )) + x0 = (xt - sigma * x1) / alpha + else: + x0 = samples + x1 = torch.randn((batch_size, ) + self._output_spec.shape, + device=samples.device) + xt = alpha * x0 + sigma * x1 + + return x0, xt, x1 + + def calc_loss(self, inputs, samples, f_neg_energy=None, sample_mask=None): + """Compute the squared error between predicted and true velocities. + + Args: + inputs: Conditional input nest. + samples: Optional tensor of ground-truth ``x_0`` values. + f_neg_energy: Optional energy function for importance weighting. + sample_mask: Optional mask applied to the per-element losses. + + Returns: + Tensor of per-sample velocity regression losses. + """ + leaf = alf.nest.extract_any_leaf_from_nest(inputs) + batch_size = leaf.shape[0] + device = leaf.device + t = self._time_sampler(batch_size, device=device) + alpha, sigma = self._sde.pt0(expand_dims(t, self._output_spec.ndim)) + x0, xt, x1 = self._get_x0_xt_x1(samples, batch_size, alpha, sigma) + vt = self._model((xt, inputs, t))[0] + alpha_dot, sigma_dot = self._sde.pt0_dot( + expand_dims(t, self._output_spec.ndim)) + cvt = alpha_dot * x0 + sigma_dot * x1 + diff = (vt - cvt)**2 + diff = self._apply_sample_mask(diff, sample_mask) + loss = diff.reshape(batch_size, -1).sum(-1) + + if f_neg_energy is not None: + p0 = self._calc_weights(x0, alpha, sigma, inputs, f_neg_energy) + loss = p0 * loss + + return loss + + def sample(self, inputs): + """Sample by integrating the learned velocity field. + + Args: + inputs: Conditional inputs used for generation. + + Returns: + Tensor of generated samples. + """ + batch_size = get_nest_batch_size(inputs) + if isinstance(self._output_spec, alf.BoundedTensorSpec): + noise = self._output_spec.sample((batch_size, )) + else: + noise = self._output_spec.randn((batch_size, )) + noise = noise * self._sde.pt0(torch.tensor(1))[1] + dt = 1 / self._steps + with torch.no_grad(): + for step in range(0, self._steps): + t = torch.full((batch_size, ), 1 - step * dt) + vt = self._model((noise, inputs, t))[0] + noise = noise - vt * dt + + return noise diff --git a/alf/algorithms/diffusion_algorithm_test.py b/alf/algorithms/diffusion_algorithm_test.py new file mode 100644 index 000000000..265fa18b6 --- /dev/null +++ b/alf/algorithms/diffusion_algorithm_test.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from alf.algorithms.diffusion_algorithm import * +from alf.algorithms.diffusion_model import SimpleMLPNet +from functools import partial + +import math +import torch +from functools import partial +import matplotlib.pyplot as plt + +torch.set_default_device('cuda') + + +def jointplot(x, y, bins=30, figure_size=(8, 8)): + """Draw scatter plot with marginal histograms aligned to the axes + """ + # scatter + marginal histograms + fig = plt.figure(figsize=figure_size) + gs = fig.add_gridspec(4, 4, wspace=0.05, hspace=0.05) + + # Main scatter plot + ax_scatter = fig.add_subplot(gs[1:4, 0:3]) + ax_scatter.plot(x, y, '.', alpha=0.5) + + # X histogram (above scatter, share x-axis) + ax_histx = fig.add_subplot(gs[0, 0:3], sharex=ax_scatter) + ax_histx.hist(x, bins=bins, color="gray") + plt.setp(ax_histx.get_xticklabels(), visible=False) # hide x labels here + + # Y histogram (to the right of scatter, share y-axis) + ax_histy = fig.add_subplot(gs[1:4, 3], sharey=ax_scatter) + ax_histy.hist(y, bins=bins, orientation='horizontal', color="gray") + plt.setp(ax_histy.get_yticklabels(), visible=False) # hide y labels here + + ax_histx.set_ylabel("count") + ax_histy.set_xlabel("count") + + plt.show() + + +def pdist(A, B): + """Distance between each pair of the two collections of inputs. + + Args: + A: (b, n, d) + B: (b, m, d) + Returns: + pairwise distances (b,n,m) + """ + A2 = (A * A).sum(dim=2, keepdim=True) # (b,n,1) + B2 = (B * B).sum(dim=2, keepdim=True) # (b,m,1) + # bmm: (B,n,d) x (B,d,m) -> (B,n,m) + M = torch.bmm(A, B.transpose(1, 2)) + D2 = A2 + B2.transpose(1, 2) - 2.0 * M + return D2.clamp_min_(0.0).sqrt_() + + +def energy_stat(sample_x, sample_y, size): + # https://en.wikipedia.org/wiki/Energy_distance#Testing_for_equal_distributions + # pairwise Euclidean norms + def _pdist(A, B): + # return (((A[:,None,:]-B[None,:,:])**2).sum(-1)).sqrt() + return pdist(A.unsqueeze(0), B.unsqueeze(0)).squeeze(0) + + X = sample_x(size) + Y = sample_y(size) + d_xy = _pdist(X, Y).mean() + d_xx = _pdist(X, X).mean() + d_yy = _pdist(Y, Y).mean() + return (2 * d_xy - d_xx - d_yy) / d_yy + + +class GMM: + name = 'gmm' + mu1 = torch.tensor([-0.5, -0.5]) + std1 = 0.25 + mu2 = torch.tensor([0.5, 0.5]) + std2 = 0.125 + prob1 = 0.3 + + def neg_energy(self, x, _): + logp1 = math.log(self.prob1) - 2 * math.log(self.std1) - 0.5 * (( + (x - self.mu1) / self.std1)**2).sum(-1) + logp2 = math.log(1 - self.prob1) - 2 * math.log(self.std2) - 0.5 * (( + (x - self.mu2) / self.std2)**2).sum(-1) + logp = torch.stack([logp1, logp2], dim=-1) + return logp.logsumexp(dim=-1) + + def sample(self, n): + r = torch.rand(n) + e = torch.randn(n, 2) + mu = torch.where(r.unsqueeze(-1) < self.prob1, self.mu1, self.mu2) + std = torch.where(r < self.prob1, self.std1, self.std2) + return mu + e * std.unsqueeze(-1) + + +def sample_f(n, generator): + inputs = torch.full((n, 1), 1.0) + return generator.sample(inputs) + + +def train(generator, dist, sample_based, batch_size=1024, ema=0.99): + if ema > 0: + averager = torch.optim.swa_utils.AveragedModel( + generator, + multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(ema)) + ema_generator = averager.module + else: + ema_generator = generator + + optimizer = torch.optim.Adam(generator.parameters(), lr=6e-4) + warmup_iters = 16 + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.001, + end_factor=1.0, + total_iters=warmup_iters, + ) + main_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, + total_iters=1000, + factor=1.0) + lr_schedule = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_iters]) + for i in range(1000): + samples = dist.sample(batch_size) if sample_based else None + f_neg_energy = dist.neg_energy if not sample_based else None + loss = generator.calc_loss(torch.full((batch_size, 1), 1.0), samples, + f_neg_energy).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_schedule.step() + if ema > 0: + averager.update_parameters(generator) + + e_stat = energy_stat(partial(sample_f, generator=ema_generator), + dist.sample, 10000) + return ema_generator, e_stat + + +def run_setting(setting): + generator = setting['generator'](input_spec=alf.TensorSpec((1, )), + output_spec=alf.TensorSpec((2, )), + model_ctor=SimpleMLPNet, + sde=setting['sde'], + steps=setting.get('steps', 20)) + name = setting['name'] + sample_based = setting['sample_based'] + ema = setting.get('ema', 0.0) + print('Running', name) + e_stats = [] + repeat = 5 + + while len(e_stats) < repeat: + dist = GMM() + model, e_stat = train(generator, + dist, + sample_based, + batch_size=1024, + ema=ema) + print('train', len(e_stats), dist.name, 'energy_stat', e_stat.item()) + if e_stat.isfinite(): + e_stats.append(e_stat.item()) + + e_stats = torch.tensor(e_stats) + print('energy stat mean:', + e_stats.mean().item(), "std:", + e_stats.std().item()) + + x = sample_f(2000, generator).cpu().numpy() + jointplot(x[:, 0], x[:, 1], bins=50) + plt.savefig(f'{setting["name"]}.png') + return model, e_stats.mean().item(), e_stats.std().item() + + +settings = [ + dict(name='ot_sample_fm_ema', + sde=OTSDE(), + generator=FlowMatching, + sample_based=True, + ema=0.99), + dict(name='rt_sample_fm_ema', + sde=RTSDE(), + generator=FlowMatching, + sample_based=True, + ema=0.99), + dict(name='ot_fm_ema', + sde=OTSDE(), + generator=FlowMatching, + sample_based=False, + ema=0.99), + dict(name='rt_fm_ema', + sde=RTSDE(), + generator=FlowMatching, + sample_based=False, + ema=0.99), +] + +if __name__ == '__main__': + results = [] + for setting in settings: + results.append(run_setting(setting)) + for setting, result in zip(settings, results): + model, e_stat_mean, e_stat_std = result + print(setting['name'], e_stat_mean, e_stat_std) diff --git a/alf/algorithms/diffusion_model.py b/alf/algorithms/diffusion_model.py new file mode 100644 index 000000000..77d45491e --- /dev/null +++ b/alf/algorithms/diffusion_model.py @@ -0,0 +1,385 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import alf + + +class Concat(torch.nn.Module): + """Module that concatenates a sequence of tensors along the last axis.""" + + def forward(self, x): + return torch.cat(x, dim=-1) + + +# Timestep embedding used in the DDPM++ and ADM architectures. +class PositionalEmbedding(torch.nn.Module): + """Positional time embedding using deterministic sinusoidal features.""" + + def __init__(self, num_channels, max_positions=10000, endpoint=False): + """Create a positional embedding module. + + Args: + num_channels: Total number of output channels for the embedding. + max_positions: Maximum number of positions used to scale the + frequencies of the sinusoidal embedding. + endpoint: Whether the highest frequency should reach the endpoint + ``1 / max_positions`` exactly. + """ + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, + end=self.num_channels // 2, + dtype=torch.float32, + device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions)**freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +# Timestep embedding used in the NCSN++ architecture. +class FourierEmbedding(torch.nn.Module): + """Random Fourier feature based time embedding.""" + + def __init__(self, num_channels, scale=16): + """Create a Fourier embedding module with random frequencies. + + Args: + num_channels: Total number of output channels for the embedding. + scale: Standard deviation used when sampling the random base + frequencies. + """ + super().__init__() + self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * math.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class MLPNet(alf.networks.Network): + """Multi-layer perceptron used to predict scores or velocities for SDEs.""" + + def __init__(self, + input_spec, + output_spec, + mean_flow=False, + hidden_dim=256, + time_embedding_type='positional'): + """Construct the MLP used for score or velocity prediction. + + Args: + input_spec: Specification of the conditional input tensor. + output_spec: Specification describing the generated tensor. + mean_flow: Whether the network receives an additional mean-flow + horizon input. + hidden_dim: Feature dimension used throughout the hidden layers and + embeddings. + time_embedding_type: Chooses between ``'positional'`` and + ``'fourier'`` time embeddings. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + self._mean_flow = mean_flow + k = 4 if mean_flow else 3 + self._model = torch.nn.Sequential( + Concat(), + torch.nn.Linear(k * hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, output_spec.numel), + ) + self._time_embedding = (PositionalEmbedding( + num_channels=hidden_dim, endpoint=True) if time_embedding_type + == 'positional' else FourierEmbedding( + num_channels=hidden_dim)) + self._cond_embedding = torch.nn.Linear(in_features=input_spec.numel, + out_features=hidden_dim, + bias=False) + self._x_embedding = torch.nn.Linear(in_features=output_spec.numel, + out_features=hidden_dim, + bias=False) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + h = inputs[3] if self._mean_flow else None + embeddings = [self._x_embedding(x), self._time_embedding(t)] + if h is not None: + embeddings.append(self._time_embedding(h)) + embeddings.append(self._cond_embedding(cond)) + x = self._model(embeddings) + return x, state + + +class DiTBlock(torch.nn.Module): + """Transformer block with adaptive layer normalization conditioning.""" + + def __init__(self, d_model, d_ff, cond_dim, num_heads): + """Initialize the DiT block. + + Args: + d_model: Transformer hidden size. + d_ff: Hidden size of the feed-forward network inside the block. + cond_dim: Dimensionality of the conditioning vector. + num_heads: Number of attention heads. + """ + super().__init__() + self._norm1 = torch.nn.LayerNorm(d_model, elementwise_affine=False) + self._attn = torch.nn.MultiheadAttention(d_model, + num_heads=num_heads, + batch_first=True) + self._norm2 = torch.nn.LayerNorm(d_model, elementwise_affine=False) + self._fc1 = alf.layers.FC(d_model, + d_ff, + activation=torch.nn.functional.silu) + self._fc2 = alf.layers.FC(d_ff, d_model) + self._cond_mlp = torch.nn.Sequential( + # torch.nn.SiLU(), + alf.layers.FC(cond_dim, 6 * d_model, use_bias=True)) + + def forward(self, inputs): + x, cond = inputs + scale1, shift1, gate1, scale2, shift2, gate2 = self._cond_mlp( + cond).unsqueeze(1).chunk(6, dim=-1) + h = self._norm1(x) + h = torch.addcmul(shift1, h, 1 + scale1) + attn_output, _ = self._attn(h, h, h) + x.addcmul(attn_output, gate1) + h = self._norm2(x) + h = torch.addcmul(shift2, h, 1 + scale2) + h = self._fc1(h) + h = self._fc2(h) + return x.addcmul(h, gate2) + + +class DiT(alf.networks.Network): + """Diffusion Transformer architecture for sequence-shaped outputs.""" + + def __init__(self, + input_spec, + output_spec, + d_model=128, + num_heads=4, + num_blocks=2, + time_embedding_dim=32, + mean_flow=False): + """Create a DiT network tailored for ALF tensor specs. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor shaped as a + sequence. + d_model: Transformer hidden size. + num_heads: Number of attention heads in each block. + num_blocks: Number of stacked transformer blocks. + time_embedding_dim: Dimensionality of sinusoidal time embeddings. + mean_flow: Whether the network receives the extra mean-flow + horizon input. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + + self._blocks = torch.nn.ModuleList() + assert output_spec.ndim == 2, "DiT only supports 2D output" + self._in_proj = alf.layers.FC(output_spec.shape[1], d_model) + length = output_spec.shape[0] + self._pe = torch.nn.Parameter(torch.zeros(1, length, d_model)) + self._out_proj = alf.layers.FC(d_model, output_spec.shape[1]) + cond_dim = sum(spec.numel for spec in alf.nest.flatten(input_spec)) + cond_dim += time_embedding_dim * (2 if mean_flow else 1) + self._time_embedding = PositionalEmbedding( + num_channels=time_embedding_dim, endpoint=True) + self._cond_mlp = torch.nn.Sequential( + alf.layers.FC(cond_dim, + d_model, + activation=torch.nn.functional.silu), + alf.layers.FC(d_model, + d_model, + activation=torch.nn.functional.silu), + ) + for _ in range(num_blocks): + self._blocks.append( + DiTBlock(d_model, 4 * d_model, d_model, num_heads)) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + assert x.ndim == 3, "DiT only supports 3D input" + h = inputs[3] if len(inputs) == 4 else None + embeddings = alf.nest.flatten(cond) + embeddings.append(self._time_embedding(t)) + if h is not None: + embeddings.append(self._time_embedding(h)) + cond = torch.cat(embeddings, dim=-1) + cond = self._cond_mlp(cond) + + x = self._in_proj(x) + self._pe + for block in self._blocks: + x = block((x, cond)) + x = self._out_proj(x) + return x, state + + +class AdaLnBlock(torch.nn.Module): + """Residual block with adaptive layer normalization conditioning.""" + + def __init__(self, in_dim, out_dim, hidden_dim, cond_dim): + """Configure the adaptive layer normalization block. + + Args: + in_dim: Size of the input feature dimension. + out_dim: Size of the output feature dimension. + hidden_dim: Hidden dimension for the internal MLP. + cond_dim: Dimensionality of the conditioning vector applied to AdaLN. + """ + super().__init__() + self._norm = torch.nn.LayerNorm(in_dim, elementwise_affine=False) + self._fc1 = alf.layers.FC(in_dim, + hidden_dim, + activation=torch.nn.functional.silu) + self._fc2 = alf.layers.FC(hidden_dim, out_dim) + self._ada = torch.nn.Sequential( + # torch.nn.SiLU(), + alf.layers.FC(cond_dim, 3 * in_dim, use_bias=True)) + + def forward(self, inputs): + x, cond = inputs + h = self._norm(x) + scale, shift, gate = self._ada(cond).chunk(3, dim=-1) + h = h * (1 + scale) + shift + h = self._fc1(h) + h = self._fc2(h) + return x + h * gate + + +class AdaNet(alf.networks.Network): + """Fully-connected network with AdaLN blocks for diffusion modeling.""" + + def __init__(self, + input_spec, + output_spec, + d_model=256, + num_blocks=2, + time_embedding_dim=32, + mean_flow=False): + """Create an AdaNet model for diffusion-based generation. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor. + d_model: Hidden size of the AdaLN blocks. + num_blocks: Number of stacked AdaLN residual blocks. + time_embedding_dim: Dimensionality of sinusoidal embeddings. + mean_flow: Whether to include the additional mean-flow horizon. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + + self._blocks = torch.nn.ModuleList() + self._in_proj = alf.layers.FC(output_spec.numel, d_model) + self._out_proj = alf.layers.FC(d_model, output_spec.numel) + cond_dim = sum(spec.numel for spec in alf.nest.flatten(input_spec)) + cond_dim += time_embedding_dim * (2 if mean_flow else 1) + self._time_embedding = PositionalEmbedding( + num_channels=time_embedding_dim, endpoint=True) + self._cond_mlp = torch.nn.Sequential( + alf.layers.FC(cond_dim, + d_model, + activation=torch.nn.functional.silu), + alf.layers.FC(d_model, + d_model, + activation=torch.nn.functional.silu), + ) + for _ in range(num_blocks): + self._blocks.append( + AdaLnBlock(d_model, d_model, 4 * d_model, d_model)) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + x_shape = x.shape + x = x.reshape(x.shape[0], -1) + h = inputs[3] if len(inputs) == 4 else None + embeddings = alf.nest.flatten(cond) + embeddings.append(self._time_embedding(t)) + if h is not None: + embeddings.append(self._time_embedding(h)) + cond = torch.cat(embeddings, dim=-1) + cond = self._cond_mlp(cond) + + x = self._in_proj(x) + for block in self._blocks: + x = block((x, cond)) + x = self._out_proj(x) + x = x.reshape(*x_shape) + return x, state + + +class SimpleMLPNet(alf.networks.Network): + """Compact MLP baseline for score or velocity prediction.""" + + def __init__(self, input_spec, output_spec, mean_flow=False): + """Construct a simple baseline MLP network. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor. + mean_flow: Whether to include a mean-flow horizon input. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + activation = torch.nn.GELU + + self._model = torch.nn.Sequential( + Concat(), + torch.nn.Linear( + output_spec.numel + input_spec.numel + (2 if mean_flow else 1), + 256), + activation(), + torch.nn.Linear(256, 256), + activation(), + torch.nn.Linear(256, 256), + activation(), + torch.nn.Linear(256, output_spec.numel), + ) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + h = inputs[3] if len(inputs) == 4 else None + embeddings = [x, cond, t.unsqueeze(-1)] + if h is not None: + embeddings.append(h.unsqueeze(-1)) + x = self._model(embeddings) + return x, state From fb71cf2a4d21007755e95e584d435892f042b6e1 Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Wed, 8 Oct 2025 14:29:45 -0700 Subject: [PATCH 2/2] Fix codespell issue --- alf/algorithms/agent.py | 2 +- alf/algorithms/mcts_algorithm.py | 3 +++ alf/algorithms/muzero_representation_learner.py | 3 +++ alf/algorithms/taac_algorithm.py | 2 +- alf/summary/render.py | 2 +- alf/utils/losses.py | 4 ++-- 6 files changed, 11 insertions(+), 5 deletions(-) diff --git a/alf/algorithms/agent.py b/alf/algorithms/agent.py index 804a24ea7..d997c06e6 100644 --- a/alf/algorithms/agent.py +++ b/alf/algorithms/agent.py @@ -506,7 +506,7 @@ def preprocess_experience(self, root_inputs, rollout_info, batch_info): def summarize_rollout(self, experience): """First call ``RLAlgorithm.summarize_rollout()`` to summarize basic - rollout statisics. If the rl algorithm has overridden this function, + rollout statistics. If the rl algorithm has overridden this function, then also call its customized version. """ super(Agent, self).summarize_rollout(experience) diff --git a/alf/algorithms/mcts_algorithm.py b/alf/algorithms/mcts_algorithm.py index 291e7789a..7b5aeccc4 100644 --- a/alf/algorithms/mcts_algorithm.py +++ b/alf/algorithms/mcts_algorithm.py @@ -227,6 +227,7 @@ def _add_node(name: str, properties: dict): @alf.configurable class MCTSAlgorithm(OffPolicyAlgorithm): + # codespell:ignore-begin r"""Monte-Carlo Tree Search algorithm. The code largely follows the pseudocode of @@ -300,6 +301,8 @@ class MCTSAlgorithm(OffPolicyAlgorithm): extend these k' paths are most promising according to the UCB scores. """ + # codespell:ignore-end + def __init__( self, observation_spec, diff --git a/alf/algorithms/muzero_representation_learner.py b/alf/algorithms/muzero_representation_learner.py index f4bb97d45..4d9ce3a50 100644 --- a/alf/algorithms/muzero_representation_learner.py +++ b/alf/algorithms/muzero_representation_learner.py @@ -60,6 +60,7 @@ @alf.configurable class MuzeroRepresentationImpl(OffPolicyAlgorithm): + # codespell:ignore-begin """MuZero-style Representation Learner. MuZero is described in the paper: @@ -85,6 +86,8 @@ class MuzeroRepresentationImpl(OffPolicyAlgorithm): """ + # codespell:ignore-end + def __init__( self, observation_spec, diff --git a/alf/algorithms/taac_algorithm.py b/alf/algorithms/taac_algorithm.py index dc7dbe86f..2c6aa6905 100644 --- a/alf/algorithms/taac_algorithm.py +++ b/alf/algorithms/taac_algorithm.py @@ -230,7 +230,7 @@ class TaacAlgorithmBase(OffPolicyAlgorithm): In a nutsell, for inference TAAC adds a second stage that chooses between a candidate trajectory :math:`\hat{\tau}` output by an SAC actor and the previous trajectory :math:`\tau^-`. For policy evaluation, TAAC uses a compare-through Q - operator for TD backup by re-using state-action sequences that have shared + operator for TD backup by reusing state-action sequences that have shared actions between rollout and training. For policy improvement, the new actor gradient is approximated by multiplying a scaling factor to the :math:`\frac{\partial Q}{\partial a}` term in the original SAC’s actor diff --git a/alf/summary/render.py b/alf/summary/render.py index 5b0653fb4..94e179d26 100644 --- a/alf/summary/render.py +++ b/alf/summary/render.py @@ -266,7 +266,7 @@ def is_rendering_enabled(): def _rendering_wrapper(rendering_func): """A wrapper function to gate the rendering function based on if rendering is enabled, and if yes generate a scoped rendering identifier before - calling the rendering function. It re-uses the scope stack in ``alf.summary.summary_ops.py``. + calling the rendering function. It reuses the scope stack in ``alf.summary.summary_ops.py``. """ @functools.wraps(rendering_func) diff --git a/alf/utils/losses.py b/alf/utils/losses.py index 68521b75e..809f4bc27 100644 --- a/alf/utils/losses.py +++ b/alf/utils/losses.py @@ -129,7 +129,7 @@ def iqn_huber_loss(value: torch.Tensor, is between this and the target. target: the time-major tensor for return, this is used as the target for computing the loss. - next_delta_tau: the sampled increments of the probability for the input + next_delta_tau: the sampled increments of the probability for the input of the quantile function of the target critics. fixed_tau: the fixed increments of probability, for non iqn style quantile regression. @@ -166,7 +166,7 @@ def iqn_huber_loss(value: torch.Tensor, error = loss_fn(diff) if iqn_tau: if diff.ndim - tau_hat.ndim > 1: - # For multidimentional reward: + # For multidimensional reward: # diff is of shape [T or T-1, B, reward_dim, n_quantiles, n_quantiles] # while tau_hat and next_delta_tau have shape [T or T-1, B, n_quantiles] tau_hat = tau_hat.unsqueeze(-2)