BaseTester
- class hy2dl.evaluation.basetester.BaseTester(cfg: Config, evaluation_dataset: BaseDataset)
Bases:
objectClass to process and store evaluation results.
This class is inherited by other evaluator subclasses (e.g. simulation_evaluator, forecast_evaluator) to produce and store the evaluation results.
- Parameters:
cfg (Config) – Configuration object containing model hyperparameters and settings.
evaluation_dataset (BaseDataset) – Dataset used for evaluation.
- evaluate_model(model: Module)
Evaluate the model and store the results in a zarr file.
- Parameters:
model (torch.nn.Module) – Model to evaluate.
- validate_model(model: Module, epoch: int, filter_mask: DataArray = None)
Validate the model every cfg.validate_every epochs and calculate the validation metric.
- Parameters:
model (torch.nn.Module) – Model to evaluate.
epoch (int) – Current epoch number.
forecast_mode (bool) – True if the dataset is from a forecast model (with lead_time dimension), False if from a simulation model.
filter_mask (xr.DataArray, optional) – Boolean DataArray to filter values during evaluation. Expected dimensions (gauge_id, date).