import lightning.pytorch as pl
import numpy as np
# import pandas as pd
# import ppx
# import seaborn as sns
import torch
import torch.nn as nn
from simba.core.models.transformers.spectrum_transformer_encoder_custom import (
SpectrumTransformerEncoderCustom,
)
from simba.logger_setup import logger
class FixedLinearRegression(nn.Module):
"""
linear layer for computing sum of dot product
"""
def __init__(self, d_model):
super().__init__()
self.weight = nn.Parameter(
torch.ones(1, d_model)
) # Fixed weight initialized to 1
self.bias = nn.Parameter(torch.zeros(1)) # Bias initialized to 0
# Freeze the parameters
self.weight.requires_grad = False
self.bias.requires_grad = False
def forward(self, x):
return torch.matmul(x, self.weight.t()) + self.bias
[docs]
class Embedder(pl.LightningModule):
"""It receives a set of pairs of molecules and it must train the similarity model based on it. Embed spectra."""
[docs]
def __init__(
self,
d_model,
n_layers,
dropout=0.1,
weights=None,
lr=None,
use_element_wise=True,
use_cosine_distance=True, # element wise instead of concat for mixing info between embeddings
use_adduct=False,
categorical_adducts=False,
adduct_mass_map="",
use_ce=False,
use_ion_activation=False,
use_ion_method=False,
):
"""Initialize the CCSPredictor"""
super().__init__()
self.weights = weights
# Add a linear layer for projection
self.use_element_wise = use_element_wise
self.linear = nn.Linear(d_model, d_model)
self.linear_regression = nn.Linear(d_model, 1)
self.fixed_linear_regression = FixedLinearRegression(d_model)
self.relu = nn.ReLU()
self.use_adduct = use_adduct
self.use_ce = use_ce
self.use_ion_activation = use_ion_activation
self.use_ion_method = use_ion_method
self.spectrum_encoder = SpectrumTransformerEncoderCustom(
d_model=d_model,
n_layers=n_layers,
dropout=dropout,
use_adduct=use_adduct,
categorical_adducts=categorical_adducts,
adduct_mass_map=adduct_mass_map,
use_ce=use_ce,
use_ion_activation=use_ion_activation,
use_ion_method=use_ion_method,
)
self.regression_loss = nn.MSELoss()
self.dropout = nn.Dropout(p=dropout)
self.train_loss_list = []
self.val_loss_list = []
self.lr = lr
self.use_cosine_distance = use_cosine_distance
if self.use_cosine_distance:
self.linear_cosine = nn.Linear(d_model, d_model)
self.cosine_similarity = nn.CosineSimilarity(dim=1)
self.use_cosine_library = True
# print(f"Using cosine library from Pytorch?: {self.use_cosine_library}")
[docs]
def normalized_dot_product(self, a, b):
# Normalize inputs
a_norm = torch.nn.functional.normalize(a, p=2, dim=-1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=-1)
# Compute dot product
dot_product = torch.sum(a_norm * b_norm, dim=-1)
return dot_product
[docs]
def forward(self, batch):
"""The inference pass"""
kwargs_0 = {
"precursor_mass": batch["precursor_mass_0"].float(),
"precursor_charge": batch["precursor_charge_0"].float(),
}
kwargs_1 = {
"precursor_mass": batch["precursor_mass_1"].float(),
"precursor_charge": batch["precursor_charge_1"].float(),
}
# extra data
if self.use_adduct:
kwargs_0["ionmode"] = batch["ionmode_0"].float()
kwargs_1["ionmode"] = batch["ionmode_1"].float()
kwargs_0["adduct_mass"] = batch["adduct_mass_0"].float()
kwargs_1["adduct_mass"] = batch["adduct_mass_1"].float()
if self.use_ce:
logger.info("Using CE in the model")
kwargs_0["ce"] = batch["ce_0"].float()
kwargs_1["ce"] = batch["ce_1"].float()
if self.use_ion_activation:
kwargs_0["ion_activation"] = batch["ion_activation_0"].float()
kwargs_1["ion_activation"] = batch["ion_activation_1"].float()
if self.use_ion_method:
kwargs_0["ion_method"] = batch["ion_method_0"].float()
kwargs_1["ion_method"] = batch["ion_method_1"].float()
emb0, _ = self.spectrum_encoder(
mz_array=batch["mz_0"].float(),
intensity_array=batch["intensity_0"].float(),
**kwargs_0,
)
emb1, _ = self.spectrum_encoder(
mz_array=batch["mz_1"].float(),
intensity_array=batch["intensity_1"].float(),
**kwargs_1,
)
emb0 = emb0[:, 0, :]
emb1 = emb1[:, 0, :]
emb0 = self.relu(emb0)
emb1 = self.relu(emb1)
if self.use_cosine_distance:
if self.use_cosine_library:
emb = self.cosine_similarity(emb0, emb1)
# Reshape the tensor
emb = emb.reshape(-1, 1)
else:
# ensure the embeddings are positive
emb0_l2 = torch.norm(emb0, p=2, dim=-1, keepdim=True)
emb1_l2 = torch.norm(emb1, p=2, dim=-1, keepdim=True)
emb = (emb0 * emb1) / (emb0_l2 * emb1_l2)
emb = self.fixed_linear_regression(emb)
# emb = (emb+1)/2
else:
emb = emb0 + emb1
emb = self.linear(emb)
emb = self.dropout(emb)
emb = self.relu(emb)
emb = self.linear_regression(emb)
return emb
[docs]
def step(self, batch, batch_idx, threshold=0.5):
"""A training/validation/inference step."""
spec = self(batch)
target = torch.tensor(batch["similarity"]).to(self.device)
target = target.view(-1)
# adjust scale
# target = 2*(target-0.5)
loss = self.regression_loss(spec.float(), target.view(-1, 1).float()).float()
return loss.float()
[docs]
def training_step(self, batch, batch_idx):
"""A training step"""
loss = self.step(batch, batch_idx)
# self.train_loss_list.append(loss.item())
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
[docs]
def validation_step(self, batch, batch_idx):
"""A validation step"""
loss = self.step(batch, batch_idx)
# self.val_loss_list.append(loss.item())
self.log("validation_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
[docs]
def predict_step(self, batch, batch_idx):
"""A predict step"""
spec = self(batch)
# if self.use_cosine_library:
# spec= (spec+1)/2
return spec
[docs]
def load_weights(self):
weights = {}
for name, param in self.named_parameters():
weights[name] = np.array(param.data)
return weights
[docs]
def load_pretrained_maldi_embedder(self, model_path):
# original weights
original_weights = self.load_weights()
# Load weights from the checkpoint
checkpoint = torch.load(
model_path,
map_location="cpu",
)
# Load weights into model B from the checkpoint
checkpoint_keys = checkpoint["state_dict"].keys()
original_embedder_keys = (
self.state_dict().keys()
) # Assuming `model` is your target model
# Load weights for shared layers
for key in checkpoint_keys:
if key in original_embedder_keys:
self.state_dict()[key].copy_(checkpoint["state_dict"][key])
# new weights
new_weights = self.load_weights()
## sanity check (the weights of the model changed?):
if not (self.are_weights_changed(original_weights, new_weights)):
print("INFO: Correctly loaded pretrained Maldi Model")
else:
raise ValueError("ERROR!!!: Error loading Maldi model")
[docs]
def are_weights_changed(
self,
original_weights,
new_weights,
layer_test="spectrum_encoder.transformer_encoder.layers.0.norm2.bias",
):
return np.array_equal(original_weights[layer_test], new_weights[layer_test])
[docs]
def set_freeze_layers(self, layer_names_to_freeze, freeze):
# Freeze specified layers
for name, param in self.named_parameters():
if any(layer_name in name for layer_name in layer_names_to_freeze):
param.requires_grad = not (freeze)
else:
param.requires_grad = True
[docs]
def get_maldi_embedder_keys(self, model_path):
# Load weights from the checkpoint
checkpoint = torch.load(
model_path,
map_location="cpu",
)
# Load weights into model B from the checkpoint
return checkpoint["state_dict"].keys()
[docs]
def get_all_keys(self):
return self.state_dict().keys()