Module library.classes.generators

Expand source code
import os
import tensorflow as tf
import numpy as np
from scipy.ndimage import gaussian_filter
import sys
import linecache
from library.parser import get_cg_at_datasets
from library.static.topologies import DOPC_CG_NAME_TO_TYPE_MAP, DOPC_BEAD_TYPE_NAME_IDS, DOPC_ELEMENT_TYPE_NAME_IDS
from Bio.PDB.PDBIO import Select
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB.PDBParser import PDBParser
from library.static.vector_mappings import DOPC_AT_MAPPING, DOPC_CG_MAPPING
import matplotlib.pyplot as plt
import random

PADDING_X = 1  # Padding on left and right
PADDING_Y = 1  # Padding on top and bottom

AVERAGE_SIMULATION_BOX_SIZE = (191.05054545, 191.05054545, 111.67836364)  # This is the average box size as fallback if the box size cannot be found in the csv file
PBC_CUTOFF = 10.0  # If a bond is longer than this, it is considered to be on the other side of the simulation box
BOX_SCALE_FACTOR = 10.0  # The scale factor to downscale the bond vectors with
CG_BOUNDING_BOX_RELATION_PATH = "data/box_sizes_cg.csv"  # Path to the csv file that contains the bounding box sizes for the cg residues
AT_BOUNDING_BOX_RELATION_PATH = "data/box_sizes_at.csv"  # This is generated by the generate_molecule_data_fast.py script
NEIGHBOURHOOD_PATH = "data/neighborhoods.csv"  # This is generated by the generate_molecule_data_fast.py script
ABSOLUTE_POSITION_EXTRA_SCALE = 1.1  # This is a factor that is used to scale the box around absolute positions with. This is used to make sure that the absolute positions are not outside of [-1, 1]
ABSOLUT_POSITION_SCALE = 200.0      # This factor is used to scale down positions


def fix_pbc(vector, box_size=AVERAGE_SIMULATION_BOX_SIZE, cutoff=[PBC_CUTOFF, PBC_CUTOFF, PBC_CUTOFF]):
    """
    This function fixes the periodic boundary conditions of a vector. It does this by checking if the vector is outside of the box and if it is, it will move it back into the box.

    Args:
        vector (vector): The vector that should be fixed.
        box_size (vector, optional): The box size of the simulation. Defaults to AVERAGE_SIMULATION_BOX_SIZE.
        cutoff (list, optional): Cutoff radius to apply the PBC fix to. Defaults to [PBC_CUTOFF, PBC_CUTOFF, PBC_CUTOFF].

    Returns:
        vector: The fixed vector.
    """
    if vector[0] > cutoff[0]:
        vector[0] -= box_size[0]
    elif vector[0] < -cutoff[0]:
        vector[0] += box_size[0]
    if vector[1] > cutoff[1]:
        vector[1] -= box_size[1]
    elif vector[1] < -cutoff[1]:
        vector[1] += box_size[1]
    if vector[2] > cutoff[2]:
        vector[2] -= box_size[2]
    elif vector[2] < -cutoff[2]:
        vector[2] += box_size[2]

    if vector[0] > cutoff[0] or vector[0] < -cutoff[0] or vector[1] > cutoff[1] or vector[1] < -cutoff[1] or vector[2] > cutoff[2] or vector[2] < -cutoff[2]:
        return fix_pbc(vector, box_size)
    else:
        return vector


def is_output_matrix_healthy(output):
    """
        This is a diagnostic function that checks if the output matrix is fulffiling the requirements.
        Checked will be:
            - If values outside of [-1, 1] exist (this should not be the case)
            - If nan or inf exist (this should not be the case)
        TODO: This functions can be extended to check for other things as well.
    """
    healthy = True

    # Check if values that are not in [-1, 1] exist
    if np.max(output) > 1 or np.min(output) < -1:
        healthy = False

    # Check for nan or inf
    if np.isnan(output).any() or np.isinf(output).any():
        healthy = False

    return healthy


def add_relative_vectors(atom_pos_dict, atom_mapping, output_matrix, batch_index, box):
    """
    This functions calculates the relative vectors between the atoms in the atom mapping and writes them to the output matrix.
    Additionally, it fixes PBC and checks for consistency.

    Args:
        atom_pos_dict (_type_): The atom pos dict is a dict that maps the atom name to the atom position.
        atom_mapping (_type_): The atom mapping is a list of tuples that contain the atom names that should be mapped.
        output_matrix (_type_): The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
        batch_index (_type_): The batch index is the index of the batch in the output matrix.
        residue_idx (_type_): The residue index is the index of the residue in the dataset.
    """
    for j, (at1_name, at2_name) in enumerate(atom_mapping):
        if not (atom_pos_dict.get(at1_name) and atom_pos_dict.get(at2_name)):
            raise Exception(f"Missing {at1_name} or {at2_name} in residue")
        else:
            # Calculate the relative vector
            rel_vector = atom_pos_dict.get(at2_name).get_vector() - atom_pos_dict.get(at1_name).get_vector()

            # Fix the PBC
            if rel_vector.norm() > PBC_CUTOFF:
                rel_vector = fix_pbc(rel_vector, box)

            # Consitency check
            if rel_vector.norm() > PBC_CUTOFF:
                raise Exception(f"Found a vector that is too large ({rel_vector})!")

            # Check if nan or inf is in the vector
            if np.isnan(rel_vector[0]) or np.isnan(rel_vector[1]) or np.isnan(rel_vector[2]) or np.isinf(rel_vector[0]) or np.isinf(rel_vector[1]) or np.isinf(rel_vector[2]):
                raise Exception(f"Found nan or inf in vector ({rel_vector})!")

        # Write the relative vectors to the output matrix
        for k in range(3):
            output_matrix[batch_index, j + PADDING_X, PADDING_Y + k, 0] = rel_vector[k] / BOX_SCALE_FACTOR

    return output_matrix


