Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions src/timesnet_forecast/models/timesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def __init__(
def forward(
self, x: torch.Tensor, x_mark: torch.Tensor | None = None
) -> torch.Tensor:
if x.ndim != 3:
raise ValueError("DataEmbedding expects input shaped [B, L, C]")
if x_mark is not None and x_mark.ndim != 3:
raise ValueError("x_mark must share dimensions [B, L, T]")

value = self.value_embedding(x)
pos = self.position_embedding(x)
temporal = (
Expand Down Expand Up @@ -595,7 +600,7 @@ def _ensure_embedding(
id_feature_dim = self.id_embed_dim

total_per_series = 1 + static_out_dim + id_feature_dim
total_in = c_in * total_per_series
total_in_features = c_in * total_per_series

static_feature_dim = static_out_dim + id_feature_dim
if static_feature_dim > 0:
Expand Down Expand Up @@ -627,12 +632,13 @@ def _ensure_embedding(
self.embedding is None
or self.output_proj is None
or self.sigma_proj is None
or self._augmented_in_features != total_in
or self._per_series_feature_dim != total_per_series
or self._augmented_in_features != total_in_features
)
if rebuild_embedding:
time_arg = None if time_dim == 0 else time_dim
self.embedding = DataEmbedding(
c_in=total_in,
c_in=total_in_features,
d_model=self.d_model,
dropout=self.dropout,
time_features=time_arg,
Expand All @@ -646,7 +652,7 @@ def _ensure_embedding(
device=x.device, dtype=x.dtype
)
self.output_dim = c_in
self._augmented_in_features = total_in
self._augmented_in_features = total_in_features
self._per_series_feature_dim = total_per_series
else:
if self.output_dim != c_in:
Expand All @@ -660,6 +666,7 @@ def _ensure_embedding(
self.output_proj = self.output_proj.to(device=x.device, dtype=x.dtype)
self.sigma_proj = self.sigma_proj.to(device=x.device, dtype=x.dtype)
self._per_series_feature_dim = total_per_series
self._augmented_in_features = total_in_features

if total_per_series <= 1:
# A per-series LayerNorm with a single feature would zero-out the
Expand Down Expand Up @@ -796,28 +803,41 @@ def forward(
static_rep = static_concat.unsqueeze(1).expand(-1, time_len, -1, -1)
per_series_features.append(static_rep)

combined = (
combined_per_series = (
torch.cat(per_series_features, dim=-1)
if len(per_series_features) > 1
else per_series_features[0]
)
assert (
self.pre_embedding_norm is not None
), "pre_embedding_norm should have been initialised by _ensure_embedding"
combined = self.pre_embedding_norm(combined)
combined = combined.reshape(B, time_len, -1)
combined_per_series = self.pre_embedding_norm(combined_per_series)
combined = combined_per_series.reshape(B, time_len, -1)
if (
self._augmented_in_features is not None
and combined.size(-1) != self._augmented_in_features
):
raise RuntimeError(
"Flattened feature dimension mismatch with configured embedding"
)
combined = self.pre_embedding_dropout(combined)
features = self.embedding(combined, mark_slice) # type: ignore[arg-type]
features = self.embedding(combined, mark_slice)
if features.ndim != 3:
raise RuntimeError(
"Embedding output must have shape [B, L, d_model]"
)
if features.shape[-1] != self.d_model:
raise RuntimeError(
"Embedding output dimension mismatch with configured d_model"
)
d_model = features.size(-1)
if features.size(1) != self.input_len:
raise RuntimeError("Embedded sequence length mismatch with input_len")
feat_t = features.permute(0, 2, 1).contiguous()
feat_flat = feat_t.reshape(B * d_model, self.input_len)
feat_flat = feat_t.reshape(B * self.d_model, self.input_len)
extended = self.predict_linear(feat_flat)
extended = extended.view(B, d_model, self.input_len + self.pred_len)
extended = extended.view(B, self.d_model, self.input_len + self.pred_len)
features = extended.permute(0, 2, 1).contiguous()
total_len = features.size(1)

if self.debug_memory and features.is_cuda and torch.cuda.is_available():
mem_bytes = torch.cuda.memory_allocated(features.device)
Expand All @@ -831,26 +851,28 @@ def forward(
for block in self.blocks:
object.__setattr__(block, "period_selector", self.period_selector)

preview_periods, _ = self.period_selector(features)
seq_features = features

preview_periods, _ = self.period_selector(seq_features)
if preview_periods.numel() == 0:
mu = enc_x.new_zeros(B, self._out_steps, N)
sigma = self._sigma_from_ref(mu)
return mu, sigma

for block in self.blocks:
if features.shape[-1] != self.d_model:
if seq_features.shape[-1] != self.d_model:
raise RuntimeError(
"Residual input to TimesBlock must maintain d_model channels"
)
if self.use_checkpoint:
updated = checkpoint(block, features, use_reentrant=False)
updated = checkpoint(block, seq_features, use_reentrant=False)
else:
updated = block(features)
delta = updated - features
features = features + self.residual_dropout(delta)
features = self.layer_norm(features)
target_features = features[:, -target_steps:, :].contiguous()
mu = self.output_proj(target_features) # type: ignore[operator]
updated = block(seq_features)
delta = updated - seq_features
seq_features = seq_features + self.residual_dropout(delta)
seq_features = self.layer_norm(seq_features)
target_features = seq_features[:, -target_steps:, :].contiguous()
mu = self.output_proj(target_features)
assert self.sigma_proj is not None # for type checkers
floor = self._sigma_from_ref(mu)
sigma_head = self.sigma_proj(target_features)
Expand Down