import random
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from simba.core.models.transformers.augmentation import Augmentation
[docs]
class CustomDatasetMultitasking(Dataset):
[docs]
def __init__(
self,
your_dict,
training=False,
prob_aug=1.0,
# prob_aug=0.2,
mz=None,
intensity=None,
precursor_mass=None,
precursor_charge=None,
df_smiles=None,
use_fingerprints=False,
fingerprint_0=None,
max_num_peaks=None,
use_adduct=False,
ionmode=None,
adduct_mass=None,
use_ce=False,
ce=None,
):
self.data = your_dict
self.keys = list(your_dict.keys())
self.training = training
self.prob_aug = prob_aug
self.mz = mz
self.intensity = intensity
self.precursor_mass = precursor_mass
self.precursor_charge = precursor_charge
self.df_smiles = df_smiles ### df with rows smiles, indexes
self.use_fingerprints = use_fingerprints
self.use_adduct = use_adduct
self.use_ce = use_ce
if self.use_fingerprints:
self.fingerprint_0 = fingerprint_0
self.max_num_peaks = max_num_peaks
if self.use_adduct:
self.ionmode = ionmode
self.adduct_mass = adduct_mass
if self.use_ce:
self.ce = ce
def __len__(self):
return len(self.data[self.keys[0]])
# return len(self.keys)
[docs]
def get_original_dictionary(self, max_num_peaks=100):
"""
get a dictionary containing the spectrums mapped
"""
len_data = self.data[self.keys[0]].shape[0]
## Get the mz, intensity values and precursor data
dictionary = {}
dictionary["mz_0"] = np.zeros((len_data, max_num_peaks), dtype=np.float32)
dictionary["intensity_0"] = np.zeros(
(len_data, max_num_peaks), dtype=np.float32
)
dictionary["mz_1"] = np.zeros((len_data, max_num_peaks), dtype=np.float32)
dictionary["intensity_1"] = np.zeros(
(len_data, max_num_peaks), dtype=np.float32
)
dictionary["ed"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["mces"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["precursor_mass_0"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["precursor_charge_0"] = np.zeros((len_data, 1), dtype=np.int32)
dictionary["precursor_mass_1"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["precursor_charge_1"] = np.zeros((len_data, 1), dtype=np.int32)
### add extra metadata in case it is necessary
if self.use_adduct:
dictionary["ionmode_0"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["ionmode_1"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["adduct_mass_0"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["adduct_mass_1"] = np.zeros((len_data, 1), dtype=np.float32)
if self.use_ce:
dictionary["ce_0"] = np.zeros((len_data, 1), dtype=np.float32)
dictionary["ce_1"] = np.zeros((len_data, 1), dtype=np.float32)
if self.use_fingerprints:
print("Defining fingerprints ...")
dictionary["fingerprint_0"] = np.zeros((len_data, 2048), dtype=np.int32)
for idx in tqdm(range(0, len_data)):
sample_unique = {k: self.data[k][idx] for k in self.keys}
indexes_unique_0 = sample_unique["index_unique_0"]
indexes_unique_1 = sample_unique["index_unique_1"]
print(f"value of indexes_unique_0 {indexes_unique_0} ")
indexes_original_0 = self.df_smiles.loc[int(indexes_unique_0), "indexes"][0]
indexes_original_1 = self.df_smiles.loc[int(indexes_unique_1), "indexes"][0]
dictionary["mz_0"][idx] = self.mz[indexes_original_0].astype(np.float32)
dictionary["intensity_0"][idx] = self.intensity[indexes_original_0].astype(
np.float32
)
dictionary["mz_1"][idx] = self.mz[indexes_original_1].astype(np.float32)
dictionary["intensity_1"][idx] = self.intensity[indexes_original_1].astype(
np.float32
)
dictionary["precursor_mass_0"][idx] = self.precursor_mass[
indexes_original_0
].astype(np.float32)
dictionary["precursor_mass_1"][idx] = self.precursor_mass[
indexes_original_1
].astype(np.float32)
dictionary["precursor_charge_0"][idx] = self.precursor_charge[
indexes_original_0
].astype(np.float32)
dictionary["precursor_charge_1"][idx] = self.precursor_charge[
indexes_original_1
].astype(np.float32)
dictionary["ed"][idx] = sample_unique["ed"].astype(np.float32)
dictionary["mces"][idx] = sample_unique["mces"].astype(np.float32)
if self.use_adduct:
dictionary["ionmode_0"][idx] = self.ionmode[indexes_original_0].astype(
np.float32
)
dictionary["ionmode_1"][idx] = self.ionmode[indexes_original_1].astype(
np.float32
)
dictionary["adduct_mass_0"][idx] = self.adduct_mass[
indexes_original_0
].astype(np.float32)
dictionary["adduct_mass_1"][idx] = self.adduct_mass[
indexes_original_1
].astype(np.float32)
if self.use_ce:
dictionary["ce_0"][idx] = self.ce[indexes_original_0]
dictionary["ce_1"][idx] = self.ce[indexes_original_1]
if self.use_fingerprints:
dictionary["fingerprint_0"][idx] = self.fingerprint_0[
indexes_original_0
].astype(np.float32)
return dictionary
def __getitem__(self, idx):
sample = {k: self.data[k][idx] for k in self.keys}
idx_0 = sample["index_unique_0"]
idx_1 = sample["index_unique_1"]
if self.training:
# select random samples
idx_0_original = random.choice(self.df_smiles.loc[int(idx_0[0]), "indexes"])
idx_1_original = random.choice(self.df_smiles.loc[int(idx_1[0]), "indexes"])
else:
# select the first index
idx_0_original = self.df_smiles.loc[int(idx_0[0]), "indexes"][0]
# select the last index
idx_1_original = self.df_smiles.loc[int(idx_1[0]), "indexes"][-1]
# Get the original spectrum based on indexes
spectrum_sample = {}
spectrum_sample["mz_0"] = self.mz[idx_0_original].astype(np.float32)
spectrum_sample["intensity_0"] = self.intensity[idx_0_original].astype(
np.float32
)
spectrum_sample["mz_1"] = self.mz[idx_1_original].astype(np.float32)
spectrum_sample["intensity_1"] = self.intensity[idx_1_original].astype(
np.float32
)
spectrum_sample["precursor_mass_0"] = self.precursor_mass[
idx_0_original
].astype(np.float32)
spectrum_sample["precursor_mass_1"] = self.precursor_mass[
idx_1_original
].astype(np.float32)
spectrum_sample["precursor_charge_0"] = self.precursor_charge[
idx_0_original
].astype(np.float32)
spectrum_sample["precursor_charge_1"] = self.precursor_charge[
idx_1_original
].astype(np.float32)
spectrum_sample["ed"] = sample["ed"].astype(np.float32)
spectrum_sample["mces"] = sample["mces"].astype(np.float32)
if self.use_fingerprints:
ind = int(idx_0[0])
if self.training:
if (ind % 2) == 0:
spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype(
np.float32
)
else:
# return 0s
spectrum_sample["fingerprint_0"] = 0 * self.fingerprint_0[
ind
].astype(np.float32)
else:
spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype(
np.float32
)
if self.use_adduct:
spectrum_sample["ionmode_0"] = self.ionmode[idx_0_original].astype(
np.float32
)
spectrum_sample["ionmode_1"] = self.ionmode[idx_1_original].astype(
np.float32
)
spectrum_sample["adduct_mass_0"] = self.adduct_mass[idx_0_original].astype(
np.float32
)
spectrum_sample["adduct_mass_1"] = self.adduct_mass[idx_1_original].astype(
np.float32
)
if self.use_ce:
spectrum_sample["ce_0"] = self.ce[idx_0_original].astype(np.float32)
spectrum_sample["ce_1"] = self.ce[idx_1_original].astype(np.float32)
if self.training and random.random() < self.prob_aug:
# augmentation
spectrum_sample = Augmentation.augment(
spectrum_sample, max_num_peaks=self.max_num_peaks
)
# normalize
spectrum_sample = Augmentation.normalize_intensities(spectrum_sample)
return spectrum_sample