def add_absolute_positions(atoms, output_matrix, batch_index, box, target_atom=None, preset_position_origin=None):
    """
    This functions calculates the absolute positions of the atoms in the atom mapping and writes them to the output matrix.

    Args:
        atom_pos_dict (_type_): The atom pos dict is a dict that maps the atom name to the atom position.
        atom_mapping (_type_): The atom mapping is a list of tuples that contain the atom names that should be mapped.
        output_matrix (_type_): The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
        batch_index (_type_): The batch index is the index of the batch in the output matrix.
        residue_idx (_type_): The residue index is the index of the residue in the dataset.
        target_name (_type_): The target is the atom that should be fitted. If this is None, all atoms will be fitted. If a target is set, only this atom will be in the output matrix.
        preset_position_origin (_type_): The position origin is the position of the atom that should be fitted. If this is None, the position of the N atom or NC3 bead will be used.
    """
    # Filter out hydrogen atoms
    atoms = [atom for atom in atoms if atom.element != "H"]

    # Make base position
    position_origin = [atom for atom in atoms if atom.get_name() == "N" or atom.get_name() == "NC3"][0].get_vector()  # Everything is relative to the N atom or NC3 bead
    if preset_position_origin:
        position_origin = preset_position_origin

    # Remove the N atom or NC3 bead because their position is the base position (0,0,0)
    if not preset_position_origin:
        atoms = [atom for atom in atoms if atom.get_name() != "N" and atom.get_name() != "NC3"]

    if target_atom:
        # If only one atom is selected we need to filter out all other atoms
        atoms = [atom for atom in atoms if atom.get_name() == target_atom.get_name()]

        if len(atoms) != 1:
            raise Exception(f"Found {len(atoms)} atoms with the name {target_atom.get_name()}!")

    for j, atom in enumerate(atoms):
        # Calculate the relative vector
        position = atom.get_vector() - position_origin

        # Fix the PBC
        position = fix_pbc(position, box, cutoff=[box[0] / 1.5, box[1] / 1.5, box[2] / 1.5])  # TODO: Maybe find better cutoff approximations

        # Check if nan or inf is in the vector
        if np.isnan(position[0]) or np.isnan(position[1]) or np.isnan(position[2]) or np.isinf(position[0]) or np.isinf(position[1]) or np.isinf(position[2]):
            raise Exception(f"Found nan or inf in vector ({position})!")

        # Write the relative vectors to the output matrix
        for k in range(3):
            output_matrix[batch_index, j + PADDING_X, PADDING_Y + k, 0] = position[k] / ABSOLUT_POSITION_SCALE  # / (box[k] * ABSOLUTE_POSITION_EXTRA_SCALE)

    return output_matrix


def get_bounding_box(dataset_idx: int, path_to_csv: str) -> np.array:
    """
    This function returns the bounding box of a residue in a dataset. The bounding box is the size of the simulation box in which the residue is located.

    Args:
        dataset_idx (int): The dataset index is the index of the dataset in the dataset list.
        path_to_csv (list): The path to the csv file that contains the bounding box sizes for the residues.

    Returns:
        array: The bounding box as a numpy array.
    """
    # Get the bounding box by using the cahched line access
    line = linecache.getline(path_to_csv, dataset_idx + 1)
    # Return the bounding box as a numpy array
    return np.array([float(x) for x in line.split(",")])


def get_neighbour_residues(dataset_idx: int, path_to_csv: str):
    """
    This function returns a list of the neighbour residues of a residue in a dataset.

    Args:
        dataset_idx (int): The dataset index is the index of the dataset in the dataset list.
        path_to_csv (list): The path to the csv file that contains the neighbour residues for the residues.

    Returns:
        array: List of neighbour residue indices.
    """
    # Get the bounding box by using the cahched line access
    line = linecache.getline(path_to_csv, dataset_idx + 1).replace("\n", "")
    if len(line) == 0:
        return np.array([])
    # Return the bounding box as a numpy array
    return np.array([int(x) for x in line.split(",")])


def print_matrix(matrix):
    """
    This function prints a matrix in ascii art.

    Args:
        matrix (list): List with shape (batch_size, i, j, 1)
    """
    # Prints an output/input matrix in ascii art
    for i in range(matrix.shape[0]):
        print(" ", end="")
        for k in range(matrix.shape[1]):
            print(f"{k:6d}", end="  ")
        print()
        print(matrix.shape[1] * 8 * "-")
        # Batch
        for j in range(matrix.shape[2]):
            # Y
            for k in range(matrix.shape[1]):
                # X
                minus_sign_padding = " " if matrix[i, k, j, 0] >= 0 else ""
                print(f"{minus_sign_padding}{matrix[i, k, j, 0]:.4f}", end=" ")
            print()
        print(matrix.shape[1] * 8 * "-")


