RINDTI: Simplifying Drug-Target Interaction Prediction with Protein Residue Interaction Networks

public public 1yr ago 0 bookmarks

This repository aims to simplify the drug-target interaction prediction process which is based on protein residue interaction networks (RINs)

Overview

The repository aims to go from a simple collections of inputs - structures of proteins, interactions data on drugs to a fully-function GNN model

Installation

  1. clone the repository with git clone https://github.com/ilsenatorov/rindti

  2. change in the root directory with cd rindti

  3. (Optional) install mamba with conda install -n base -c conda-forge mamba

  4. create the conda environment with mamba env create -f workflow/envs/main.yaml (might take some time)

  5. activate the environment with conda activate rindti

  6. Test the installation with pytest

Documentation

Check out the documentation to get more information.

Contributing

If you would like to contribute to the repository, please check out the contributing guide .

Code Snippets

 9
10
script:
    "../scripts/parse_dataset.py"
22
23
script:
    "../scripts/split_data.py"
SnakeMake From line 22 of rules/data.smk
33
34
script:
    "../scripts/prepare_all.py"
SnakeMake From line 33 of rules/data.smk
13
14
script:
    "../scripts/prepare_drugs.py"
39
40
41
42
shell:
    """
    rinerator {input.pdb} {params.dir}/{wildcards.prot} > {log} 2>&1
    """
53
54
script:
    "../scripts/parse_rinerator.py"
65
66
script:
    "../scripts/distance_based.py"
74
75
script:
    "../scripts/prot_esm.py"
84
85
script:
    "../scripts/pretrain_prot_data.py"
28
29
script:
    "../scripts/create_pymol_scripts.py"
42
43
shell:
    "pymol -k -y -c {input.script} > {log} 2>&1"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import resource


def create_script(protein: str, inp: str, params: dict):
    """
    Create pymol parsing script for a protein according to the params.
    """
    resources = params.resources
    results = params.results
    fmt_keywords = {"protein": protein, "resources": resources, "results": results}
    script = [
        "import psico.fullinit",
        "from glob import glob",
        'cmd.load("{inp}")',
    ]

    if params.method == "plddt":
        script.append('cmd.select("result", "b > {threshold}")')
        fmt_keywords["threshold"] = params.other_params[params.method]["threshold"]
    else:
        # template-based
        script += [
            'lst = glob("{resources}/templates/*.pdb")',
            'templates = [x.split("/")[-1].split(".")[0] for x in lst]',
            "for i in lst:cmd.load(i)",
            'scores = {{x : cmd.tmalign("{protein}", x) for x in templates}}',
            "max_score = max(scores, key=scores.get)",
            'cmd.extra_fit("name CA", max_score, "tmalign")',
        ]

        fmt_keywords["radius"] = params.other_params[params.method]["radius"]
        if params.method == "bsite":
            script.append('cmd.select("result", "br. {protein} within {radius} of organic")')
        elif params.method == "template":
            script.append('cmd.select("result", "br. {protein} within {radius} of not {protein} and name CA")')
    script.append('cmd.save("{parsed_structs_dir}/{protein}.pdb", "result")')
    fmt_keywords["parsed_structs_dir"] = params.parsed_structs_dir
    fmt_keywords["structs"] = params.method
    fmt_keywords["inp"] = inp
    return "\n".join(script).format(**fmt_keywords)


if __name__ == "__main__":
    for inp, out in zip(snakemake.input, snakemake.output):
        protein = os.path.basename(out).split(".")[0]

        with open(out, "w") as file:
            file.write(create_script(protein, inp, snakemake.params))
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from encd import encd
from utils import onehot_encode

node_encoding = encd["prot"]["node"]


def encode_residue(residue: str, node_feats: str):
    """Encode a residue"""
    residue = residue.lower()
    if node_feats == "label":
        if residue not in node_encoding:
            return node_encoding["unk"]
        return node_encoding[residue] + 1
    elif node_feats == "onehot":
        return onehot_encode(node_encoding[residue], len(node_encoding))
    else:
        raise ValueError("Unknown node_feats type!")


class Residue:
    """Residue class"""

    def __init__(self, line: str) -> None:
        self.name = line[17:20].strip()
        self.num = int(line[22:26].strip())
        self.chainID = line[21].strip()
        self.x = float(line[30:38].strip())
        self.y = float(line[38:46].strip())
        self.z = float(line[46:54].strip())


