Module library.viz

Expand source code
from Bio.PDB.Atom import Atom
from Bio.PDB.Model import Model
from Bio.PDB.Chain import Chain
from Bio.PDB.Residue import Residue
from Bio.PDB.Structure import Structure
from Bio.PDB.NeighborSearch import NeighborSearch

from library.classes import dataset
from library.parser import get_cg_at_datasets
from library.static.utils import DEFAULT_ELEMENT_COLOR_MAP
from library.static.topologies import cg_name_to_type_dict, cg_bond_map_dict

import os
import matplotlib.pyplot as plt
import numpy as np


def __get_name(atom: Atom):
    return atom.get_name()


def __get_element(atom: Atom):
    return atom.element


def plot_residue(
    residue: Residue,
    residue_map=None,
    bond_map=None,
    show_labels=False,
    group_by_element=False,
    show_neighrbor_bonds=False,
    neighbor_distance=0.5,
    save_path=None,
    dont_show_plot=False,
    ax=None,
    fig=None
) -> plt.Figure:
    """
        Plots the residue using matplotlib

        Parameters:
            residue (Residue): The residue to plot
            map (dict): A dictionary mapping atom names to type
            bond_map (dict): A set mapping atom i to atom j if they are bonded
            show_labels (bool): If True, the atoms will be labeled
            group_by_element (bool): If True, the atoms will be colored by their element
            show_neighrbor_bonds (bool): If True, the bonds (estimated) will be shown
            neighbor_distance (float): The distance to search for neighbors
            save_path (str): The path to save the plot to, will not show the plot if given
            dont_show_plot (bool): If True, the plot will not be shown
            ax (matplotlib.axes.Axes): The axes to plot on (must be given if fig is given)
            fig (matplotlib.figure.Figure): The figure to plot on (must be given if ax is given)

        Returns:
            fig (matplotlib.figure.Figure): The figure object
    """

    if not isinstance(residue, Residue):
        raise TypeError('residue must be of type Residue')

    atom_id = __get_element if group_by_element else __get_name

    # Get the coordinates of all atoms in the residue
    coordinates = np.array([atom.get_coord() for atom in residue.get_atoms()])
    x, y, z = coordinates.T

    # Create the figure
    if ax is None and fig is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

    # Set the labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # Dict for random colors for each atom type
    color_map = {}

    if residue_map is None or group_by_element:
        # Create a random color for each atom type
        for atom in residue.get_atoms():
            if atom_id(atom) not in color_map.keys():
                if atom.element in DEFAULT_ELEMENT_COLOR_MAP.keys():
                    color_map[atom_id(
                        atom)] = DEFAULT_ELEMENT_COLOR_MAP[atom.element]
                else:
                    color_map[atom_id(atom)] = np.random.rand(3,)
    else:
        # Use the given map
        for type in residue_map.keys():
            if residue_map[type] not in color_map.keys():
                color_map[residue_map[type]] = np.random.rand(3,)

    # Color the atoms by their element
    colors = []

    if residue_map is None:
        colors = [color_map[atom_id(atom)] for atom in residue.get_atoms()]
    elif group_by_element:
        if residue_map is not None:
            for atom in residue.get_atoms():
                if atom.element in DEFAULT_ELEMENT_COLOR_MAP.keys():
                    colors.append(DEFAULT_ELEMENT_COLOR_MAP[atom.element])
                else:
                    colors.append(color_map[atom_id(atom)])
        else:
            colors = [color_map[atom_id(atom)] for atom in residue.get_atoms()]
    else:
        colors = [color_map[residue_map[atom.get_name()]]
                  for atom in residue.get_atoms()]

    # Add legend
    ax.legend(handles=[plt.Line2D([0], [0], color=color, linewidth=3, linestyle='-')
              for color in color_map.values()
                       ], labels=list(color_map.keys()))

    # Plot the coordinates
    ax.scatter(x, y, z, c=colors, edgecolors='black')

    # Add lines between the atoms that are bonded using the bond_map dict i -> j
    if bond_map is not None:
        for i in bond_map.keys():
            if isinstance(bond_map[i], list):
                for j in bond_map[i]:
                    at_from = list(residue.get_atoms())[i - 1]
                    at_to = list(residue.get_atoms())[j - 1]
                    ax.plot([at_from.get_coord()[0], at_to.get_coord()[0]], [
                            at_from.get_coord()[1], at_to.get_coord()[1]], [at_from.get_coord()[2], at_to.get_coord()[2]],
                            color='black', linestyle='--', linewidth=1)
            else:
                at_from = list(residue.get_atoms())[i - 1]
                at_to = list(residue.get_atoms())[bond_map[i] - 1]
                ax.plot([at_from.get_coord()[0], at_to.get_coord()[0]], [
                        at_from.get_coord()[1], at_to.get_coord()[1]], [at_from.get_coord()[2], at_to.get_coord()[2]],
                        color='black', linestyle='--', linewidth=1)

    elif show_neighrbor_bonds:
        ns = NeighborSearch(list(residue.get_atoms()))
        for atom in residue.get_atoms():
            neighbors = ns.search(atom.get_coord(), neighbor_distance)
            for neighbor in neighbors:
                ax.plot([atom.get_coord()[0], neighbor.get_coord()[0]], [
                        atom.get_coord()[1], neighbor.get_coord()[1]], [atom.get_coord()[2], neighbor.get_coord()[2]],
                        color='black', linestyle='--', linewidth=1)

    # Add labels to the atoms if show_labels is True
    if show_labels:
        if residue_map is None or group_by_element:
            for i, atom in enumerate(residue.get_atoms()):
                ax.text(atom.get_coord()[0], atom.get_coord()[
                        1], atom.get_coord()[2], atom_id(atom))
        else:
            for i, atom in enumerate(residue.get_atoms()):
                ax.text(atom.get_coord()[0], atom.get_coord()[
                        1], atom.get_coord()[2], residue_map[atom_id(atom)])

    # Save the plot
    if save_path:
        plt.savefig(save_path)

    # Show the plot
    if not dont_show_plot:
        plt.show()

    return ax