class BackmappingBaseGenerator(tf.keras.utils.Sequence):
    """
        This is the base class for the backmapping data generator.
        It is used to generate batches of data for the CNN. The children classes specify the structure of the data.
        For example, the RelativeVectorsTrainingDataGenerator generates batches of relative bond vectors.
    """

    def __init__(
        self,
        input_dir_path: str,
        output_dir_path: str,
        input_size: tuple = (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        output_size: tuple = (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        shuffle: bool = False,
        batch_size: int = 1,
        validate_split: float = 0.1,
        validation_mode: bool = False,
        augmentation: bool = False,
    ):
        """
        This is the base class for the backmapping data generator.

        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
        """
        self.input_dir_path = input_dir_path
        self.output_dir_path = output_dir_path
        self.input_size = input_size
        self.output_size = output_size
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.validation_mode = validation_mode
        self.augmentation = augmentation

        self.parser = PDBParser(QUIET=True)

        max_index = int(os.listdir(input_dir_path).__len__() - 1)

        self.len = int(max_index * (1 - validate_split))
        self.start_index = 0
        self.end_index = self.len

        # Set validation mode
        if self.validation_mode:
            self.start_index = self.len + 1
            self.len = int(max_index * validate_split)
            self.end_index = max_index

        # Debug
        print(f"Found {self.len} residues in ({self.input_dir_path})")

        # Initialize
        self.on_epoch_end()

    def on_epoch_end(self):
        pass

    def __len__(self):
        return self.len // self.batch_size

    def __getitem__(self, idx):
        raise Exception("This is an abstract class and should not be used directly!")


class RelativeVectorsTrainingDataGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the relative vectors between the atoms and the beads.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        output_size=(53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

    def __getitem__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, *self.output_size), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Make name -> atom dict
            cg_atoms_dict = {atom.get_name(): atom for atom in cg_atoms}
            at_atoms_dict = {atom.get_name(): atom for atom in at_atoms}

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the relative vectors out of a vector mapping
            X = add_relative_vectors(cg_atoms_dict, DOPC_CG_MAPPING, X, i, cg_box_size)
            Y = add_relative_vectors(at_atoms_dict, DOPC_AT_MAPPING, Y, i, at_box_size)

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)


class AbsolutePositionsGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the aboslute positions of the atoms and beads.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(12 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more bead position than relativ vectors
        output_size=(54 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more atom position than relativ vectors
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
        only_fit_one_atom=False,
        atom_name=None,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
            only_fit_one_atom (bool, optional): If only one atom should be fitted. Defaults to False.
            atom_name (str, optional): The name of the atom that should be fitted. Defaults to None.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

        self.only_fit_one_atom = only_fit_one_atom
        self.atom_name = atom_name

    def __getitem__(self, idx):
        return self.__getitem_one_atom__(idx) if self.only_fit_one_atom else self.__getitem_all_atoms__(idx)

    def __getitem_one_atom__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, 1 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Filter out the atom that should be fitted
            if not self.atom_name:
                raise Exception("You need to specify an atom name!")
            at_atoms_to_fit = [atom for atom in at_atoms if atom.get_name() == self.atom_name]

            if at_atoms_to_fit.__len__() != 1:
                raise Exception(f"Found {at_atoms.__len__()} atoms with the name {self.atom_name} in residue {residue_idx}!")

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size, target_atom=at_atoms_to_fit[0])

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

    def __getitem_all_atoms__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, ), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size)

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)


class AbsolutePositionsNeigbourhoodGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the aboslute positions of the atoms and beads and 
        includes the neighbourhood of the molecule that should be fitted.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(12 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more bead position than relativ vectors
        output_size=(54 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more atom position than relativ vectors
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
        only_fit_one_atom=False,
        atom_name=None,
        neighbourhood_size=4,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
            only_fit_one_atom (bool, optional): If only one atom should be fitted. Defaults to False.
            atom_name (str, optional): The name of the atom that should be fitted. Defaults to None.
            neighbourhood_size (int, optional): The amount of the neighbours to take into place. Defaults to 4.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

        self.only_fit_one_atom = only_fit_one_atom
        self.atom_name = atom_name
        self.neighbourhood_size = neighbourhood_size

    def __getitem__(self, idx):
        return self.__getitem_one_atom__(idx) if self.only_fit_one_atom else self.__getitem_all_atoms__(idx)

    def __getitem_one_atom__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)     # Now is (batch, beades + neighborhood_size, (x, y, z), 1)
        Y = np.zeros((self.batch_size, 1 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Filter out the atom that should be fitted
            if not self.atom_name:
                raise Exception("You need to specify an atom name!")

            at_atoms_to_fit = [atom for atom in at_atoms if atom.get_name() == self.atom_name]

            if at_atoms_to_fit.__len__() != 1:
                raise Exception(f"Found {at_atoms.__len__()} atoms with the name {self.atom_name} in residue {residue_idx}!")

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Get the neighbourhood
            neighbourhood = get_neighbour_residues(residue_idx, NEIGHBOURHOOD_PATH)

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size, target_atom=at_atoms_to_fit[0])

            # Add the neighbourhood
            for j in range(np.min([self.neighbourhood_size, neighbourhood.__len__()])):
                neighbor_X = self.get_neighbor_X(
                    residue_idx=neighbourhood[j],
                    box_size=cg_box_size,
                    position_origin=[atom.get_vector() for atom in cg_atoms if atom.get_name() == "NC3"][0]
                )[0, PADDING_X: 12 + PADDING_X, PADDING_Y:-PADDING_Y, 0]

                # Add the neighbour to the input
                X[i, PADDING_X + 12 + j * 12: PADDING_X + 24 + j * 12, PADDING_Y:-PADDING_Y, 0] = neighbor_X

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

    def get_neighbor_X(self, residue_idx, box_size, position_origin):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)     # Now is (batch, beades

        # Check if end index is reached
        if self.validation_mode and residue_idx > self.end_index:
            raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

        # Get the path to the files
        cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"

        # Load the files
        cg_structure = self.parser.get_structure(residue_idx, cg_path)

        # Get the residues
        cg_residue = list(cg_structure.get_residues())[0]

        # Get the atoms
        cg_atoms = list(cg_residue.get_atoms())

        # Make the absolute positions out of a vector mapping
        X = add_absolute_positions(cg_atoms, X, 0, box_size, preset_position_origin=position_origin)

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)

        # Return tensor as deep copy
        return tf.identity(X)

    def __getitem_all_atoms__(self, idx):
        raise Exception("Not implemented yet!")

Functions

def add_absolute_positions(atoms, output_matrix, batch_index, box, target_atom=None, preset_position_origin=None)

This functions calculates the absolute positions of the atoms in the atom mapping and writes them to the output matrix.

Args

