Sampler
- class hy2dl.utils.sampler.GaugeBatchSampler(valid_samples: ndarray, batch_size: int)
Bases:
SamplerProduces batches of indices that belong to the same gauge_id
This sampler ensures that all samples in a batch come from the same gauge_id. This is necessary in evaluation (validation/testing) where we need to process the data for each gauge separately.
- Parameters:
valid_samples (numpy.ndarray) – A 1D structured numpy array containing the valid samples. Fields are: - ‘gauge_id’ (object): The ID of the basin/gauge. - ‘date’ (datetime64[ns]): The timestamp of the sample. - ‘source’ (object): The data source (e.g., ‘obs’ or ‘fc’).
batch_size (int) – Size of each batch
Note (valid_samples must be sorted by 'gauge_id' and 'date': np.sort(self.valid_samples, order=["gauge_id", "date"]))