BaseTester

class hy2dl.evaluation.basetester.BaseTester(cfg: Config, evaluation_dataset: BaseDataset)

Bases: object

Class to process and store evaluation results.

This class is inherited by other evaluator subclasses (e.g. simulation_evaluator, forecast_evaluator) to produce and store the evaluation results.

Parameters:
  • cfg (Config) – Configuration object containing model hyperparameters and settings.

  • evaluation_dataset (BaseDataset) – Dataset used for evaluation.

evaluate_model(model: Module)

Evaluate the model and store the results in a zarr file.

Parameters:

model (torch.nn.Module) – Model to evaluate.

validate_model(model: Module, epoch: int, filter_mask: DataArray = None)

Validate the model every cfg.validate_every epochs and calculate the validation metric.

Parameters:
  • model (torch.nn.Module) – Model to evaluate.

  • epoch (int) – Current epoch number.

  • forecast_mode (bool) – True if the dataset is from a forecast model (with lead_time dimension), False if from a simulation model.

  • filter_mask (xr.DataArray, optional) – Boolean DataArray to filter values during evaluation. Expected dimensions (gauge_id, date).