atom_pos_dict : _type_
The atom pos dict is a dict that maps the atom name to the atom position.
atom_mapping : _type_
The atom mapping is a list of tuples that contain the atom names that should be mapped.
output_matrix : _type_
The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
batch_index : _type_
The batch index is the index of the batch in the output matrix.
residue_idx : _type_
The residue index is the index of the residue in the dataset.
target_name : _type_
The target is the atom that should be fitted. If this is None, all atoms will be fitted. If a target is set, only this atom will be in the output matrix.
preset_position_origin : _type_
The position origin is the position of the atom that should be fitted. If this is None, the position of the N atom or NC3 bead will be used.
Expand source code
def add_absolute_positions(atoms, output_matrix, batch_index, box, target_atom=None, preset_position_origin=None):
    """
    This functions calculates the absolute positions of the atoms in the atom mapping and writes them to the output matrix.

    Args:
        atom_pos_dict (_type_): The atom pos dict is a dict that maps the atom name to the atom position.
        atom_mapping (_type_): The atom mapping is a list of tuples that contain the atom names that should be mapped.
        output_matrix (_type_): The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
        batch_index (_type_): The batch index is the index of the batch in the output matrix.
        residue_idx (_type_): The residue index is the index of the residue in the dataset.
        target_name (_type_): The target is the atom that should be fitted. If this is None, all atoms will be fitted. If a target is set, only this atom will be in the output matrix.
        preset_position_origin (_type_): The position origin is the position of the atom that should be fitted. If this is None, the position of the N atom or NC3 bead will be used.
    """
    # Filter out hydrogen atoms
    atoms = [atom for atom in atoms if atom.element != "H"]

    # Make base position
    position_origin = [atom for atom in atoms if atom.get_name() == "N" or atom.get_name() == "NC3"][0].get_vector()  # Everything is relative to the N atom or NC3 bead
    if preset_position_origin:
        position_origin = preset_position_origin

    # Remove the N atom or NC3 bead because their position is the base position (0,0,0)
    if not preset_position_origin:
        atoms = [atom for atom in atoms if atom.get_name() != "N" and atom.get_name() != "NC3"]

    if target_atom:
        # If only one atom is selected we need to filter out all other atoms
        atoms = [atom for atom in atoms if atom.get_name() == target_atom.get_name()]

        if len(atoms) != 1:
            raise Exception(f"Found {len(atoms)} atoms with the name {target_atom.get_name()}!")

    for j, atom in enumerate(atoms):
        # Calculate the relative vector
        position = atom.get_vector() - position_origin

        # Fix the PBC
        position = fix_pbc(position, box, cutoff=[box[0] / 1.5, box[1] / 1.5, box[2] / 1.5])  # TODO: Maybe find better cutoff approximations

        # Check if nan or inf is in the vector
        if np.isnan(position[0]) or np.isnan(position[1]) or np.isnan(position[2]) or np.isinf(position[0]) or np.isinf(position[1]) or np.isinf(position[2]):
            raise Exception(f"Found nan or inf in vector ({position})!")

        # Write the relative vectors to the output matrix
        for k in range(3):
            output_matrix[batch_index, j + PADDING_X, PADDING_Y + k, 0] = position[k] / ABSOLUT_POSITION_SCALE  # / (box[k] * ABSOLUTE_POSITION_EXTRA_SCALE)

    return output_matrix
def add_relative_vectors(atom_pos_dict, atom_mapping, output_matrix, batch_index, box)

This functions calculates the relative vectors between the atoms in the atom mapping and writes them to the output matrix. Additionally, it fixes PBC and checks for consistency.

Args

atom_pos_dict : _type_
The atom pos dict is a dict that maps the atom name to the atom position.
atom_mapping : _type_
The atom mapping is a list of tuples that contain the atom names that should be mapped.
output_matrix : _type_
The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
batch_index : _type_
The batch index is the index of the batch in the output matrix.
residue_idx : _type_
The residue index is the index of the residue in the dataset.
Expand source code
def add_relative_vectors(atom_pos_dict, atom_mapping, output_matrix, batch_index, box):
    """
    This functions calculates the relative vectors between the atoms in the atom mapping and writes them to the output matrix.
    Additionally, it fixes PBC and checks for consistency.

    Args:
        atom_pos_dict (_type_): The atom pos dict is a dict that maps the atom name to the atom position.
        atom_mapping (_type_): The atom mapping is a list of tuples that contain the atom names that should be mapped.
        output_matrix (_type_): The output matrix is the matrix where the relative vectors should be written to. Shape: (batch_size, i, j, 1)
        batch_index (_type_): The batch index is the index of the batch in the output matrix.
        residue_idx (_type_): The residue index is the index of the residue in the dataset.
    """
    for j, (at1_name, at2_name) in enumerate(atom_mapping):
        if not (atom_pos_dict.get(at1_name) and atom_pos_dict.get(at2_name)):
            raise Exception(f"Missing {at1_name} or {at2_name} in residue")
        else:
            # Calculate the relative vector
            rel_vector = atom_pos_dict.get(at2_name).get_vector() - atom_pos_dict.get(at1_name).get_vector()

            # Fix the PBC
            if rel_vector.norm() > PBC_CUTOFF:
                rel_vector = fix_pbc(rel_vector, box)

            # Consitency check
            if rel_vector.norm() > PBC_CUTOFF:
                raise Exception(f"Found a vector that is too large ({rel_vector})!")

            # Check if nan or inf is in the vector
            if np.isnan(rel_vector[0]) or np.isnan(rel_vector[1]) or np.isnan(rel_vector[2]) or np.isinf(rel_vector[0]) or np.isinf(rel_vector[1]) or np.isinf(rel_vector[2]):
                raise Exception(f"Found nan or inf in vector ({rel_vector})!")

        # Write the relative vectors to the output matrix
        for k in range(3):
            output_matrix[batch_index, j + PADDING_X, PADDING_Y + k, 0] = rel_vector[k] / BOX_SCALE_FACTOR

    return output_matrix
def fix_pbc(vector, box_size=(191.05054545, 191.05054545, 111.67836364), cutoff=[10.0, 10.0, 10.0])

This function fixes the periodic boundary conditions of a vector. It does this by checking if the vector is outside of the box and if it is, it will move it back into the box.

Args