class Structure:
    """Structure class"""

    def __init__(self, filename: str, node_feats: str) -> None:
        self.residues = {}
        self.parse_file(filename)
        self.node_feats = node_feats

    def parse_file(self, filename: str) -> None:
        """Parse PDB file"""
        for line in open(filename, "r"):
            if line.startswith("ATOM") and line[12:16].strip() == "CA":
                res = Residue(line)
                self.residues[res.num] = res

    def get_coords(self) -> torch.Tensor:
        """Get coordinates of all atoms"""
        coords = [[res.x, res.y, res.z] for res in self.residues.values()]
        return torch.tensor(coords)

    def get_nodes(self) -> torch.Tensor:
        """Get features of all nodes of a graph"""
        return torch.tensor([encode_residue(res.name, self.node_feats) for res in self.residues.values()])

    def get_edges(self, threshold: float) -> torch.Tensor:
        """Get edges of a graph using threshold as a cutoff"""
        coords = self.get_coords()
        dist = torch.cdist(coords, coords)
        edges = torch.where(dist < threshold)
        edges = torch.cat([arr.view(-1, 1) for arr in edges], axis=1)
        edges = edges[edges[:, 0] != edges[:, 1]]
        return edges.t()

    def get_graph(self, threshold: float) -> dict:
        """Get a graph using threshold as a cutoff"""
        nodes = self.get_nodes()
        edges = self.get_edges(threshold)
        return dict(x=nodes, edge_index=edges)


if __name__ == "__main__":
    import pickle

    import pandas as pd
    from joblib import Parallel, delayed
    from tqdm import tqdm

    if "snakemake" in globals():
        all_structures = snakemake.input.pdbs
        threshold = snakemake.params.threshold

        def get_graph(filename: str) -> dict:
            """Single function to be run in parallel."""
            return Structure(filename, snakemake.params.node_feats).get_graph(threshold)

        data = Parallel(n_jobs=snakemake.threads)(delayed(get_graph)(i) for i in tqdm(all_structures))
        df = pd.DataFrame(pd.Series(data, name="data"))
        df["filename"] = all_structures
        df["ID"] = df["filename"].apply(lambda x: x.split("/")[-1].split(".")[0])
        df.set_index("ID", inplace=True)
        df.drop("filename", axis=1, inplace=True)
        df = df.to_pickle(snakemake.output.pickle)
    else:
        import os
        import os.path as osp

        from jsonargparse import CLI

        def run(pdb_dir: str, output: str, threads: int = 1, threshold: float = 5, node_feats: str = "label"):
            """Run the pipeline"""

            def get_graph(filename: str) -> dict:
                """Calculate a single graph from a file"""
                return Structure(filename, node_feats).get_graph(threshold)

            pdbs = [osp.join(pdb_dir, x) for x in os.listdir(pdb_dir)]
            data = Parallel(n_jobs=threads)(delayed(get_graph)(i) for i in tqdm(pdbs))
            df = pd.DataFrame(pd.Series(data, name="data"))
            df["filename"] = pdbs
            df["ID"] = df["filename"].apply(lambda x: x.split("/")[-1].split(".")[0])
            df.set_index("ID", inplace=True)
            df.drop("filename", axis=1, inplace=True)
            df = df.to_dict("index")
            df.to_pickle(output)

        cli = CLI(run)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import pandas as pd


def posneg_filter(inter: pd.DataFrame) -> pd.DataFrame:
    """Only keep drugs that have at least 1 positive and negative interaction"""
    pos = inter[inter["Y"] == 1]["Drug_ID"].unique()
    neg = inter[inter["Y"] == 0]["Drug_ID"].unique()
    both = set(pos).intersection(set(neg))
    inter = inter[inter["Drug_ID"].isin(both)]
    return inter


