Loss functions

class hy2dl.training.loss.BaseLoss(cfg: Config)

Bases: Module

Abstract base class to ensure all losses use the same format in the forward pass

forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hy2dl.training.loss.NLL(cfg: Config)

Bases: BaseLoss

Negative log-likelihood.

Calculate negative log-likelihood (i.e. the log probability of y_obs given a mixture distribution), applying an optional weight to each target variable.

forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hy2dl.training.loss.NSEBasinAveraged(cfg: Config)

Bases: BaseLoss

Basin-averaged Nash–Sutcliffe Efficiency.

Loss function where the squared errors are weighed by the standard deviation of each basin. A description of this function is available at [1].

References

forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
Parameters:
  • pred (dict[str, torch.Tensor]) – Model predictions

  • sample (dict[str, Any]) – Dictionary containing observed targets and other sample information

Returns:

Value of the basin-averaged NSE

Return type:

torch.Tensor

class hy2dl.training.loss.WeightedMSE(cfg: Config)

Bases: BaseLoss

Weighted Mean Squared Error.

Calculates the MSE between simulated and observed targets, applying an optional weight to each target variable.

forward(pred: dict[str, Tensor], sample: dict[str, Any]) Tensor
Parameters:
  • pred (dict[str, torch.Tensor]) – Model predictions

  • sample (dict[str, Any]) – Dictionary containing observed targets and other sample information

Returns:

Value of the weighted MSE.

Return type:

torch.Tensor