vector : vector
The vector that should be fixed.
box_size : vector, optional
The box size of the simulation. Defaults to AVERAGE_SIMULATION_BOX_SIZE.
cutoff : list, optional
Cutoff radius to apply the PBC fix to. Defaults to [PBC_CUTOFF, PBC_CUTOFF, PBC_CUTOFF].

Returns

vector
The fixed vector.
Expand source code
def fix_pbc(vector, box_size=AVERAGE_SIMULATION_BOX_SIZE, cutoff=[PBC_CUTOFF, PBC_CUTOFF, PBC_CUTOFF]):
    """
    This function fixes the periodic boundary conditions of a vector. It does this by checking if the vector is outside of the box and if it is, it will move it back into the box.

    Args:
        vector (vector): The vector that should be fixed.
        box_size (vector, optional): The box size of the simulation. Defaults to AVERAGE_SIMULATION_BOX_SIZE.
        cutoff (list, optional): Cutoff radius to apply the PBC fix to. Defaults to [PBC_CUTOFF, PBC_CUTOFF, PBC_CUTOFF].

    Returns:
        vector: The fixed vector.
    """
    if vector[0] > cutoff[0]:
        vector[0] -= box_size[0]
    elif vector[0] < -cutoff[0]:
        vector[0] += box_size[0]
    if vector[1] > cutoff[1]:
        vector[1] -= box_size[1]
    elif vector[1] < -cutoff[1]:
        vector[1] += box_size[1]
    if vector[2] > cutoff[2]:
        vector[2] -= box_size[2]
    elif vector[2] < -cutoff[2]:
        vector[2] += box_size[2]

    if vector[0] > cutoff[0] or vector[0] < -cutoff[0] or vector[1] > cutoff[1] or vector[1] < -cutoff[1] or vector[2] > cutoff[2] or vector[2] < -cutoff[2]:
        return fix_pbc(vector, box_size)
    else:
        return vector
def get_bounding_box(dataset_idx: int, path_to_csv: str) ‑> 

This function returns the bounding box of a residue in a dataset. The bounding box is the size of the simulation box in which the residue is located.

Args

dataset_idx : int
The dataset index is the index of the dataset in the dataset list.
path_to_csv : list
The path to the csv file that contains the bounding box sizes for the residues.

Returns

array
The bounding box as a numpy array.
Expand source code
def get_bounding_box(dataset_idx: int, path_to_csv: str) -> np.array:
    """
    This function returns the bounding box of a residue in a dataset. The bounding box is the size of the simulation box in which the residue is located.

    Args:
        dataset_idx (int): The dataset index is the index of the dataset in the dataset list.
        path_to_csv (list): The path to the csv file that contains the bounding box sizes for the residues.

    Returns:
        array: The bounding box as a numpy array.
    """
    # Get the bounding box by using the cahched line access
    line = linecache.getline(path_to_csv, dataset_idx + 1)
    # Return the bounding box as a numpy array
    return np.array([float(x) for x in line.split(",")])
def get_neighbour_residues(dataset_idx: int, path_to_csv: str)

This function returns a list of the neighbour residues of a residue in a dataset.

Args

dataset_idx : int
The dataset index is the index of the dataset in the dataset list.
path_to_csv : list
The path to the csv file that contains the neighbour residues for the residues.

Returns

array
List of neighbour residue indices.
Expand source code
def get_neighbour_residues(dataset_idx: int, path_to_csv: str):
    """
    This function returns a list of the neighbour residues of a residue in a dataset.

    Args:
        dataset_idx (int): The dataset index is the index of the dataset in the dataset list.
        path_to_csv (list): The path to the csv file that contains the neighbour residues for the residues.

    Returns:
        array: List of neighbour residue indices.
    """
    # Get the bounding box by using the cahched line access
    line = linecache.getline(path_to_csv, dataset_idx + 1).replace("\n", "")
    if len(line) == 0:
        return np.array([])
    # Return the bounding box as a numpy array
    return np.array([int(x) for x in line.split(",")])
def is_output_matrix_healthy(output)

This is a diagnostic function that checks if the output matrix is fulffiling the requirements. Checked will be: - If values outside of [-1, 1] exist (this should not be the case) - If nan or inf exist (this should not be the case) TODO: This functions can be extended to check for other things as well.

Expand source code
def is_output_matrix_healthy(output):
    """
        This is a diagnostic function that checks if the output matrix is fulffiling the requirements.
        Checked will be:
            - If values outside of [-1, 1] exist (this should not be the case)
            - If nan or inf exist (this should not be the case)
        TODO: This functions can be extended to check for other things as well.
    """
    healthy = True

    # Check if values that are not in [-1, 1] exist
    if np.max(output) > 1 or np.min(output) < -1:
        healthy = False

    # Check for nan or inf
    if np.isnan(output).any() or np.isinf(output).any():
        healthy = False

    return healthy
def print_matrix(matrix)

This function prints a matrix in ascii art.

Args

matrix : list
List with shape (batch_size, i, j, 1)
Expand source code
def print_matrix(matrix):
    """
    This function prints a matrix in ascii art.

    Args:
        matrix (list): List with shape (batch_size, i, j, 1)
    """
    # Prints an output/input matrix in ascii art
    for i in range(matrix.shape[0]):
        print(" ", end="")
        for k in range(matrix.shape[1]):
            print(f"{k:6d}", end="  ")
        print()
        print(matrix.shape[1] * 8 * "-")
        # Batch
        for j in range(matrix.shape[2]):
            # Y
            for k in range(matrix.shape[1]):
                # X
                minus_sign_padding = " " if matrix[i, k, j, 0] >= 0 else ""
                print(f"{minus_sign_padding}{matrix[i, k, j, 0]:.4f}", end=" ")
            print()
        print(matrix.shape[1] * 8 * "-")

Classes

