Loss functions
- class hy2dl.training.loss.BaseLoss(cfg: Config)
Bases:
ModuleAbstract base class to ensure all losses use the same format in the forward pass
- forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hy2dl.training.loss.NLL(cfg: Config)
Bases:
BaseLossNegative log-likelihood.
Calculate negative log-likelihood (i.e. the log probability of y_obs given a mixture distribution), applying an optional weight to each target variable.
- forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hy2dl.training.loss.NSEBasinAveraged(cfg: Config)
Bases:
BaseLossBasin-averaged Nash–Sutcliffe Efficiency.
Loss function where the squared errors are weighed by the standard deviation of each basin. A description of this function is available at [1].
References
- forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
- Parameters:
pred (dict[str, torch.Tensor]) – Model predictions
sample (dict[str, Any]) – Dictionary containing observed targets and other sample information
- Returns:
Value of the basin-averaged NSE
- Return type:
torch.Tensor
- class hy2dl.training.loss.WeightedMSE(cfg: Config)
Bases:
BaseLossWeighted Mean Squared Error.
Calculates the MSE between simulated and observed targets, applying an optional weight to each target variable.
- forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
- Parameters:
pred (dict[str, torch.Tensor]) – Model predictions
sample (dict[str, Any]) – Dictionary containing observed targets and other sample information
- Returns:
Value of the weighted MSE.
- Return type:
torch.Tensor