Source code for simba.core.models.transformers.CustomDatasetEncoder

from torch.utils.data import Dataset

from simba.core.models.transformers.augmentation import Augmentation


[docs] class CustomDatasetEncoder(Dataset):
[docs] def __init__(self, data): self.data = data self.keys = list(self.data.keys())
def __len__(self): return self.data[self.keys[0]].shape[0] def __getitem__(self, idx): # key = self.keys[idx] # sample = self.data[key] # print(idx) samples = {k: self.data[k][idx] for k in self.keys} samples = Augmentation.normalize_intensities( samples, intensity_labels=["intensity"] ) return samples