class AbsolutePositionsGenerator (input_dir_path, output_dir_path, input_size=(14, 5, 1), output_size=(56, 5, 1), shuffle=False, batch_size=1, validate_split=0.1, validation_mode=False, augmentation=False, only_fit_one_atom=False, atom_name=None)

A data generator class that generates batches of data for the CNN. The input and output is the aboslute positions of the atoms and beads.

Args

input_dir_path : str
The path to the directory where the input data (X) is located
output_dir_path : str
The path to the directory where the output data (Y) is located
input_size : tuple, optional
The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
output_size : tuple, optional
The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
shuffle : bool, optional
If the data should be shuffled after each epoch. Defaults to False.
batch_size : int, optional
The size of each batch. Defaults to 1.
validate_split : float, optional
The percentage of data that should be used for validation. Defaults to 0.1.
validation_mode : bool, optional
If the generator should be in validation mode. Defaults to False.
augmentation : bool, optional
If the generator should augment the data. Defaults to False.
only_fit_one_atom : bool, optional
If only one atom should be fitted. Defaults to False.
atom_name : str, optional
The name of the atom that should be fitted. Defaults to None.
Expand source code
class AbsolutePositionsGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the aboslute positions of the atoms and beads.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(12 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more bead position than relativ vectors
        output_size=(54 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more atom position than relativ vectors
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
        only_fit_one_atom=False,
        atom_name=None,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
            only_fit_one_atom (bool, optional): If only one atom should be fitted. Defaults to False.
            atom_name (str, optional): The name of the atom that should be fitted. Defaults to None.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

        self.only_fit_one_atom = only_fit_one_atom
        self.atom_name = atom_name

    def __getitem__(self, idx):
        return self.__getitem_one_atom__(idx) if self.only_fit_one_atom else self.__getitem_all_atoms__(idx)

    def __getitem_one_atom__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, 1 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Filter out the atom that should be fitted
            if not self.atom_name:
                raise Exception("You need to specify an atom name!")
            at_atoms_to_fit = [atom for atom in at_atoms if atom.get_name() == self.atom_name]

            if at_atoms_to_fit.__len__() != 1:
                raise Exception(f"Found {at_atoms.__len__()} atoms with the name {self.atom_name} in residue {residue_idx}!")

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size, target_atom=at_atoms_to_fit[0])

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

    def __getitem_all_atoms__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, ), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size)

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

Ancestors

Inherited members

class AbsolutePositionsNeigbourhoodGenerator (input_dir_path, output_dir_path, input_size=(14, 5, 1), output_size=(56, 5, 1), shuffle=False, batch_size=1, validate_split=0.1, validation_mode=False, augmentation=False, only_fit_one_atom=False, atom_name=None, neighbourhood_size=4)

A data generator class that generates batches of data for the CNN. The input and output is the aboslute positions of the atoms and beads and includes the neighbourhood of the molecule that should be fitted.

Args

input_dir_path : str
The path to the directory where the input data (X) is located
output_dir_path : str
The path to the directory where the output data (Y) is located
input_size : tuple, optional
The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
output_size : tuple, optional
The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
shuffle : bool, optional
If the data should be shuffled after each epoch. Defaults to False.
batch_size : int, optional
The size of each batch. Defaults to 1.
validate_split : float, optional
The percentage of data that should be used for validation. Defaults to 0.1.
validation_mode : bool, optional
If the generator should be in validation mode. Defaults to False.
augmentation : bool, optional
If the generator should augment the data. Defaults to False.
only_fit_one_atom : bool, optional
If only one atom should be fitted. Defaults to False.
atom_name : str, optional
The name of the atom that should be fitted. Defaults to None.
neighbourhood_size : int, optional
The amount of the neighbours to take into place. Defaults to 4.
Expand source code
class AbsolutePositionsNeigbourhoodGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the aboslute positions of the atoms and beads and 
        includes the neighbourhood of the molecule that should be fitted.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(12 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more bead position than relativ vectors
        output_size=(54 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),  # We have one more atom position than relativ vectors
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
        only_fit_one_atom=False,
        atom_name=None,
        neighbourhood_size=4,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
            only_fit_one_atom (bool, optional): If only one atom should be fitted. Defaults to False.
            atom_name (str, optional): The name of the atom that should be fitted. Defaults to None.
            neighbourhood_size (int, optional): The amount of the neighbours to take into place. Defaults to 4.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

        self.only_fit_one_atom = only_fit_one_atom
        self.atom_name = atom_name
        self.neighbourhood_size = neighbourhood_size

    def __getitem__(self, idx):
        return self.__getitem_one_atom__(idx) if self.only_fit_one_atom else self.__getitem_all_atoms__(idx)

    def __getitem_one_atom__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)     # Now is (batch, beades + neighborhood_size, (x, y, z), 1)
        Y = np.zeros((self.batch_size, 1 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Filter out the atom that should be fitted
            if not self.atom_name:
                raise Exception("You need to specify an atom name!")

            at_atoms_to_fit = [atom for atom in at_atoms if atom.get_name() == self.atom_name]

            if at_atoms_to_fit.__len__() != 1:
                raise Exception(f"Found {at_atoms.__len__()} atoms with the name {self.atom_name} in residue {residue_idx}!")

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Get the neighbourhood
            neighbourhood = get_neighbour_residues(residue_idx, NEIGHBOURHOOD_PATH)

            # Make the absolute positions out of a vector mapping
            X = add_absolute_positions(cg_atoms, X, i, cg_box_size)
            Y = add_absolute_positions(at_atoms, Y, i, at_box_size, target_atom=at_atoms_to_fit[0])

            # Add the neighbourhood
            for j in range(np.min([self.neighbourhood_size, neighbourhood.__len__()])):
                neighbor_X = self.get_neighbor_X(
                    residue_idx=neighbourhood[j],
                    box_size=cg_box_size,
                    position_origin=[atom.get_vector() for atom in cg_atoms if atom.get_name() == "NC3"][0]
                )[0, PADDING_X: 12 + PADDING_X, PADDING_Y:-PADDING_Y, 0]

                # Add the neighbour to the input
                X[i, PADDING_X + 12 + j * 12: PADDING_X + 24 + j * 12, PADDING_Y:-PADDING_Y, 0] = neighbor_X

        # Augment the data
        if self.augmentation:
            # Randomly rotate each dataset
            for i in range(self.batch_size):
                vectors_X = X[i, :, :, 0]
                vectors_Y = Y[i, :, :, 0]

                # Randomly rotate the dataset
                angle_x = random.uniform(-np.pi, np.pi)
                angle_y = random.uniform(-np.pi, np.pi)
                angle_z = random.uniform(-np.pi, np.pi)

                # Loop over beads
                for j in range(vectors_X.shape[0]):
                    vec = vectors_X[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_X[j, PADDING_Y:-PADDING_Y] = vec

                # Loop over atoms
                for j in range(vectors_Y.shape[0]):
                    vec = vectors_Y[j, PADDING_Y:-PADDING_Y]
                    # Rotate
                    vec = np.matmul(np.array([[1, 0, 0], [0, np.cos(angle_x), -np.sin(angle_x)], [0, np.sin(angle_x), np.cos(angle_x)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_y), 0, np.sin(angle_y)], [0, 1, 0], [-np.sin(angle_y), 0, np.cos(angle_y)]]), vec)
                    vec = np.matmul(np.array([[np.cos(angle_z), -np.sin(angle_z), 0], [np.sin(angle_z), np.cos(angle_z), 0], [0, 0, 1]]), vec)

                    # Write back
                    vectors_Y[j, PADDING_Y:-PADDING_Y] = vec

                # Write the rotated vectors back to the matrix
                X[i, :, :, 0] = vectors_X
                Y[i, :, :, 0] = vectors_Y

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            if not is_output_matrix_healthy(Y):
                print_matrix(Y[0:1, :, :, :])
            if not is_output_matrix_healthy(X):
                # Find batch that is not healthy
                for i in range(X.shape[0]):
                    if not is_output_matrix_healthy(X[i:i+1, :, :, :]):
                        print(i)
                        print_matrix(X[i:i+1, :, :, :])
                        break
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

    def get_neighbor_X(self, residue_idx, box_size, position_origin):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)     # Now is (batch, beades

        # Check if end index is reached
        if self.validation_mode and residue_idx > self.end_index:
            raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

        # Get the path to the files
        cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"

        # Load the files
        cg_structure = self.parser.get_structure(residue_idx, cg_path)

        # Get the residues
        cg_residue = list(cg_structure.get_residues())[0]

        # Get the atoms
        cg_atoms = list(cg_residue.get_atoms())

        # Make the absolute positions out of a vector mapping
        X = add_absolute_positions(cg_atoms, X, 0, box_size, preset_position_origin=position_origin)

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)

        # Return tensor as deep copy
        return tf.identity(X)

    def __getitem_all_atoms__(self, idx):
        raise Exception("Not implemented yet!")