def show_dataset(
    name: str,
    residue_index: int,
    dont_show_plot=False,
):
    # Get all CG and AT datasets
    cg_datasets, at_datasets = get_cg_at_datasets(os.path.join("data", "raw"))
    cg_dataset, at_dataset = None, None

    # Find the dataset with the given name
    for dataset in cg_datasets:
        if dataset.parent == name:
            cg_dataset = dataset
            break
    else:
        raise ValueError(f"Could not find dataset with name {name}")
    
    for dataset in at_datasets:
        if dataset.parent == name:
            at_dataset = dataset
            break
    else:
        raise ValueError(f"Could not find dataset with name {name}")

    # Bond map for cg system
    cg_residue_map = cg_name_to_type_dict(os.path.join(
        "data", "topologies", "martini_v2.0_DOPC_02.itp"))
    cg_bond_map = cg_bond_map_dict(os.path.join(
        "data", "topologies", "martini_v2.0_DOPC_02.itp"))

    # Bond map for at system
    at_residue_map = cg_name_to_type_dict(os.path.join(
        "data", "raw", "CG2AT_2023-02-13_20-20-52", "FINAL", "DOPC.itp"))
    at_bond_map = cg_bond_map_dict(os.path.join(
        "data", "raw", "CG2AT_2023-02-13_20-20-52", "FINAL", "DOPC.itp"))

    # Create Folder if it does not exist
    if not os.path.exists(os.path.join("data", "figures", "raw_data", cg_dataset.parent)):
        os.makedirs(os.path.join("data", "figures",
                                    "raw_data", cg_dataset.parent))

    for i, (cg_residue, at_residue) in enumerate(zip(cg_dataset.get_residues(), at_dataset.get_residues())):
        if i == residue_index:
            # Plot the figures side by side
            fig = plt.figure(figsize=(10, 5))
            ax1 = fig.add_subplot(121, projection='3d')
            ax2 = fig.add_subplot(122, projection='3d')

            plot_residue(cg_residue,
                            residue_map=cg_residue_map, bond_map=cg_bond_map, show_labels=True, group_by_element=True, dont_show_plot=True, ax=ax1, fig=fig)

            plot_residue(at_residue, residue_map=at_residue_map,
                            bond_map=at_bond_map, show_labels=False, group_by_element=True, dont_show_plot=True, ax=ax2, fig=fig)

            # Add title
            fig.suptitle(f'CG vs AT ({cg_dataset.parent})', fontsize=16)

            # Save the figure
            if not dont_show_plot:
                plt.show()
            else:
                plt.savefig(os.path.join(f"{cg_dataset.parent}-{i}.png"))
                print(f"Saved figure to {os.path.join(f'{cg_dataset.parent}-{i}.png')}")

Functions

def plot_residue(residue: Bio.PDB.Residue.Residue, residue_map=None, bond_map=None, show_labels=False, group_by_element=False, show_neighrbor_bonds=False, neighbor_distance=0.5, save_path=None, dont_show_plot=False, ax=None, fig=None) ‑> matplotlib.figure.Figure

