Source code for simba.core.models.transformers.embedder

import lightning.pytorch as pl
import numpy as np

# import pandas as pd
# import ppx
# import seaborn as sns
import torch
import torch.nn as nn

from simba.core.models.transformers.spectrum_transformer_encoder_custom import (
    SpectrumTransformerEncoderCustom,
)
from simba.logger_setup import logger


class FixedLinearRegression(nn.Module):
    """
    linear layer for computing sum of dot product
    """

    def __init__(self, d_model):
        super().__init__()
        self.weight = nn.Parameter(
            torch.ones(1, d_model)
        )  # Fixed weight initialized to 1
        self.bias = nn.Parameter(torch.zeros(1))  # Bias initialized to 0

        # Freeze the parameters
        self.weight.requires_grad = False
        self.bias.requires_grad = False

    def forward(self, x):
        return torch.matmul(x, self.weight.t()) + self.bias


[docs] class Embedder(pl.LightningModule): """It receives a set of pairs of molecules and it must train the similarity model based on it. Embed spectra."""
[docs] def __init__( self, d_model, n_layers, dropout=0.1, weights=None, lr=None, use_element_wise=True, use_cosine_distance=True, # element wise instead of concat for mixing info between embeddings use_adduct=False, categorical_adducts=False, adduct_mass_map="", use_ce=False, use_ion_activation=False, use_ion_method=False, ): """Initialize the CCSPredictor""" super().__init__() self.weights = weights # Add a linear layer for projection self.use_element_wise = use_element_wise self.linear = nn.Linear(d_model, d_model) self.linear_regression = nn.Linear(d_model, 1) self.fixed_linear_regression = FixedLinearRegression(d_model) self.relu = nn.ReLU() self.use_adduct = use_adduct self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method self.spectrum_encoder = SpectrumTransformerEncoderCustom( d_model=d_model, n_layers=n_layers, dropout=dropout, use_adduct=use_adduct, categorical_adducts=categorical_adducts, adduct_mass_map=adduct_mass_map, use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, ) self.regression_loss = nn.MSELoss() self.dropout = nn.Dropout(p=dropout) self.train_loss_list = [] self.val_loss_list = [] self.lr = lr self.use_cosine_distance = use_cosine_distance if self.use_cosine_distance: self.linear_cosine = nn.Linear(d_model, d_model) self.cosine_similarity = nn.CosineSimilarity(dim=1) self.use_cosine_library = True
# print(f"Using cosine library from Pytorch?: {self.use_cosine_library}")
[docs] def normalized_dot_product(self, a, b): # Normalize inputs a_norm = torch.nn.functional.normalize(a, p=2, dim=-1) b_norm = torch.nn.functional.normalize(b, p=2, dim=-1) # Compute dot product dot_product = torch.sum(a_norm * b_norm, dim=-1) return dot_product
[docs] def forward(self, batch): """The inference pass""" kwargs_0 = { "precursor_mass": batch["precursor_mass_0"].float(), "precursor_charge": batch["precursor_charge_0"].float(), } kwargs_1 = { "precursor_mass": batch["precursor_mass_1"].float(), "precursor_charge": batch["precursor_charge_1"].float(), } # extra data if self.use_adduct: kwargs_0["ionmode"] = batch["ionmode_0"].float() kwargs_1["ionmode"] = batch["ionmode_1"].float() kwargs_0["adduct_mass"] = batch["adduct_mass_0"].float() kwargs_1["adduct_mass"] = batch["adduct_mass_1"].float() if self.use_ce: logger.info("Using CE in the model") kwargs_0["ce"] = batch["ce_0"].float() kwargs_1["ce"] = batch["ce_1"].float() if self.use_ion_activation: kwargs_0["ion_activation"] = batch["ion_activation_0"].float() kwargs_1["ion_activation"] = batch["ion_activation_1"].float() if self.use_ion_method: kwargs_0["ion_method"] = batch["ion_method_0"].float() kwargs_1["ion_method"] = batch["ion_method_1"].float() emb0, _ = self.spectrum_encoder( mz_array=batch["mz_0"].float(), intensity_array=batch["intensity_0"].float(), **kwargs_0, ) emb1, _ = self.spectrum_encoder( mz_array=batch["mz_1"].float(), intensity_array=batch["intensity_1"].float(), **kwargs_1, ) emb0 = emb0[:, 0, :] emb1 = emb1[:, 0, :] emb0 = self.relu(emb0) emb1 = self.relu(emb1) if self.use_cosine_distance: if self.use_cosine_library: emb = self.cosine_similarity(emb0, emb1) # Reshape the tensor emb = emb.reshape(-1, 1) else: # ensure the embeddings are positive emb0_l2 = torch.norm(emb0, p=2, dim=-1, keepdim=True) emb1_l2 = torch.norm(emb1, p=2, dim=-1, keepdim=True) emb = (emb0 * emb1) / (emb0_l2 * emb1_l2) emb = self.fixed_linear_regression(emb) # emb = (emb+1)/2 else: emb = emb0 + emb1 emb = self.linear(emb) emb = self.dropout(emb) emb = self.relu(emb) emb = self.linear_regression(emb) return emb
[docs] def step(self, batch, batch_idx, threshold=0.5): """A training/validation/inference step.""" spec = self(batch) target = torch.tensor(batch["similarity"]).to(self.device) target = target.view(-1) # adjust scale # target = 2*(target-0.5) loss = self.regression_loss(spec.float(), target.view(-1, 1).float()).float() return loss.float()
[docs] def training_step(self, batch, batch_idx): """A training step""" loss = self.step(batch, batch_idx) # self.train_loss_list.append(loss.item()) self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss
[docs] def validation_step(self, batch, batch_idx): """A validation step""" loss = self.step(batch, batch_idx) # self.val_loss_list.append(loss.item()) self.log("validation_loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss
[docs] def predict_step(self, batch, batch_idx): """A predict step""" spec = self(batch) # if self.use_cosine_library: # spec= (spec+1)/2 return spec
[docs] def configure_optimizers(self): """Configure the optimizer for training.""" # optimizer = DAdaptAdam(self.parameters(), lr=1) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # optimizer = torch.optim.RAdam(self.parameters(), lr=1e-3) return optimizer
[docs] def load_weights(self): weights = {} for name, param in self.named_parameters(): weights[name] = np.array(param.data) return weights
[docs] def load_pretrained_maldi_embedder(self, model_path): # original weights original_weights = self.load_weights() # Load weights from the checkpoint checkpoint = torch.load( model_path, map_location="cpu", ) # Load weights into model B from the checkpoint checkpoint_keys = checkpoint["state_dict"].keys() original_embedder_keys = ( self.state_dict().keys() ) # Assuming `model` is your target model # Load weights for shared layers for key in checkpoint_keys: if key in original_embedder_keys: self.state_dict()[key].copy_(checkpoint["state_dict"][key]) # new weights new_weights = self.load_weights() ## sanity check (the weights of the model changed?): if not (self.are_weights_changed(original_weights, new_weights)): print("INFO: Correctly loaded pretrained Maldi Model") else: raise ValueError("ERROR!!!: Error loading Maldi model")
[docs] def are_weights_changed( self, original_weights, new_weights, layer_test="spectrum_encoder.transformer_encoder.layers.0.norm2.bias", ): return np.array_equal(original_weights[layer_test], new_weights[layer_test])
[docs] def set_freeze_layers(self, layer_names_to_freeze, freeze): # Freeze specified layers for name, param in self.named_parameters(): if any(layer_name in name for layer_name in layer_names_to_freeze): param.requires_grad = not (freeze) else: param.requires_grad = True
[docs] def get_maldi_embedder_keys(self, model_path): # Load weights from the checkpoint checkpoint = torch.load( model_path, map_location="cpu", ) # Load weights into model B from the checkpoint return checkpoint["state_dict"].keys()
[docs] def get_all_keys(self): return self.state_dict().keys()