Ancestors

Methods

def get_neighbor_X(self, residue_idx, box_size, position_origin)
Expand source code
def get_neighbor_X(self, residue_idx, box_size, position_origin):
    # Initialize Batch
    X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)     # Now is (batch, beades

    # Check if end index is reached
    if self.validation_mode and residue_idx > self.end_index:
        raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

    # Get the path to the files
    cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"

    # Load the files
    cg_structure = self.parser.get_structure(residue_idx, cg_path)

    # Get the residues
    cg_residue = list(cg_structure.get_residues())[0]

    # Get the atoms
    cg_atoms = list(cg_residue.get_atoms())

    # Make the absolute positions out of a vector mapping
    X = add_absolute_positions(cg_atoms, X, 0, box_size, preset_position_origin=position_origin)

    # Convert to tensor
    X = tf.convert_to_tensor(X, dtype=tf.float32)

    # Return tensor as deep copy
    return tf.identity(X)

Inherited members

class BackmappingBaseGenerator (input_dir_path: str, output_dir_path: str, input_size: tuple = (13, 5, 1), output_size: tuple = (55, 5, 1), shuffle: bool = False, batch_size: int = 1, validate_split: float = 0.1, validation_mode: bool = False, augmentation: bool = False)

This is the base class for the backmapping data generator. It is used to generate batches of data for the CNN. The children classes specify the structure of the data. For example, the RelativeVectorsTrainingDataGenerator generates batches of relative bond vectors.

This is the base class for the backmapping data generator.

Args

