From f668631753897b75cedfc6982ef52b703d23c7e3 Mon Sep 17 00:00:00 2001 From: Castorp <50649074+ShinDongWoon@users.noreply.github.com> Date: Wed, 24 Sep 2025 21:28:43 +0900 Subject: [PATCH] Restore 3D TimesNet flow and adjust embedding inputs --- src/timesnet_forecast/models/timesnet.py | 62 ++++++++++++++++-------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/src/timesnet_forecast/models/timesnet.py b/src/timesnet_forecast/models/timesnet.py index 5b60f30..0a80f14 100644 --- a/src/timesnet_forecast/models/timesnet.py +++ b/src/timesnet_forecast/models/timesnet.py @@ -365,6 +365,11 @@ def __init__( def forward( self, x: torch.Tensor, x_mark: torch.Tensor | None = None ) -> torch.Tensor: + if x.ndim != 3: + raise ValueError("DataEmbedding expects input shaped [B, L, C]") + if x_mark is not None and x_mark.ndim != 3: + raise ValueError("x_mark must share dimensions [B, L, T]") + value = self.value_embedding(x) pos = self.position_embedding(x) temporal = ( @@ -595,7 +600,7 @@ def _ensure_embedding( id_feature_dim = self.id_embed_dim total_per_series = 1 + static_out_dim + id_feature_dim - total_in = c_in * total_per_series + total_in_features = c_in * total_per_series static_feature_dim = static_out_dim + id_feature_dim if static_feature_dim > 0: @@ -627,12 +632,13 @@ def _ensure_embedding( self.embedding is None or self.output_proj is None or self.sigma_proj is None - or self._augmented_in_features != total_in + or self._per_series_feature_dim != total_per_series + or self._augmented_in_features != total_in_features ) if rebuild_embedding: time_arg = None if time_dim == 0 else time_dim self.embedding = DataEmbedding( - c_in=total_in, + c_in=total_in_features, d_model=self.d_model, dropout=self.dropout, time_features=time_arg, @@ -646,7 +652,7 @@ def _ensure_embedding( device=x.device, dtype=x.dtype ) self.output_dim = c_in - self._augmented_in_features = total_in + self._augmented_in_features = total_in_features self._per_series_feature_dim = total_per_series else: if self.output_dim != c_in: @@ -660,6 +666,7 @@ def _ensure_embedding( self.output_proj = self.output_proj.to(device=x.device, dtype=x.dtype) self.sigma_proj = self.sigma_proj.to(device=x.device, dtype=x.dtype) self._per_series_feature_dim = total_per_series + self._augmented_in_features = total_in_features if total_per_series <= 1: # A per-series LayerNorm with a single feature would zero-out the @@ -796,7 +803,7 @@ def forward( static_rep = static_concat.unsqueeze(1).expand(-1, time_len, -1, -1) per_series_features.append(static_rep) - combined = ( + combined_per_series = ( torch.cat(per_series_features, dim=-1) if len(per_series_features) > 1 else per_series_features[0] @@ -804,20 +811,33 @@ def forward( assert ( self.pre_embedding_norm is not None ), "pre_embedding_norm should have been initialised by _ensure_embedding" - combined = self.pre_embedding_norm(combined) - combined = combined.reshape(B, time_len, -1) + combined_per_series = self.pre_embedding_norm(combined_per_series) + combined = combined_per_series.reshape(B, time_len, -1) + if ( + self._augmented_in_features is not None + and combined.size(-1) != self._augmented_in_features + ): + raise RuntimeError( + "Flattened feature dimension mismatch with configured embedding" + ) combined = self.pre_embedding_dropout(combined) - features = self.embedding(combined, mark_slice) # type: ignore[arg-type] + features = self.embedding(combined, mark_slice) + if features.ndim != 3: + raise RuntimeError( + "Embedding output must have shape [B, L, d_model]" + ) if features.shape[-1] != self.d_model: raise RuntimeError( "Embedding output dimension mismatch with configured d_model" ) - d_model = features.size(-1) + if features.size(1) != self.input_len: + raise RuntimeError("Embedded sequence length mismatch with input_len") feat_t = features.permute(0, 2, 1).contiguous() - feat_flat = feat_t.reshape(B * d_model, self.input_len) + feat_flat = feat_t.reshape(B * self.d_model, self.input_len) extended = self.predict_linear(feat_flat) - extended = extended.view(B, d_model, self.input_len + self.pred_len) + extended = extended.view(B, self.d_model, self.input_len + self.pred_len) features = extended.permute(0, 2, 1).contiguous() + total_len = features.size(1) if self.debug_memory and features.is_cuda and torch.cuda.is_available(): mem_bytes = torch.cuda.memory_allocated(features.device) @@ -831,26 +851,28 @@ def forward( for block in self.blocks: object.__setattr__(block, "period_selector", self.period_selector) - preview_periods, _ = self.period_selector(features) + seq_features = features + + preview_periods, _ = self.period_selector(seq_features) if preview_periods.numel() == 0: mu = enc_x.new_zeros(B, self._out_steps, N) sigma = self._sigma_from_ref(mu) return mu, sigma for block in self.blocks: - if features.shape[-1] != self.d_model: + if seq_features.shape[-1] != self.d_model: raise RuntimeError( "Residual input to TimesBlock must maintain d_model channels" ) if self.use_checkpoint: - updated = checkpoint(block, features, use_reentrant=False) + updated = checkpoint(block, seq_features, use_reentrant=False) else: - updated = block(features) - delta = updated - features - features = features + self.residual_dropout(delta) - features = self.layer_norm(features) - target_features = features[:, -target_steps:, :].contiguous() - mu = self.output_proj(target_features) # type: ignore[operator] + updated = block(seq_features) + delta = updated - seq_features + seq_features = seq_features + self.residual_dropout(delta) + seq_features = self.layer_norm(seq_features) + target_features = seq_features[:, -target_steps:, :].contiguous() + mu = self.output_proj(target_features) assert self.sigma_proj is not None # for type checkers floor = self._sigma_from_ref(mu) sigma_head = self.sigma_proj(target_features)