def sample(inter: pd.DataFrame, how: str = "under") -> pd.DataFrame:
    """Sample the interactions dataset

    Args:
        inter (pd.DataFrame): whole data, has to be binary class
        how (str, optional): over or undersample.
        Oversample adds fake negatives, undersample removed extra positives. Defaults to "under".
    """
    if how == "none":
        return inter
    total = []
    pos = inter[inter["Y"] == 1]
    neg = inter[inter["Y"] == 0]
    for prot in inter["Target_ID"].unique():
        possample = pos[pos["Target_ID"] == prot]
        negsample = neg[neg["Target_ID"] == prot]
        poscount = possample.shape[0]
        negcount = negsample.shape[0]
        if poscount == 0:
            continue
        if poscount >= negcount:
            if how == "under":
                total.append(possample.sample(negcount))
                total.append(negsample)
            elif how == "over":
                total.append(possample)
                total.append(negsample)
                subsample = inter[inter["Target_ID"] != prot].sample(poscount - negcount)
                subsample["Target_ID"] = prot
                subsample["Y"] = 0
                total.append(subsample)
            else:
                raise ValueError("Unknown sampling method!")
        else:
            total.append(possample)
            total.append(negsample.sample(poscount))
    return pd.concat(total)


if __name__ == "__main__":

    from pytorch_lightning import seed_everything

    seed_everything(snakemake.config["seed"])

    inter = pd.read_csv(snakemake.input.inter, sep="\t")

    config = snakemake.config["parse_dataset"]
    # If duplicates, take median of entries
    inter = inter.groupby(["Drug_ID", "Target_ID"]).agg("median").reset_index()
    if config["task"] == "class":
        inter["Y"] = inter["Y"].apply(lambda x: int(x < config["threshold"]))
    elif config["task"] == "reg":
        if config["log"]:
            inter["Y"] = inter["Y"].apply(np.log10)
    else:
        raise ValueError("Unknown task!")

    if config["filtering"] != "all" and config["sampling"] != "none" and config["task"] == "reg":
        raise ValueError(
            "Can't use filtering {filter} with task {task}!".format(filter=config["filtering"], task=config["task"])
        )

    if config["filtering"] == "posneg":
        inter = posneg_filter(inter)
    elif config["filtering"] != "all":
        raise ValueError("No such type of filtering!")

    inter = sample(inter, how=config["sampling"])

    inter.to_csv(snakemake.output.inter, index=False, sep="\t")
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from encd import encd
from utils import onehot_encode