input_dir_path : str
The path to the directory where the input data (X) is located
output_dir_path : str
The path to the directory where the output data (Y) is located
input_size : tuple, optional
The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
output_size : tuple, optional
The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
shuffle : bool, optional
If the data should be shuffled after each epoch. Defaults to False.
batch_size : int, optional
The size of each batch. Defaults to 1.
validate_split : float, optional
The percentage of data that should be used for validation. Defaults to 0.1.
validation_mode : bool, optional
If the generator should be in validation mode. Defaults to False.
augmentation : bool, optional
If the generator should augment the data. Defaults to False.
Expand source code
class BackmappingBaseGenerator(tf.keras.utils.Sequence):
    """
        This is the base class for the backmapping data generator.
        It is used to generate batches of data for the CNN. The children classes specify the structure of the data.
        For example, the RelativeVectorsTrainingDataGenerator generates batches of relative bond vectors.
    """

    def __init__(
        self,
        input_dir_path: str,
        output_dir_path: str,
        input_size: tuple = (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        output_size: tuple = (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        shuffle: bool = False,
        batch_size: int = 1,
        validate_split: float = 0.1,
        validation_mode: bool = False,
        augmentation: bool = False,
    ):
        """
        This is the base class for the backmapping data generator.

        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
        """
        self.input_dir_path = input_dir_path
        self.output_dir_path = output_dir_path
        self.input_size = input_size
        self.output_size = output_size
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.validation_mode = validation_mode
        self.augmentation = augmentation

        self.parser = PDBParser(QUIET=True)

        max_index = int(os.listdir(input_dir_path).__len__() - 1)

        self.len = int(max_index * (1 - validate_split))
        self.start_index = 0
        self.end_index = self.len

        # Set validation mode
        if self.validation_mode:
            self.start_index = self.len + 1
            self.len = int(max_index * validate_split)
            self.end_index = max_index

        # Debug
        print(f"Found {self.len} residues in ({self.input_dir_path})")

        # Initialize
        self.on_epoch_end()

    def on_epoch_end(self):
        pass

    def __len__(self):
        return self.len // self.batch_size

    def __getitem__(self, idx):
        raise Exception("This is an abstract class and should not be used directly!")

Ancestors

  • keras.utils.data_utils.Sequence

Subclasses

Methods

def on_epoch_end(self)

Method called at the end of every epoch.

Expand source code
def on_epoch_end(self):
    pass
class RelativeVectorsTrainingDataGenerator (input_dir_path, output_dir_path, input_size=(13, 5, 1), output_size=(55, 5, 1), shuffle=False, batch_size=1, validate_split=0.1, validation_mode=False, augmentation=False)

A data generator class that generates batches of data for the CNN. The input and output is the relative vectors between the atoms and the beads.

Args

input_dir_path : str
The path to the directory where the input data (X) is located
output_dir_path : str
The path to the directory where the output data (Y) is located
input_size : tuple, optional
The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
output_size : tuple, optional
The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
shuffle : bool, optional
If the data should be shuffled after each epoch. Defaults to False.
batch_size : int, optional
The size of each batch. Defaults to 1.
validate_split : float, optional
The percentage of data that should be used for validation. Defaults to 0.1.
validation_mode : bool, optional
If the generator should be in validation mode. Defaults to False.
augmentation : bool, optional
If the generator should augment the data. Defaults to False.
Expand source code
class RelativeVectorsTrainingDataGenerator(BackmappingBaseGenerator):
    """
        A data generator class that generates batches of data for the CNN.
        The input and output is the relative vectors between the atoms and the beads.
    """

    def __init__(
        self,
        input_dir_path,
        output_dir_path,
        input_size=(11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        output_size=(53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1),
        shuffle=False,
        batch_size=1,
        validate_split=0.1,
        validation_mode=False,
        augmentation=False,
    ):
        """
        Args:
            input_dir_path (str): The path to the directory where the input data (X) is located
            output_dir_path (str): The path to the directory where the output data (Y) is located
            input_size (tuple, optional): The size/shape of the input data. Defaults to (11 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            output_size (tuple, optional): The size/shape of the output data. Defaults to (53 + 2 * PADDING_X, 3 + 2 * PADDING_Y, 1).
            shuffle (bool, optional): If the data should be shuffled after each epoch. Defaults to False.
            batch_size (int, optional): The size of each batch. Defaults to 1.
            validate_split (float, optional): The percentage of data that should be used for validation. Defaults to 0.1.
            validation_mode (bool, optional): If the generator should be in validation mode. Defaults to False.
            augmentation (bool, optional): If the generator should augment the data. Defaults to False.
        """
        # Call super constructor
        super().__init__(
            input_dir_path,
            output_dir_path,
            input_size,
            output_size,
            shuffle,
            batch_size,
            validate_split,
            validation_mode,
            augmentation,
        )

    def __getitem__(self, idx):
        # Initialize Batch
        X = np.zeros((self.batch_size, *self.input_size), dtype=np.float32)
        Y = np.zeros((self.batch_size, *self.output_size), dtype=np.float32)

        # Loop over the batch
        for i in range(self.batch_size):
            # Get the index of the residue
            residue_idx = idx * self.batch_size + i

            # Check if end index is reached
            if self.validation_mode and residue_idx > self.end_index:
                raise Exception(f"You are trying to access a residue that does not exist ({residue_idx})!")

            # Get the path to the files
            cg_path = f"{self.input_dir_path}/{residue_idx}.pdb"
            at_path = f"{self.output_dir_path}/{residue_idx}.pdb"

            # Load the files
            cg_structure = self.parser.get_structure(residue_idx, cg_path)
            at_structure = self.parser.get_structure(residue_idx, at_path)

            # Get the residues
            cg_residue = list(cg_structure.get_residues())[0]
            at_residue = list(at_structure.get_residues())[0]

            # Get the atoms
            cg_atoms = list(cg_residue.get_atoms())
            at_atoms = list(at_residue.get_atoms())

            # Make name -> atom dict
            cg_atoms_dict = {atom.get_name(): atom for atom in cg_atoms}
            at_atoms_dict = {atom.get_name(): atom for atom in at_atoms}

            # Get bounding box sizes
            cg_box_size = get_bounding_box(residue_idx, CG_BOUNDING_BOX_RELATION_PATH)
            at_box_size = get_bounding_box(residue_idx, AT_BOUNDING_BOX_RELATION_PATH)

            # Check if the box sizes are not nan
            if np.isnan(cg_box_size).any() or np.isnan(at_box_size).any():
                raise Exception(f"Found nan in box sizes ({cg_box_size}, {at_box_size}) in residue {residue_idx}!")

            # Make the relative vectors out of a vector mapping
            X = add_relative_vectors(cg_atoms_dict, DOPC_CG_MAPPING, X, i, cg_box_size)
            Y = add_relative_vectors(at_atoms_dict, DOPC_AT_MAPPING, Y, i, at_box_size)

        # Convert to tensor
        X = tf.convert_to_tensor(X, dtype=tf.float32)
        Y = tf.convert_to_tensor(Y, dtype=tf.float32)

        # Check if values that are not in [-1, 1] exist
        if not is_output_matrix_healthy(Y) or not is_output_matrix_healthy(X):
            raise Exception(f"Found values outside of [-1, 1], see print before.")

        # Return tensor as deep copy
        return tf.identity(X), tf.identity(Y)

Ancestors

Inherited members