CudaLSTM
- class hy2dl.modelzoo.cudalstm.CudaLSTM(cfg: Config)
Bases:
ModuleLSTM model.
This class implements an LSTM layer. If config.model == “cudalstm”, the LSTM layer is followed by a linear head, which maps the hidden states produced by the LSTM into predictions. Otherwise, it is assumed the model is being used as part of a larger architecture, and only the hidden states are returned.
The LSTM layer can operate either in a standard mode (hindcast only) or forecast mode. In forecast mode, the LSTM cell rolls out continuously through both the hindcast and forecast periods using specific embedding layers for each case.
- Parameters:
cfg (Config) – Configuration object containing model hyperparameters and settings.
- forward(sample: dict[str, Any]) dict[str, Tensor]
Forward pass of the LSTM network.
Processes hindcast features, and optionally concatenates forecast features along the sequence dimension, before passing them through the LSTM and linear head.
- Parameters:
sample (dict[str, Any]) – Dictionary with the different variables that will be used in the forward pass. See hy2dl.datasetzoo.basedataset.Basedataset.__getitems__() for details.
- Returns:
y_hat: model predictions, shape (B, N, T) hs: hidden states of LSTM cell, shape (B, N, cfg.hidden_size)
- Return type:
dict[str, torch.Tensor]
Notes
Shape abbreviations used: - B: batch size - N: length of the target sequence, based on predict_last_n cofiguration argument - T: number of target variables