Plots the residue using matplotlib

Parameters

residue (Residue): The residue to plot map (dict): A dictionary mapping atom names to type bond_map (dict): A set mapping atom i to atom j if they are bonded show_labels (bool): If True, the atoms will be labeled group_by_element (bool): If True, the atoms will be colored by their element show_neighrbor_bonds (bool): If True, the bonds (estimated) will be shown neighbor_distance (float): The distance to search for neighbors save_path (str): The path to save the plot to, will not show the plot if given dont_show_plot (bool): If True, the plot will not be shown ax (matplotlib.axes.Axes): The axes to plot on (must be given if fig is given) fig (matplotlib.figure.Figure): The figure to plot on (must be given if ax is given)

Returns

fig (matplotlib.figure.Figure): The figure object

Expand source code
def plot_residue(
    residue: Residue,
    residue_map=None,
    bond_map=None,
    show_labels=False,
    group_by_element=False,
    show_neighrbor_bonds=False,
    neighbor_distance=0.5,
    save_path=None,
    dont_show_plot=False,
    ax=None,
    fig=None
) -> plt.Figure:
    """
        Plots the residue using matplotlib

        Parameters:
            residue (Residue): The residue to plot
            map (dict): A dictionary mapping atom names to type
            bond_map (dict): A set mapping atom i to atom j if they are bonded
            show_labels (bool): If True, the atoms will be labeled
            group_by_element (bool): If True, the atoms will be colored by their element
            show_neighrbor_bonds (bool): If True, the bonds (estimated) will be shown
            neighbor_distance (float): The distance to search for neighbors
            save_path (str): The path to save the plot to, will not show the plot if given
            dont_show_plot (bool): If True, the plot will not be shown
            ax (matplotlib.axes.Axes): The axes to plot on (must be given if fig is given)
            fig (matplotlib.figure.Figure): The figure to plot on (must be given if ax is given)

        Returns:
            fig (matplotlib.figure.Figure): The figure object
    """

    if not isinstance(residue, Residue):
        raise TypeError('residue must be of type Residue')

    atom_id = __get_element if group_by_element else __get_name

    # Get the coordinates of all atoms in the residue
    coordinates = np.array([atom.get_coord() for atom in residue.get_atoms()])
    x, y, z = coordinates.T

    # Create the figure
    if ax is None and fig is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

    # Set the labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # Dict for random colors for each atom type
    color_map = {}

    if residue_map is None or group_by_element:
        # Create a random color for each atom type
        for atom in residue.get_atoms():
            if atom_id(atom) not in color_map.keys():
                if atom.element in DEFAULT_ELEMENT_COLOR_MAP.keys():
                    color_map[atom_id(
                        atom)] = DEFAULT_ELEMENT_COLOR_MAP[atom.element]
                else:
                    color_map[atom_id(atom)] = np.random.rand(3,)
    else:
        # Use the given map
        for type in residue_map.keys():
            if residue_map[type] not in color_map.keys():
                color_map[residue_map[type]] = np.random.rand(3,)

    # Color the atoms by their element
    colors = []

    if residue_map is None:
        colors = [color_map[atom_id(atom)] for atom in residue.get_atoms()]
    elif group_by_element:
        if residue_map is not None:
            for atom in residue.get_atoms():
                if atom.element in DEFAULT_ELEMENT_COLOR_MAP.keys():
                    colors.append(DEFAULT_ELEMENT_COLOR_MAP[atom.element])
                else:
                    colors.append(color_map[atom_id(atom)])
        else:
            colors = [color_map[atom_id(atom)] for atom in residue.get_atoms()]
    else:
        colors = [color_map[residue_map[atom.get_name()]]
                  for atom in residue.get_atoms()]

    # Add legend
    ax.legend(handles=[plt.Line2D([0], [0], color=color, linewidth=3, linestyle='-')
              for color in color_map.values()
                       ], labels=list(color_map.keys()))

    # Plot the coordinates
    ax.scatter(x, y, z, c=colors, edgecolors='black')

    # Add lines between the atoms that are bonded using the bond_map dict i -> j
    if bond_map is not None:
        for i in bond_map.keys():
            if isinstance(bond_map[i], list):
                for j in bond_map[i]:
                    at_from = list(residue.get_atoms())[i - 1]
                    at_to = list(residue.get_atoms())[j - 1]
                    ax.plot([at_from.get_coord()[0], at_to.get_coord()[0]], [
                            at_from.get_coord()[1], at_to.get_coord()[1]], [at_from.get_coord()[2], at_to.get_coord()[2]],
                            color='black', linestyle='--', linewidth=1)
            else:
                at_from = list(residue.get_atoms())[i - 1]
                at_to = list(residue.get_atoms())[bond_map[i] - 1]
                ax.plot([at_from.get_coord()[0], at_to.get_coord()[0]], [
                        at_from.get_coord()[1], at_to.get_coord()[1]], [at_from.get_coord()[2], at_to.get_coord()[2]],
                        color='black', linestyle='--', linewidth=1)

    elif show_neighrbor_bonds:
        ns = NeighborSearch(list(residue.get_atoms()))
        for atom in residue.get_atoms():
            neighbors = ns.search(atom.get_coord(), neighbor_distance)
            for neighbor in neighbors:
                ax.plot([atom.get_coord()[0], neighbor.get_coord()[0]], [
                        atom.get_coord()[1], neighbor.get_coord()[1]], [atom.get_coord()[2], neighbor.get_coord()[2]],
                        color='black', linestyle='--', linewidth=1)

    # Add labels to the atoms if show_labels is True
    if show_labels:
        if residue_map is None or group_by_element:
            for i, atom in enumerate(residue.get_atoms()):
                ax.text(atom.get_coord()[0], atom.get_coord()[
                        1], atom.get_coord()[2], atom_id(atom))
        else:
            for i, atom in enumerate(residue.get_atoms()):
                ax.text(atom.get_coord()[0], atom.get_coord()[
                        1], atom.get_coord()[2], residue_map[atom_id(atom)])

    # Save the plot
    if save_path:
        plt.savefig(save_path)

    # Show the plot
    if not dont_show_plot:
        plt.show()

    return ax