class ProteinEncoder:
    """Performs all the encoding steps for a single sif file."""

    def __init__(self, node_feats: str, edge_feats: str):
        self.node_feats = node_feats
        self.edge_feats = edge_feats

    def encode_residue(self, residue: str) -> np.array:
        """Fully encode residue - one-hot and node_feats
        Args:
            residue (str): One-letter residue name
        Returns:
            np.array: Concatenated node_feats and one-hot encoding of residue name
        """
        residue = residue.lower()
        if self.node_feats == "label":
            return encd["prot"]["node"][residue] + 1
        elif self.node_feats == "onehot":
            return onehot_encode(encd["prot"]["node"][residue], len(encd["prot"]["node"]))
        else:
            raise ValueError("Unknown node_feats type!")

    def parse_sif(self, filename: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Parse a single sif file
        Args:
            filename (str): SIF file location
        Returns:
            Tuple[DataFrame, DataFrame]: nodes, edges DataFrames
        """
        nodes = []
        edges = []
        if not os.path.exists(filename):
            return None, None
        with open(filename, "r") as file:
            for line in file:
                line = line.strip()
                splitline = line.split()
                if len(splitline) != 3:
                    continue
                node1, edgetype, node2 = splitline
                node1split = node1.split(":")
                node2split = node2.split(":")
                if len(node1split) != 4:
                    continue
                if len(node2split) != 4:
                    continue
                chain1, resn1, x1, resaa1 = node1split
                chain2, resn2, x2, resaa2 = node2split
                if x1 != "_" or x2 != "_":
                    continue
                if resaa1.lower() not in encd["prot"]["node"] or resaa2.lower() not in encd["prot"]["node"]:
                    continue
                resn1 = int(resn1)
                resn2 = int(resn2)
                if resn1 == resn2:
                    continue
                edgesplit = edgetype.split(":")
                if len(edgesplit) != 2:
                    continue
                node1 = {"chain": chain1, "resn": resn1, "resaa": resaa1}
                node2 = {"chain": chain2, "resn": resn2, "resaa": resaa2}
                edgetype, _ = edgesplit
                edge1 = {
                    "resn1": resn1,
                    "resn2": resn2,
                    "type": edgetype,
                }
                edge2 = {
                    "resn1": resn2,
                    "resn2": resn1,
                    "type": edgetype,
                }
                nodes.append(node1)
                nodes.append(node2)
                edges.append(edge1)
                edges.append(edge2)
        nodes = pd.DataFrame(nodes).drop_duplicates()
        try:
            nodes = nodes.sort_values("resn").reset_index(drop=True).reset_index().set_index("resn")
        except Exception as e:
            print(nodes)
            print(filename)
            print(e)
            return None, None
        for node in nodes.index:
            if (node - 1) in nodes.index:
                edges.append({"resn1": node, "resn2": node - 1, "type": "pept"})
                edges.append({"resn2": node, "resn1": node - 1, "type": "pept"})
        edges = pd.DataFrame(edges).drop_duplicates()
        node_idx = nodes["index"].to_dict()
        edges["node1"] = edges["resn1"].apply(lambda x: node_idx[x])
        edges["node2"] = edges["resn2"].apply(lambda x: node_idx[x])
        return nodes, edges

    def encode_nodes(self, nodes: pd.DataFrame) -> torch.Tensor:
        """Given dataframe of nodes create node node_feats
        Args:
            nodes (pd.DataFrame): nodes dataframe from parse_sif
        Returns:
            torch.Tensor: Tensor of node node_feats [n_nodes, *]
        """
        nodes.drop_duplicates(inplace=True)
        node_attr = [self.encode_residue(x) for x in nodes["resaa"]]
        node_attr = np.asarray(node_attr)
        if self.node_feats == "label":
            node_attr = torch.tensor(node_attr, dtype=torch.long)
        else:
            node_attr = torch.tensor(node_attr, dtype=torch.float32)
        return node_attr

    def encode_edges(self, edges: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
        """Given dataframe of edges, create edge index and edge node_feats
        Args:
            edges (pd.DataFrame): edges dataframe from parse_sif
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: edge index [2,n_edges], edge attributes [n_edges, *]
        """
        if self.edge_feats == "none":
            edges.drop("type", axis=1, inplace=True)
        edges.drop_duplicates(inplace=True)
        edge_index = edges[["node1", "node2"]].astype(int).values
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        edge_index = edge_index.t().contiguous()
        if self.edge_feats == "none":
            return edge_index, None
        edge_feats = edges["type"].apply(lambda x: encd["prot"]["edge"][x])
        if self.edge_feats == "label":
            edge_feats = torch.tensor(edge_feats, dtype=torch.long)
            return edge_index, edge_feats
        elif self.edge_feats == "onehot":
            edge_feats = edge_feats.apply(onehot_encode, count=len(encd["prot"]["edge"]))
            edge_feats = torch.tensor(edge_feats, dtype=torch.float)
            return edge_index, edge_feats

    def __call__(self, protein_sif: str) -> dict:
        """Fully process the protein
        Args:
            protein_sif (str): File location for sif file
        Returns:
            dict: standard format with x for node node_feats, edge_index for edges etc
        """
        try:
            nodes, edges = self.parse_sif(protein_sif)
            if nodes is None:
                return np.nan
            node_attr = self.encode_nodes(nodes)
            edge_index, edge_feats = self.encode_edges(edges)
            return dict(
                x=node_attr,
                edge_index=edge_index,
                edge_feats=edge_feats,
                # index_mapping=nodes["index"].to_dict(),
            )
        except Exception as e:
            print(protein_sif)
            print(e)
            return np.nan


def extract_name(protein_sif: str) -> str:
    """Extract the protein name from the sif filename"""
    return protein_sif.split("/")[-1].split("_")[0]


if __name__ == "__main__":
    if "snakemake" in globals():
        prots = pd.Series(list(snakemake.input.rins), name="sif")
        prots = pd.DataFrame(prots)
        prots["ID"] = prots["sif"].apply(extract_name)
        prots.set_index("ID", inplace=True)
        prot_encoder = ProteinEncoder(snakemake.params.node_feats, snakemake.params.edge_feats)
        prots["data"] = prots["sif"].apply(prot_encoder)
        prots.to_pickle(snakemake.output.pickle)
    else:
        import argparse

        from joblib import Parallel, delayed
        from tqdm import tqdm

        parser = argparse.ArgumentParser(description="Prepare protein data from rinerator")
        parser.add_argument("--sifs", nargs="+", required=True, help="Rinerator output folders")
        parser.add_argument("--output", required=True, help="Output pickle file")
        parser.add_argument("--node_feats", type=str, default="label")
        parser.add_argument("--edge_feats", type=str, default="none")
        parser.add_argument("--threads", type=int, default=1, help="Number of threads to use")
        args = parser.parse_args()

        prots = pd.DataFrame(pd.Series(args.sifs, name="sif"))
        prots["ID"] = prots["sif"].apply(extract_name)
        prots.set_index("ID", inplace=True)
        prot_encoder = ProteinEncoder(args.node_feats, args.edge_feats)
        data = Parallel(n_jobs=args.threads)(delayed(prot_encoder)(i) for i in tqdm(prots["sif"]))
        prots["data"] = data
        prots.to_pickle(args.output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
from typing import Iterable

import pandas as pd
from pandas.core.frame import DataFrame
from utils import get_config


def process(row: pd.Series) -> dict:
    """Process each interaction."""
    split = row["split"]
    return {
        "label": row["Y"],
        "split": split,
        "prot_id": row["Target_ID"],
        "drug_id": row["Drug_ID"],
    }


def process_df(df: DataFrame) -> Iterable[dict]:
    """Apply process() function to each row of the DataFrame"""
    return [process(row) for (_, row) in df.iterrows()]


def del_index_mapping(x: dict) -> dict:
    """Delete 'index_mapping' entry from the dict"""
    if "index_mapping" in x:
        del x["index_mapping"]
    return x


if __name__ == "__main__":

    interactions = pd.read_csv(snakemake.input.inter, sep="\t")

    with open(snakemake.input.drugs, "rb") as file:
        drugs = pickle.load(file)

    with open(snakemake.input.prots, "rb") as file:
        prots = pickle.load(file)

    interactions = interactions[interactions["Target_ID"].isin(prots.index)]
    interactions = interactions[interactions["Drug_ID"].isin(drugs.index)]

    prots = prots[prots.index.isin(interactions["Target_ID"].unique())]
    drugs = drugs[drugs.index.isin(interactions["Drug_ID"].unique())]

    prot_count = interactions["Target_ID"].value_counts()
    drug_count = interactions["Drug_ID"].value_counts()

    prots["data"] = prots.apply(lambda x: {**x["data"], "count": prot_count[x.name]}, axis=1)
    drugs["data"] = drugs.apply(lambda x: {**x["data"], "count": drug_count[x.name]}, axis=1)

    full_data = process_df(interactions)
    snakemake.config["data"] = {
        "prot": get_config(prots, "prot"),
        "drug": get_config(drugs, "drug"),
    }

    final_data = {
        "data": full_data,
        "config": snakemake.config,
        "prots": prots,
        "drugs": drugs,
    }

    with open(snakemake.output.combined_pickle, "wb") as file:
        pickle.dump(final_data, file, protocol=-1)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import pandas as pd
import torch
from encd import encd
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops
from torch_geometric.utils import to_undirected
from utils import onehot_encode


class DrugEncoder:
    """Drug encoder, goes from SMILES to dictionary of torch data

    Args:
        node_feats (str): 'label' or 'onehot'
        edge_feats (str): 'label' or 'onehot
        max_num_atoms (int, optional): filter out molecules that are too big. Defaults to 150.
    """

    def __init__(self, node_feats: str, edge_feats: str, max_num_atoms: int = 150):
        assert node_feats in {"label", "onehot", "glycan", "glycanone", "IUPAC"}
        assert edge_feats in {"label", "onehot", "none"}
        self.node_feats = node_feats
        self.edge_feats = edge_feats
        self.max_num_atoms = max_num_atoms

    def encode_node(self, atom_num, atom):
        """Encode single atom"""
        if atom_num not in encd["drug"]["node"].keys():
            atom_num = "other"

        if self.node_feats == "glycan":
            if atom_num in encd["glycan"]:
                return encd["glycan"][atom_num] + encd["chirality"][atom.GetChiralTag()]
            else:
                return encd["glycan"]["other"] + encd["chirality"][atom.GetChiralTag()]

        label = encd["drug"]["node"][atom_num]
        if self.node_feats == "onehot":
            return onehot_encode(label, len(encd["drug"]["node"]))
        return label + 1

    def encode_edge(self, edge):
        """Encode single edge"""
        label = encd["drug"]["edge"][edge]
        if self.edge_feats == "onehot":
            return onehot_encode(label, len(encd["drug"]["edge"]))
        elif self.edge_feats == "label":
            return label
        else:
            raise ValueError("This shouldn't be called for edge type none")

    def __call__(self, smiles: str) -> dict:
        """Generate drug Data from smiles

        Args:
            smiles (str): SMILES

        Returns:
            dict: dict with x, edge_index etc or np.nan for bad entries
        """
        if smiles != smiles:  # check for nans, i.e. missing smiles strings in dataset
            return np.nan
        mol = Chem.MolFromSmiles(smiles)
        if not mol:  # when rdkit fails to read a molecule it returns None
            return np.nan
        new_order = rdmolfiles.CanonicalRankAtoms(mol)
        mol = rdmolops.RenumberAtoms(mol, new_order)
        edges = []
        edge_feats = [] if self.edge_feats != "none" else None
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edges.append([start, end])
            btype = str(bond.GetBondType())
            # If bond type is unknown, remove molecule
            if btype not in encd["drug"]["edge"].keys():
                return np.nan
            if self.edge_feats != "none":
                edge_feats.append(self.encode_edge(btype))
        if not edges:  # If no edges (bonds) were found, remove molecule
            return np.nan
        atom_features = []
        for atom in mol.GetAtoms():
            atom_num = atom.GetAtomicNum()
            atom_features.append(self.encode_node(atom_num, atom))
        if len(atom_features) > self.max_num_atoms:
            return np.nan
        if self.node_feats == "label":
            x = torch.tensor(atom_features, dtype=torch.long)
        else:
            x = torch.tensor(atom_features, dtype=torch.float32)
        edge_index = torch.tensor(edges).t().contiguous()
        if self.edge_feats == "onehot":
            edge_feats = torch.tensor(edge_feats, dtype=torch.float32)
        elif self.edge_feats == "label":
            edge_feats = torch.tensor(edge_feats, dtype=torch.long)
        elif self.edge_feats == "none":
            edge_feats = None
        else:
            raise ValueError("Unknown edge encoding!")
        if self.edge_feats != "none":
            edge_index, edge_feats = to_undirected(edge_index, edge_feats)
        else:
            edge_index = to_undirected(edge_index)
        return dict(x=x, edge_index=edge_index, edge_feats=edge_feats)


if __name__ == "__main__":
    drug_enc = DrugEncoder(snakemake.params.node_feats, snakemake.params.edge_feats, snakemake.params.max_num_atoms)
    ligs = pd.read_csv(snakemake.input.lig, sep="\t").set_index("Drug_ID")
    ligs["data"] = ligs["Drug"].apply(drug_enc)
    ligs = ligs[ligs["data"].notna()]
    ligs = ligs.to_pickle(snakemake.output.pickle)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pickle

import pandas as pd
from utils import get_config

prot_table = pd.read_csv(snakemake.input.prot_table, sep="\t")
prot_data = pd.read_pickle(snakemake.input.prot_data)
prot_y = prot_table.set_index("Target_ID")["Y"].to_dict()

dims_config = get_config(prot_data, "prot")
dims_config["num_classes"] = len(prot_y)
snakemake.config["prots"]["data"] = dims_config

y_encoder = {v: k for k, v in enumerate(sorted(set(prot_y.values())))}

result = []
for k, v in prot_data["data"].items():
    v["y"] = y_encoder[prot_y[k]]
    v["id"] = k
    result.append(v)

with open(snakemake.output.pretrain_prot_data, "wb") as file:
    pickle.dump(
        {
            "data": result,
            "config": snakemake.config["prots"],
            "decoder": {v: k for k, v in y_encoder.items()},
        },
        file,
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os

import esm
import pandas as pd
import torch
from extract_esm import create_parser
from extract_esm import main as extract_main


def generate_esm_python(prot: pd.DataFrame) -> pd.DataFrame:
    """Return esms."""

    model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()
    model.eval()  # disables dropout for deterministic results
    prot.set_index("Target_ID", inplace=True)
    data = [(k, v) for k, v in prot["Target"].to_dict().items()]

    _, _, batch_tokens = batch_converter(data)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]

    sequence_representations = []
    for i, (_, seq) in enumerate(data):
        sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))
    data = [{"x": x} for x in sequence_representations]
    prot["data"] = data
    prot = prot.to_dict("index")
    return prot


def generate_esm_script(prot: pd.DataFrame) -> pd.DataFrame:
    """Create an ESM script for btach processing."""
    prot_ids, seqs = list(zip(*[(k, v) for k, v in prot["Target"].to_dict().items()]))
    os.makedirs("./esms", exist_ok=True)
    with open("./esms/prots.fasta", "w") as fasta:
        for prot_id, seq in zip(prot_ids, seqs):
            fasta.write(f">{prot_id}\n{seq[:1022]}\n")

    esm_parser = create_parser()
    esm_args = esm_parser.parse_args(
        ["esm1b_t33_650M_UR50S", "esms/prots.fasta", "esms/", "--repr_layers", "33", "--include", "mean"]
    )
    extract_main(esm_args)
    data = []
    for prot_id in prot_ids:
        data.append({"x": torch.load(f"./esms/{prot_id}.pt")["mean_representations"][33].unsqueeze(0)})
    # os.rmdir("./esms")
    prot["data"] = data
    # prot = prot.to_dict("index")
    return prot


if __name__ == "__main__":
    import pickle

    prots = pd.read_csv(snakemake.input.seqs, sep="\t").set_index("Target_ID")
    prots = generate_esm_script(prots)
    prots.to_pickle(snakemake.output.pickle)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


def split_groups(
    inter: pd.DataFrame,
    col_name: str = "Target_ID",
    bin_size: int = 10,
    train_frac: float = 0.7,
    val_frac: float = 0.2,
) -> pd.DataFrame:
    """Split data by protein (cold-target)
    Tries to ensure good size of all sets by sorting the prots by number of interactions
    and performing splits within bins of 10

    Args:
        inter (pd.DataFrame): interaction DataFrame
        col_name (str): Which column to split on (col_name or 'Drug_ID' usually)
        bin_size (int, optional): Size of the bins to perform individual splits in. Defaults to 10.
        train_frac (float, optional): value from 0 to 1, how much of the data goes into train
        val_frac (float, optional): value from 0 to 1, how much of the data goes into validation

    Returns:
        pd.DataFrame: DataFrame with a new 'split' column
    """
    sorted_index = [x for x in inter[col_name].value_counts().index]
    train_prop = int(bin_size * train_frac)
    val_prop = int(bin_size * val_frac)
    train = []
    val = []
    test = []
    for i in range(0, len(sorted_index), bin_size):
        subset = sorted_index[i : i + bin_size]
        train_bin = list(np.random.choice(subset, min(len(subset), train_prop), replace=False))
        train += train_bin
        subset = [x for x in subset if x not in train_bin]
        val_bin = list(np.random.choice(subset, min(len(subset), val_prop), replace=False))
        val += val_bin
        subset = [x for x in subset if x not in val_bin]
        test += subset
    train_idx = inter[inter[col_name].isin(train)].index
    val_idx = inter[inter[col_name].isin(val)].index
    test_idx = inter[inter[col_name].isin(test)].index
    inter.loc[train_idx, "split"] = "train"
    inter.loc[val_idx, "split"] = "val"
    inter.loc[test_idx, "split"] = "test"
    return inter


def split_random(inter: pd.DataFrame, train_frac: float = 0.7, val_frac: float = 0.2) -> pd.DataFrame:
    """Split the dataset in a completely random fashion

    Args:
        inter (pd.DataFrame): interaction DataFrame
        train_frac (float, optional): value from 0 to 1, how much of the data goes into train
        val_frac (float, optional): value from 0 to 1, how much of the data goes into validation

    Returns:
        pd.DataFrame: DataFrame with a new 'split' column
    """
    train, valtest = train_test_split(inter, train_size=train_frac)
    val, test = train_test_split(valtest, train_size=val_frac)
    train.loc[:, "split"] = "train"
    val.loc[:, "split"] = "val"
    test.loc[:, "split"] = "test"
    inter = pd.concat([train, val, test])
    return inter


if __name__ == "__main__":
    from pytorch_lightning import seed_everything

    seed_everything(snakemake.config["seed"])
    inter = pd.read_csv(snakemake.input.inter, sep="\t")
    fracs = {"train_frac": snakemake.params.train, "val_frac": snakemake.params.val}

    if snakemake.params.method == "target":
        inter = split_groups(inter, col_name="Target_ID", **fracs)
    elif snakemake.params.method == "drug":
        inter = split_groups(inter, col_name="Drug_ID", **fracs)
    elif snakemake.params.method == "random":
        inter = split_random(inter)
    else:
        raise NotImplementedError("Unknown split type!")
    inter.to_csv(snakemake.output.split_data, sep="\t")
ShowHide 15 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/ilsenatorov/rindti
Name: rindti
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...