LSTM mix-density-networks
- class hy2dl.modelzoo.lstmmdn.LSTMMDN(cfg: Config)
Bases:
ModuleLSTM with Mixture Density Network (MDN) head layer.
This class implements an LSTM layer followed by a MDN head, which maps the hidden states produced by the LSTM into the parameters of a mixture distribution.
- Parameters:
cfg (Config) – Configuration object containing model hyperparameters and settings.
Notation
----------
B (-)
L (-)
N (-)
K (-)
T (-)
S (-)
Q (-)
- forward(sample)
Forward pass of LSTM-MDN
Processes hindcast features, and optionally concatenates forecast features along the sequence dimension, before passing them through the LSTM. The LSTM returns thethe parameters and weights of the mixture distribution of predictions.
- Parameters:
sample (dict[str, Any]) – Dictionary with the different variables that will be used in the forward pass. See hy2dl.datasetzoo.basedataset.Basedataset.__getitems__() for details.
- Returns:
Dictionary containing: - ‘y_hat’: expected value of the mixture distribution, shape [B, N, T] - ‘params’: dict of distribution parameters [B, N, K, T] - ‘weights’: mixture weights of shape [B, N, K, T]
- Return type:
dict