def show_dataset(name: str, residue_index: int, dont_show_plot=False)
Expand source code
def show_dataset(
    name: str,
    residue_index: int,
    dont_show_plot=False,
):
    # Get all CG and AT datasets
    cg_datasets, at_datasets = get_cg_at_datasets(os.path.join("data", "raw"))
    cg_dataset, at_dataset = None, None

    # Find the dataset with the given name
    for dataset in cg_datasets:
        if dataset.parent == name:
            cg_dataset = dataset
            break
    else:
        raise ValueError(f"Could not find dataset with name {name}")
    
    for dataset in at_datasets:
        if dataset.parent == name:
            at_dataset = dataset
            break
    else:
        raise ValueError(f"Could not find dataset with name {name}")

    # Bond map for cg system
    cg_residue_map = cg_name_to_type_dict(os.path.join(
        "data", "topologies", "martini_v2.0_DOPC_02.itp"))
    cg_bond_map = cg_bond_map_dict(os.path.join(
        "data", "topologies", "martini_v2.0_DOPC_02.itp"))

    # Bond map for at system
    at_residue_map = cg_name_to_type_dict(os.path.join(
        "data", "raw", "CG2AT_2023-02-13_20-20-52", "FINAL", "DOPC.itp"))
    at_bond_map = cg_bond_map_dict(os.path.join(
        "data", "raw", "CG2AT_2023-02-13_20-20-52", "FINAL", "DOPC.itp"))

    # Create Folder if it does not exist
    if not os.path.exists(os.path.join("data", "figures", "raw_data", cg_dataset.parent)):
        os.makedirs(os.path.join("data", "figures",
                                    "raw_data", cg_dataset.parent))

    for i, (cg_residue, at_residue) in enumerate(zip(cg_dataset.get_residues(), at_dataset.get_residues())):
        if i == residue_index:
            # Plot the figures side by side
            fig = plt.figure(figsize=(10, 5))
            ax1 = fig.add_subplot(121, projection='3d')
            ax2 = fig.add_subplot(122, projection='3d')

            plot_residue(cg_residue,
                            residue_map=cg_residue_map, bond_map=cg_bond_map, show_labels=True, group_by_element=True, dont_show_plot=True, ax=ax1, fig=fig)

            plot_residue(at_residue, residue_map=at_residue_map,
                            bond_map=at_bond_map, show_labels=False, group_by_element=True, dont_show_plot=True, ax=ax2, fig=fig)

            # Add title
            fig.suptitle(f'CG vs AT ({cg_dataset.parent})', fontsize=16)

            # Save the figure
            if not dont_show_plot:
                plt.show()
            else:
                plt.savefig(os.path.join(f"{cg_dataset.parent}-{i}.png"))
                print(f"Saved figure to {os.path.join(f'{cg_dataset.parent}-{i}.png')}")