ATLAS - Three commands to start analyzing your metagenome data

public public 7mo ago Version: v2.18.0 0 bookmarks

Metagenome-Atlas

Metagenome-atlas is a easy-to-use metagenomic pipeline based on snakemake. It handles all steps from QC, Assembly, Binning, to Annotation.

scheme of workflow

You can start using atlas with three commands:

 mamba install -y -c bioconda -c conda-forge metagenome-atlas={latest_version} atlas init --db-dir databases path/to/fastq/files atlas run all

where {latest_version} should be replaced by

Webpage

metagenome-atlas.github.io

Documentation

https://metagenome-atlas.readthedocs.io/

Tutorial

Citation

ATLAS: a Snakemake workflow for assembly, annotation, and genomic binning of metagenome sequence data.
Kieser, S., Brown, J., Zdobnov, E. M., Trajkovski, M. & McCue, L. A.
BMC Bioinformatics 21, 257 (2020).
doi: 10.1186/s12859-020-03585-4

Developpment/Extensions

Here are some ideas I work or want to work on when I have time. If you want to contribute or have some ideas let me know via a feature request issue.

  • Optimized MAG recovery (e.g. Spacegraphcats )

  • Integration of viruses/plasmid that live for now as extensions

  • Add statistics and visualisations as in atlas_analyze

  • Implementation of most rules as snakemake wrapper

  • Cloud execution

  • Update to new Snakemake version and use cool reports.

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

from common_report import *

import os, sys
import pandas as pd
import plotly.express as px


labels = {
    "Percent_Assembled_Reads": "Percent of Assembled Reads",
    "contig_bp": "Total BP",
    "n_contigs": "Contigs (count)",
    "N_Predicted_Genes": "Predicted Genes (count)",
    "N50": "N50-number",
    "L50": "N50-length (bp)",
    "N90": "N90-number",
    "L90": "N90-length (bp)",
}


PLOT_PARAMS = dict(labels=labels)


def make_plots(combined_stats):
    ## Make figures with PLOTLY
    # load and rename data
    df = pd.read_csv(combined_stats, sep="\t", index_col=0)
    df.sort_index(ascending=True, inplace=True)
    df.index.name = "Sample"
    df["Sample"] = df.index

    # create plots store in div
    div = {}

    fig = px.strip(df, y="Percent_Assembled_Reads", hover_name="Sample", **PLOT_PARAMS)
    fig.update_yaxes(range=[0, 100])
    div["Percent_Assembled_Reads"] = fig.to_html(**HTML_PARAMS)

    fig = px.strip(df, y="N_Predicted_Genes", hover_name="Sample", **PLOT_PARAMS)
    div["N_Predicted_Genes"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(df, y="L50", x="N50", hover_name="Sample", **PLOT_PARAMS)
    div["N50"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(df, y="L90", x="N90", hover_name="Sample", **PLOT_PARAMS)
    div["N90"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(
        df, y="contig_bp", x="n_contigs", hover_name="Sample", **PLOT_PARAMS
    )
    div["Total"] = fig.to_html(**HTML_PARAMS)

    return div


# main


div = make_plots(combined_stats=snakemake.input.combined_contig_stats)


make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_assembly_report.html"),
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts


from common_report import *

import pandas as pd
import plotly.express as px


from utils.taxonomy import tax2table


def make_plots(bin_info):
    div = {}

    div["input_file"] = f"{bin_info} and {snakemake.input.bins2species}"

    # Prepare data
    df = pd.read_table(bin_info, index_col=0)
    df["Bin Id"] = df.index  # need it also as column

    # add species info
    bin2species = pd.read_table(snakemake.input.bins2species, index_col=0)
    df = df.join(bin2species)

    logging.info(df.head())

    logging.info(bin2species.head())

    # calculate number of genomes/bins
    st = pd.DataFrame(columns=["Bins", "Species"])

    def add_stats(name, d):
        st.loc[name, "Bins"] = d.shape[0]
        st.loc[name, "Species"] = d.Representative.unique().shape[0]

    add_stats("All", df)

    df.eval("Quality_score = Completeness - 5* Contamination", inplace=True)
    div[
        "QualityScore"
    ] = "<p>Quality score is calculated as: Completeness - 5 x Contamination.</p>"
    add_stats("Quality score >50 ", df.query("Quality_score>50"))
    add_stats("Good quality", df.query("Completeness>90 & Contamination <5"))
    add_stats("Quality score >90 ", df.query("Quality_score>90"))

    div["table"] = st.to_html()

    logging.info(df.describe())

    # Bin Id  Completeness    completeness_general    Contamination   completeness_specific   completeness_model_used translation_table_used  coding_density  contig_n50      average_gene_length      genome_size     gc_content      total_coding_sequences  additional_notes        quality_score   sample  Ambigious_bases Length_contigs  Length_scaffolds N50     N_contigs       N_scaffolds     logN50
    hover_data = [
        "Completeness_Model_Used",
        "Coding_Density",
        "N50",
        "GC_Content",
    ]
    size_name = "Genome_Size"

    lineage_name = "Species"

    # 2D plot

    logging.info("make 2d plot")
    fig = px.scatter(
        data_frame=df,
        y="Completeness",
        x="Contamination",
        color=lineage_name,
        size=size_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    fig.update_xaxes(range=(-0.2, 10.1))
    div["2D"] = fig.to_html(**HTML_PARAMS)

    # 2D plot

    logging.info("make 2d plot species")
    fig = px.scatter(
        data_frame=df.loc[df.Representative.unique()],
        y="Completeness",
        x="Contamination",
        color=lineage_name,
        size=size_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    fig.update_xaxes(range=(-0.2, 10.1))
    div["2Dsp"] = fig.to_html(**HTML_PARAMS)

    ## By sample
    logging.info("plot  by sample")
    fig = px.strip(
        data_frame=df,
        y="Quality_score",
        x="Sample",
        color=lineage_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    div["bySample"] = fig.to_html(**HTML_PARAMS)

    # # By species
    # logging.info("plot by species")
    # fig = px.strip(
    #     data_frame=df,
    #     y="Quality_score",
    #     x=lineage_name,
    #     hover_data=hover_data,
    #     hover_name="Bin Id",
    # )
    # fig.update_yaxes(range=(50, 102))
    # div["byPhylum"] = fig.to_html(**HTML_PARAMS)

    return div


# main


div = make_plots(bin_info=snakemake.input.bin_info)


make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_bin_report.html"),
    wildcards=snakemake.wildcards,
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

from common_report import *


import pandas as pd
import plotly.express as px
from plotly import subplots
import plotly.graph_objs as go
import numpy as np


labels = {"Total_Reads": "Total Reads", "Total_Bases": "Total Bases"}


PLOT_PARAMS = dict(labels=labels)


import zipfile


def get_stats_from_zips(zips, samples):
    # def get_read_stats(samples, step):
    quality_pe = pd.DataFrame()
    quality_se = pd.DataFrame()
    for zfile, sample in zip(zips, samples):
        zf = zipfile.ZipFile(zfile)

        # single end only
        if "boxplot_quality.txt" in zf.namelist():
            with zf.open("boxplot_quality.txt") as f:
                df = pd.read_csv(f, index_col=0, sep="\t")
                quality_se[sample] = df.mean_1
        else:
            if "se/boxplot_quality.txt" in zf.namelist():
                with zf.open("se/boxplot_quality.txt") as f:
                    df = pd.read_csv(f, index_col=0, sep="\t")
                    quality_se[sample] = df.mean_1

            if "pe/boxplot_quality.txt" in zf.namelist():
                with zf.open("pe/boxplot_quality.txt") as f:
                    df = pd.read_csv(f, index_col=0, sep="\t")
                    df.columns = [df.columns, [sample] * df.shape[1]]

                    quality_pe = pd.concat(
                        (quality_pe, df[["mean_1", "mean_2"]]), axis=1
                    )

    return quality_pe, quality_se


def get_pe_read_quality_plot(df, quality_range, color_range):
    fig = subplots.make_subplots(cols=2)

    for i, sample in enumerate(df["mean_1"].columns):
        fig.append_trace(
            go.Scatter(
                x=df.index,
                y=df["mean_1"][sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                marker=dict(color=color_range[i]),
            ),
            1,
            1,
        )

        fig.append_trace(
            dict(
                x=df.index,
                y=df["mean_2"][sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                showlegend=False,
                marker=dict(color=color_range[i]),
            ),
            1,
            2,
        )

    fig.update_layout(
        yaxis=dict(range=quality_range, autorange=True, title="Average quality score"),
        xaxis1=dict(title="Position forward read"),
        xaxis2=dict(autorange="reversed", title="Position reverse read"),
    )

    return fig


def draw_se_read_quality(df, quality_range, color_range):
    fig = subplots.make_subplots(cols=1)

    for i, sample in enumerate(df.columns):
        fig.append_trace(
            go.Scatter(
                x=df.index,
                y=df[sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                marker=dict(color=color_range[i]),
            ),
            1,
            1,
        )

    fig.update_layout(
        yaxis=dict(range=quality_range, autorange=True, title="Average quality score"),
        xaxis=dict(title="Position read"),
    )
    return fig


def make_plots(
    samples, zipfiles_QC, read_counts, read_length, min_quality, insert_size_stats
):
    div = {}

    ## Quality along read

    N = len(samples)
    color_range = [
        "hsl(" + str(h) + ",50%" + ",50%)" for h in np.linspace(0, 360, N + 1)
    ]

    # load quality profiles for QC and low
    Quality_QC_pe, Quality_QC_se = get_stats_from_zips(zipfiles_QC, samples)
    # Quality_raw_pe, Quality_raw_se = get_stats_from_zips(zipfiles_QC,samples)

    # detrmine range of quality values and if paired
    max_quality = 1 + np.nanmax((Quality_QC_pe.max().max(), Quality_QC_se.max().max()))
    quality_range = [min_quality, max_quality]

    paired = Quality_QC_pe.shape[0] > 0

    # create plots if paired or not

    if paired:
        div["quality_QC"] = get_pe_read_quality_plot(
            Quality_QC_pe, quality_range, color_range
        ).to_html(**HTML_PARAMS)

    #     div["quality_raw"] = get_pe_read_quality_plot(
    #         Quality_raw_pe, quality_range, color_range
    #     ).to_html(**HTML_PARAMS)

    else:
        div["quality_QC"] = draw_se_read_quality(
            Quality_QC_se, quality_range, color_range
        ).to_html(**HTML_PARAMS)

    #     div["quality_raw"] = draw_se_read_quality(
    #         Quality_raw_se, quality_range, color_range
    #     ).to_html(**HTML_PARAMS)

    # Total reads plot

    df = pd.read_csv(read_counts, index_col=[0, 1], sep="\t")

    try:
        df.drop("clean", axis=0, level=1, inplace=True)
    except KeyError:
        pass

    data_qc = df.query('Step=="QC"').reset_index()

    for var in ["Total_Reads", "Total_Bases"]:
        fig = px.strip(data_qc, y=var, hover_data=["Sample", var], **PLOT_PARAMS)
        fig.update_yaxes(range=(0, data_qc[var].max() * 1.1))
        div[var] = fig.to_html(**HTML_PARAMS)

    ## reads plot across different steps

    total_reads = df.Total_Reads.unstack()

    fig = px.bar(
        data_frame=total_reads,
        barmode="group",
        labels={"value": "Reads"},
        category_orders={"Step": ["raw", "deduplicated", "filtered", "QC"]},
    )

    fig.update_yaxes(title="Number of reads")
    fig.update_xaxes(tickangle=45)
    # fig.update_layout(hovermode="x unified")

    div["Reads"] = fig.to_html(**HTML_PARAMS)

    ## Read length plot

    data_length = pd.read_table(read_length, index_col=0).T
    data_length.index.name = "Sample"

    fig = px.bar(
        data_frame=data_length,
        x="Median",
        error_x="Max",
        error_x_minus="Min",
        hover_data=["Median", "Max", "Min", "Avg", "Std_Dev", "Mode"],
    )

    fig.update_xaxes(title="Read length")

    div["Length"] = fig.to_html(**HTML_PARAMS)

    ### Insert insert_size_stats
    if insert_size_stats is None:
        div[
            "Insert"
        ] = "<p>Insert size information is not available for single end reads.</p>"
    else:
        data_insert = pd.read_table(insert_size_stats, index_col=0)
        data_insert.index.name = "Sample"

        fig = px.bar(
            data_frame=data_insert,
            x="Mean",
            error_x="STDev",
            hover_data=["Mean", "Median", "Mode", "PercentOfPairs"],
            labels={"PercentOfPairs": "Percent of pairs"},
        )

        fig.update_xaxes(title="Insert size")

        div["Insert"] = fig.to_html(**HTML_PARAMS)

    return div


# If paired we have information about insert size
if type(snakemake.input.read_length_stats) == str:
    read_length_path = snakemake.input.read_length_stats
    insert_size_stats = None
else:
    read_length_path, insert_size_stats = snakemake.input.read_length_stats

div = make_plots(
    samples=snakemake.params.samples,
    zipfiles_QC=snakemake.input.zipfiles_QC,
    read_counts=snakemake.input.read_counts,
    read_length=read_length_path,
    min_quality=snakemake.params.min_quality,
    insert_size_stats=insert_size_stats,
)

make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_QC_report.html"),
)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
shell:
    """
    reformat.sh {params.inputs} \
        interleaved={params.interleaved} \
        {params.outputs} \
        iupacToN=t \
        touppercase=t \
        qout=33 \
        overwrite=true \
        verifypaired={params.verifypaired} \
        addslash=t \
        trimreaddescription=t \
        threads={threads} \
        pigz=t unpigz=t \
        -Xmx{resources.java_mem}G 2> {log}
    """
90
91
92
93
94
95
96
run:
    # make symlink
    assert len(input) == len(
        output
    ), "Input and ouput files have not same number, can not create symlinks for all."
    for i in range(len(input)):
        os.symlink(os.path.abspath(input[i]), output[i])
135
136
137
138
139
140
141
142
143
144
145
146
147
shell:
    " bbnorm.sh {params.inputs} "
    " {params.outputs} "
    " tmpdir={resources.tmpdir} "
    " tossbadreads=t "
    " hist={output.histin} "
    " histout={output.histout} "
    " mindepth={params.mindepth} "
    " k={params.k} "
    " target={params.target} "
    " prefilter=t "
    " threads={threads} "
    " -Xmx{resources.java_mem}G &> {log} "
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
shell:
    "tadpole.sh -Xmx{resources.java_mem}G "
    " prefilter={params.prefilter} "
    " prealloc=1 "
    " {params.inputs} "
    " {params.outputs} "
    " mode=correct "
    " aggressive={params.aggressive} "
    " tossjunk={params.tossjunk} "
    " lowdepthfraction={params.lowdepthfraction}"
    " tossdepth={params.tossdepth} "
    " merge=t "
    " shave={params.shave} rinse={params.shave} "
    " threads={threads} "
    " pigz=t unpigz=t "
    " ecc=t ecco=t "
    "&> {log} "
231
232
233
234
235
236
237
238
239
240
shell:
    """
    bbmerge.sh -Xmx{resources.java_mem}G threads={threads} \
        in1={input[0]} in2={input[1]} \
        outmerged={output[2]} \
        outu={output[0]} outu2={output[1]} \
        {params.flags} k={params.kmer} \
        pigz=t unpigz=t \
        extend2={params.extend2} 2> {log}
    """
273
274
shell:
    "cat {input} > {output}"
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
shell:
    """
    rm -r {params.outdir} 2> {log}

    megahit \
    {params.inputs} \
    --tmp-dir {resources.tmpdir} \
    --num-cpu-threads {threads} \
    --k-min {params.k_min} \
    --k-max {params.k_max} \
    --k-step {params.k_step} \
    --out-dir {params.outdir} \
    --out-prefix {wildcards.sample}_prefilter \
    --min-contig-len {params.min_contig_len} \
    --min-count {params.min_count} \
    --merge-level {params.merge_level} \
    --prune-level {params.prune_level} \
    --low-local-ratio {params.low_local_ratio} \
    --memory {resources.mem}000000000  \
    {params.preset} >> {log} 2>&1
    """
350
351
shell:
    "cp {input} {output}"
466
467
shell:
    "cp {input} {output}"
488
489
script:
    "../scripts/rename_assembly.py"
511
512
wrapper:
    "v1.19.0/bio/minimap2/aligner"
531
532
533
534
535
536
537
538
539
shell:
    "pileup.sh ref={input.fasta} in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " concise=t "
    " minmapq={params.minmapq} "
    " secondary={params.pileup_secondary} "
    " 2> {log}"
562
563
564
565
566
567
568
569
570
571
572
shell:
    """filterbycoverage.sh in={input.fasta} \
    cov={input.covstats} \
    out={output.fasta} \
    outd={output.removed_names} \
    minc={params.minc} \
    minp={params.minp} \
    minr={params.minr} \
    minl={params.minl} \
    trim={params.trim} \
    -Xmx{resources.java_mem}G 2> {log}"""
588
589
shell:
    "cp {input} {output}"
602
603
run:
    os.symlink(os.path.relpath(input[0], os.path.dirname(output[0])), output[0])
619
620
shell:
    "stats.sh in={input} format=3 out={output} &> {log}"
640
641
wrapper:
    "v1.19.0/bio/minimap2/aligner"
669
670
671
672
673
674
675
676
677
678
679
680
681
shell:
    "pileup.sh "
    " ref={input.fasta} "
    " in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " hist={output.covhist} "
    " concise=t "
    " minmapq={params.minmapq} "
    " secondary={params.pileup_secondary} "
    " bincov={output.bincov} "
    " 2> {log} "
694
695
shell:
    "samtools index {input}"
715
716
717
718
719
shell:
    """
    prodigal -i {input} -o {output.gff} -d {output.fna} \
        -a {output.faa} -p meta -f gff 2> {log}
    """
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
run:
    header = [
        "gene_id",
        "Contig",
        "Gene_nr",
        "Start",
        "Stop",
        "Strand",
        "Annotation",
    ]
    with open(output.tsv, "w") as tsv:
        tsv.write("\t".join(header) + "\n")
        with open(input.faa) as fin:
            gene_idx = 0
            for line in fin:
                if line[0] == ">":
                    text = line[1:].strip().split(" # ")
                    old_gene_name = text[0]
                    text.remove(old_gene_name)
                    old_gene_name_split = old_gene_name.split("_")
                    gene_nr = old_gene_name_split[-1]
                    contig_nr = old_gene_name_split[-2]
                    sample = "_".join(
                        old_gene_name_split[: len(old_gene_name_split) - 2]
                    )
                    tsv.write(
                        "{gene_id}\t{sample}_{contig_nr}\t{gene_nr}\t{text}\n".format(
                            text="\t".join(text),
                            gene_id=old_gene_name,
                            i=gene_idx,
                            sample=sample,
                            gene_nr=gene_nr,
                            contig_nr=contig_nr,
                        )
                    )
                    gene_idx += 1
794
795
script:
    "../scripts/combine_contig_stats.py"
807
808
script:
    "../report/assembly_report.py"
27
28
29
30
31
32
33
34
35
shell:
    "pileup.sh "
    " ref={input.fasta} "
    " in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " secondary={params.pileup_secondary} "
    " 2> {log} "
48
49
50
51
52
53
54
run:
    with open(input[0]) as fi, open(output[0], "w") as fo:
        # header
        next(fi)
        for line in fi:
            toks = line.strip().split("\t")
            print(toks[0], toks[1], sep="\t", file=fo)
66
67
68
69
70
71
72
73
run:
    from utils.parsers_bbmap import combine_coverages

    combined_cov, _ = combine_coverages(
        input.covstats, get_alls_samples_of_group(wildcards), "Avg_fold"
    )

    combined_cov.T.to_csv(output[0], sep="\t")
 98
 99
100
101
102
103
104
105
106
107
108
shell:
    """
    concoct -c {params.Nexpected_clusters} \
        --coverage_file {input.coverage} \
        --composition_file {input.fasta} \
        --basename {params.basename} \
        --read_length {params.read_length} \
        --length_threshold {params.min_length} \
        --converge_out \
        --iterations {params.niterations}
    """
120
121
122
123
run:
    with open(input[0]) as fin, open(output[0], "w") as fout:
        for line in fin:
            fout.write(line.replace(",", "\t"))
145
146
147
148
149
shell:
    "jgi_summarize_bam_contig_depths "
    " --percentIdentity {params.minid} "
    " --outputDepth {output} "
    " {input.bams} &> {log} "
178
179
180
181
182
183
184
185
186
187
188
shell:
    """
    metabat2 -i {input.contigs} \
        --abdFile {input.depth_file} \
        --minContig {params.min_contig_len} \
        --numThreads {threads} \
        --maxEdges {params.sensitivity} \
        --saveCls --noBinOut \
        -o {output} \
        &> {log}
    """
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
shell:
    """
    mkdir {output[0]} 2> {log}
    run_MaxBin.pl -contig {input.fasta} \
        -abund {input.abund} \
        -out {params.output_prefix} \
        -min_contig_length {params.mcl} \
        -thread {threads} \
        -prob_threshold {params.pt} \
        -max_iteration {params.mi} >> {log}

    mv {params.output_prefix}.summary {output[0]}/.. 2>> {log}
    mv {params.output_prefix}.marker {output[0]}/..  2>> {log}
    mv {params.output_prefix}.marker_of_each_bin.tar.gz {output[0]}/..  2>> {log}
    mv {params.output_prefix}.log {output[0]}/..  2>> {log}

    """
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
run:
    import pandas as pd
    import numpy as np


    d = pd.read_csv(input[0], index_col=0, header=None, sep="\t").squeeze()

    assert (
        type(d) == pd.Series
    ), "expect the input to be a two column file: {}".format(input[0])

    old_cluster_ids = list(d.unique())
    if 0 in old_cluster_ids:
        old_cluster_ids.remove(0)

    map_cluster_ids = dict(
        zip(
            old_cluster_ids,
            utils.gen_names_for_range(
                len(old_cluster_ids),
                prefix="{sample}_{binner}_".format(**wildcards),
            ),
        )
    )

    new_d = d.map(map_cluster_ids)
    new_d.dropna(inplace=True)
    if new_d.shape[0] == 0:
        logger.warning(
            f"No bins detected with binner {wildcards.binner} in sample {wildcards.sample}.\n"
            "I add longest contig to make the pipline continue"
        )

        new_d[f"{wildcards.sample}_0"] = "{sample}_{binner}_1".format(**wildcards)

    new_d.to_csv(output[0], sep="\t", header=False)
294
295
296
297
298
299
300
301
302
303
304
run:
    (bin_ids,) = glob_wildcards(params.file_name)
    print("found {} bins".format(len(bin_ids)))
    with open(output[0], "w") as out_file:
        for binid in bin_ids:
            with open(params.file_name.format(binid=binid)) as bin_file:
                for line in bin_file:
                    if line.startswith(">"):
                        fasta_header = line[1:].strip().split()[0]
                        out_file.write(f"{fasta_header}\t{binid}\n")
            os.remove(params.file_name.format(binid=binid))
317
318
script:
    "../scripts/get_fasta_of_bins.py"
330
331
shell:
    "cp {input} {output}"
358
359
360
361
362
363
364
365
366
367
368
369
370
371
shell:
    " DAS_Tool --outputbasename {params.output_prefix} "
    " --bins {params.scaffolds2bin} "
    " --labels {params.binner_names} "
    " --contigs {input.contigs} "
    " --search_engine diamond "
    " --proteins {input.proteins} "
    " --write_bin_evals "
    " --megabin_penalty {params.megabin_penalty}"
    " --duplicate_penalty {params.duplicate_penalty} "
    " --threads {threads} "
    " --debug "
    " --score_threshold {params.score_threshold} &> {log} "
    " ; mv {params.output_prefix}_DASTool_contig2bin.tsv {output.cluster_attribution} &>> {log}"
12
13
14
15
16
17
run:
    from utils.genome_stats import get_many_genome_stats

    filenames = list(Path(input[0]).glob("*" + params.extension))

    get_many_genome_stats(filenames, output[0], threads)
32
33
34
35
36
37
38
39
40
41
42
43
44
run:
    try:
        from utils.io import pandas_concat

        pandas_concat(input, output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
71
72
73
74
75
76
77
78
79
80
81
82
83
84
shell:
    " checkm2 predict "
    " --threads {threads} "
    " {params.lowmem} "
    " --force "
    " --allmodels "
    " -x .fasta "
    " --tmpdir {resources.tmpdir} "
    " --input {input.fasta_dir} "
    " --output-directory {params.dir} "
    " &> {log[0]} "
    ";\n"
    " cp {params.dir}/quality_report.tsv {output.table} 2>> {log[0]} ; "
    " mv {params.dir}/protein_files {output.faa} 2>> {log[0]} ; "
107
108
109
110
111
112
113
114
115
116
117
118
119
shell:
    " mkdir {output.folder} 2> {log}"
    " ;\n"
    " gunc run "
    " --threads {threads} "
    " --gene_calls "
    " --db_file {input.db} "
    " --input_dir {input.fasta_dir} "
    " --temp_dir {resources.tmpdir} "
    " --file_suffix {params.extension} "
    " --out_dir {output.folder} &>> {log} "
    " ;\n "
    " cp {output.folder}/*.tsv {output.table} 2>> {log}"
145
146
147
148
149
150
151
152
153
154
155
156
157
    shell:
        " busco -i {input.fasta_dir} "
        " --auto-lineage-prok "
        " -m genome "
        " --out_path {params.tmpdir} "
        " -o output "
        " --download_path {input.db} "
        " -c {threads} "
        " --offline &> {log} "
        " ; "
        " mv {params.tmpdir}/output/batch_summary.txt {output} 2>> {log}"

"""
180
181
182
183
184
185
186
187
188
189
190
191
192
run:
    try:
        from utils.io import pandas_concat

        pandas_concat(input, output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
207
208
script:
    "../scripts/combine_checkm2.py"
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
run:
    import pandas as pd
    from pathlib import Path
    from utils import io


    def get_list_of_files(dirs, pattern):
        fasta_files = []

        # searh for fasta files (.f*) in all bin folders
        for dir in dirs:
            dir = Path(dir)
            fasta_files += list(dir.glob(pattern))

        filenames = pd.DataFrame(fasta_files, columns=["Filename"])
        filenames.index = filenames.Filename.apply(io.simplify_path)
        filenames.index.name = "Bin"

        filenames.sort_index(inplace=True)

        return filenames


    fasta_filenames = get_list_of_files(input.dirs, "*.f*")
    faa_filenames = get_list_of_files(input.protein_dirs, "*.faa")

    assert all(
        faa_filenames.index == fasta_filenames.index
    ), "faa index and faa index are nt the same"

    faa_filenames.columns = ["Proteins"]

    filenames = pd.concat((fasta_filenames, faa_filenames), axis=1)

    filenames.to_csv(output.filenames, sep="\t")
276
277
278
279
run:
    from utils.io import cat_files

    cat_files(input, output[0], gzip=True)
318
319
script:
    "../scripts/filter_genomes.py"
74
75
76
77
78
79
80
81
82
shell:
    """
    cd-hit-est -i {input} -T {threads} \
    -M {resources.mem}000 -o {params.prefix} \
    -c {params.identity} -n 9  -d 0 {params.extra} \
    -aS {params.coverage} -aL {params.coverage} &> {log}

    mv {params.prefix} {output[0]} 2>> {log}
    """
90
91
92
93
94
run:
    with open(output[0], "w") as fout:
        fout.write(f"ORF\tLength\tIdentity\tRepresentative\n")
        Clusters = parse_cd_hit_file(input[0])
        write_cd_hit_clusters(Clusters, fout)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
run:
    import pandas as pd
    import numpy as np

    from utils import gene_scripts

    # cd hit format ORF\tLength\tIdentity\tRepresentative\n
    orf2gene = pd.read_csv(input.orf2gene, sep="\t")

    # rename gene repr to Gene0000XX

    # split orf names in sample, contig_nr, and orf_nr
    orf_info = gene_scripts.split_orf_to_index(orf2gene.ORF)

    # rename representative

    representative_names = orf2gene.Representative.unique()

    map_names = pd.Series(
        index=representative_names,
        data=np.arange(1, len(representative_names) + 1, dtype=np.uint),
    )


    orf_info["GeneNr"] = orf2gene.Representative.map(map_names)


    orf_info.to_parquet(output.cluster_attribution)


    # Save name of representatives
    map_names.index.name = "Representative"
    map_names.name = "GeneNr"
    map_names.to_csv(output.rep2genenr, sep="\t")
24
25
26
27
28
29
shell:
    " reformat.sh in={input} "
    " fastaminlen={params.min_length} "
    " out={output} "
    " overwrite=true "
    " -Xmx{resources.java_mem}G 2> {log} "
75
76
77
78
79
80
81
82
83
84
85
86
run:
    import gzip as gz

    with gz.open(output[0], "wb") as fout:
        for sample, input_fasta in zip(params.samples, input.fasta):
            with gz.open(input_fasta, "rb") as fin:
                for line in fin:
                    # if line is a header add sample name
                    if line[0] == ord('>'):
                        line = f">{sample}{params.seperator}".encode() + line[1:]
                    # write each line to the combined file
                    fout.write(line)
105
106
shell:
    "minimap2 -I {params.index_size} -t {threads} -d {output} {input} 2> {log}"
122
123
shell:
    "samtools dict {input} | cut -f1-3 > {output} 2> {log}"
142
143
shell:
    """minimap2 -t {threads} -ax sr {input.mmi} {input.fq} | grep -v "^@" | cat {input.dict} - | samtools view -F 3584 -b - > {output.bam} 2>{log}"""
165
166
shell:
    "samtools sort {input} -T {params.prefix} --threads {threads} -m 3G -o {output} 2>{log}"
186
187
188
189
190
shell:
    "jgi_summarize_bam_contig_depths "
    " --percentIdentity {params.minid} "
    " --outputDepth {output} "
    " {input.bams} &> {log} "
207
208
script:
    "../scripts/convert_jgi2vamb_coverage.py"
231
232
233
234
235
236
237
238
shell:
    "vamb --outdir {output} "
    " -m {params.mincontig} "
    " --minfasta {params.minfasta} "
    " -o '{params.separator}' "
    " --jgi {input.coverage} "
    " --fasta {input.fasta} "
    "2> {log}"
264
265
script:
    "../scripts/parse_vamb.py"
21
22
23
24
25
26
27
28
29
shell:
    "skani triangle "
    " {params.extra} "
    " -l {input.paths} "
    " -o {output} "
    " -t {threads} "
    " --sparse --ci "
    " --min-af {params.min_af} "
    " &> {log} "
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
run:
    try:

        skani_column_dtypes = {
            "Ref_file": "category",
            "Query_file": "category",
            "ANI": float,
            "Align_fraction_ref": float,
            "Align_fraction_query": float,
            "ANI_5_percentile": float,
            "ANI_95_percentile": float,
        }  # Ref_name        Query_name

        import pandas as pd

        import pandas as pd

        df = pd.read_table(input[0])

        from utils.io import simplify_path

        df = pd.read_table(
            input[0],
            usecols=list(skani_column_dtypes.keys()),
            dtype=skani_column_dtypes,
        )

        df["Ref"] = df.Ref_file.cat.rename_categories(simplify_path)
        df["Query"] = df.Query_file.cat.rename_categories(simplify_path)

        df.to_parquet(output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
 99
100
script:
    "../scripts/cluster_species.py"
113
114
script:
    "../report/bin_report.py"
SnakeMake From line 113 of rules/derep.smk
134
135
136
137
138
139
run:
    shell(
        "wget -O {output} 'https://zenodo.org/record/{ZENODO_ARCHIVE}/files/{wildcards.filename}' "
    )
    if not FILES[wildcards.filename] == md5(output[0]):
        raise OSError(2, "Invalid checksum", output[0])
148
149
150
151
152
153
154
155
run:
    shell(
        "wget -O {output.tar} 'https://zenodo.org/record/{ZENODO_ARCHIVE}/files/{CHECKM_ARCHIVE}' "
    )
    if not FILES[CHECKM_ARCHIVE] == md5(output.tar):
        raise OSError(2, "Invalid checksum", CHECKM_ARCHIVE)

    shell("tar -zxf {output.tar} --directory {params.path}")
173
174
shell:
    "checkm data setRoot {params.database_dir} &> {log} "
191
192
shell:
    " wget --no-check-certificate {GTDB_DATA_URL} -O {output} &> {log} "
207
208
shell:
    'tar -xzvf {input} -C "{GTDBTK_DATA_PATH}" --strip 1 2> {log}; '
221
222
223
shell:
    " checkm2 database --download --path {output} "
    " &>> {log}"
238
239
240
shell:
    "gunc download_db {resources.tmpdir} -db {wildcards.gunc_database} &> {log} ;"
    "mv {resources.tmpdir}/gunc_db_{wildcards.gunc_database}*.dmnd {output} 2>> {log}"
254
255
shell:
    "busco -q --download_path {output} --download prokaryota &> {log}"
33
34
35
36
37
38
39
40
41
shell:
    " DRAM-setup.py prepare_databases "
    " --output_dir {output.dbdir} "
    " --threads {threads} "
    " --verbose "
    " --skip_uniref "
    " &> {log} "
    " ; "
    " DRAM-setup.py export_config --output_file {output.config}"
SnakeMake From line 33 of rules/dram.smk
65
66
67
68
69
70
71
72
73
shell:
    " DRAM.py annotate "
    " --config_loc {input.config} "
    " --input_fasta {input.fasta}"
    " --output_dir {output.outdir} "
    " --threads {threads} "
    " --min_contig_size {params.min_contig_size} "
    " {params.extra} "
    " --verbose &> {log}"
SnakeMake From line 65 of rules/dram.smk
 94
 95
 96
 97
 98
 99
100
101
102
103
104
run:
    from utils import io

    for i, annotation_file in enumerate(DRAM_ANNOTATON_FILES):
        input_files = [
            os.path.join(dram_folder, annotation_file) for dram_folder in input
        ]

        io.pandas_concat(
            input_files, output[i], sep="\t", index_col=0, axis=0, disk_based=True
        )
121
122
123
124
125
126
shell:
    " DRAM.py distill "
    " --config_loc {input.config} "
    " --input_file {input[0]}"
    " --output_dir {output} "
    "  &> {log}"
SnakeMake From line 121 of rules/dram.smk
143
144
script:
    "../scripts/DRAM_get_all_modules.py"
SnakeMake From line 143 of rules/dram.smk
19
20
script:
    "../scripts/filter_genes.py"
46
47
48
49
50
51
run:
    from utils.io import cat_files

    cat_files(input.faa, output.faa)
    cat_files(input.fna, output.fna)
    cat_files(input.short, output.short)
69
70
71
72
73
74
run:
    from utils.io import cat_files

    cat_files(input.faa, output.faa)
    cat_files(input.fna, output.fna)
    cat_files(input.short, output.short)
104
105
106
107
108
109
110
111
112
113
114
shell:
    """
    mkdir -p {params.tmpdir} {output} 2>> {log}
    mmseqs createdb {input.faa} {params.db} &> {log}

    mmseqs {params.clustermethod} -c {params.coverage} \
    --min-seq-id {params.minid} {params.extra} \
    --threads {threads} {params.db} {params.clusterdb} {params.tmpdir}  &>>  {log}

    rm -fr  {params.tmpdir} 2>> {log}
    """
132
133
134
135
136
137
138
139
140
141
142
shell:
    """
    mmseqs createtsv {params.db} {params.db} {params.clusterdb} {output.cluster_attribution}  &> {log}

    mkdir {output.rep_seqs_db} 2>> {log}

    mmseqs result2repseq {params.db} {params.clusterdb} {output.rep_seqs_db}/db  &>> {log}

    mmseqs result2flat {params.db} {params.db} {output.rep_seqs_db}/db {output.rep_seqs}  &>> {log}

    """
158
159
160
161
162
163
164
165
shell:
    " filterbyname.sh "
    " in={input.all}"
    " names={input.names}"
    " include=t"
    " out={output} "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
176
177
script:
    "../scripts/generate_orf_info.py"
207
208
script:
    "../scripts/rename_genecatalog.py"
224
225
shell:
    "stats.sh gcformat=4 gc={output} in={input} &> {log}"
237
238
wrapper:
    "v1.19.0/bio/minimap2/index"
251
252
shell:
    "cat {input} > {output} 2> {log}"
269
270
wrapper:
    "v1.19.0/bio/minimap2/aligner"
289
290
291
292
293
294
295
296
297
shell:
    " pileup.sh "
    " in={input.bam}"
    " covstats={output.covstats} "
    " rpkm={output.rpkm} "
    " secondary=t "
    " minmapq={params.minmapq} "
    " -Xmx{resources.java_mem}G "
    " 2> {log} "
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
run:
    try:
        import pandas as pd
        from utils.parsers_bbmap import read_pileup_coverage

        data = read_pileup_coverage(
            input[0],
            coverage_measure="Median_fold",
            other_columns=["Avg_fold", "Covered_percent", "Read_GC", "Std_Dev"],
        )
        data.index.name = "GeneName"
        data.sort_index(inplace=True)

        # rpkm = pd.read_csv(input[1],sep='\t',skiprows=4,usecols=["#Name","RPKM"],index_col=0).sort_index()

        data.reset_index().to_parquet(output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
374
375
script:
    "../scripts/combine_gene_coverages.py"
414
415
script:
    "../scripts/split_genecatalog.py"
455
456
457
458
459
460
shell:
    """
    emapper.py -m diamond --no_annot --no_file_comments \
        --data_dir {params.data_dir} --cpu {threads} -i {input.faa} \
        -o {params.prefix} --override 2> {log}
    """
490
491
492
493
494
495
496
497
498
499
500
501
shell:
    """

    if [ {params.copyto_shm} == "t" ] ;
    then
        cp {EGGNOG_DIR}/eggnog.db {params.data_dir}/eggnog.db 2> {log}
        cp {EGGNOG_DIR}/eggnog_proteins.dmnd {params.data_dir}/eggnog_proteins.dmnd 2>> {log}
    fi

    emapper.py --annotate_hits_table {input.seed} --no_file_comments \
      --override -o {params.prefix} --cpu {threads} --data_dir {params.data_dir} 2>> {log}
    """
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
run:
    try:

        import pandas as pd

        Tables = [
            pd.read_csv(file, index_col=None, header=None, sep="\t")
            for file in input
        ]

        combined = pd.concat(Tables, axis=0)

        del Tables

        combined.columns = EGGNOG_HEADER
        combined["Seed_evalue"] = combined["Seed_evalue"].astype("bytes")
        combined["Seed_Score"] = combined["Seed_Score"].astype("bytes")

        #           combined.sort_values("Gene",inplace=True)

        combined.to_parquet(output[0], index=False)
    except Exception as e:

        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
run:
    try:
        import pandas as pd

        df = pd.read_table(input[0], index_col=None)

        df.to_parquet(output[0], index=False)

    except Exception as e:

        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
607
608
609
610
611
612
613
614
615
616
617
shell:
    " rm -rf {params.outdir} &> {log[0]};"
    "\n"
    " DRAM.py annotate_genes "
    " --input_faa {input.faa}"
    " --config_loc {input.config} "
    " --output_dir {params.outdir} "
    " --threads {threads} "
    " {params.extra} "
    " --log_file_path {log[1]} "
    " --verbose &>> {log[0]}"
638
639
script:
    "../scripts/combine_dram_gene_annotations.py"
657
658
script:
    "../scripts/gene2genome.py"
718
719
720
721
722
723
724
725
726
727
728
729
shell:
    " DIR=$(dirname $(readlink -f $(which DAS_Tool))) "
    ";"
    " ruby {params.script_dir}/rules/scg_blank_diamond.rb diamond"
    " {input} "
    " $DIR\/db/{params.key}.all.faa "
    " $DIR\/db/{params.key}.scg.faa "
    " $DIR\/db/{params.key}.scg.lookup "
    " {threads} "
    " 2> {log} "
    " ; "
    " mv {input[0]}.scg {output}"
28
29
script:
    "../scripts/rename_genomes.py"
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
run:
    from glob import glob

    fasta_files = glob(input[0] + "/*.f*")

    with open(output[0], "w") as out_contigs:
        for fasta in fasta_files:
            bin_name, ext = os.path.splitext(os.path.split(fasta)[-1])
            # if gz remove also fasta extension
            if ext == ".gz":
                bin_name = os.path.splitext(bin_name)[0]

            # write names of contigs in mapping file
            with open(fasta) as f:
                for line in f:
                    if line[0] == ">":
                        header = line[1:].strip().split()[0]
                        out_contigs.write(f"{header}\t{bin_name}\n")
139
140
141
142
143
shell:
    """
    prodigal -i {input} -o {output.gff} -d {output.fna} \
        -a {output.faa} -p meta -f gff 2> {log}
    """
178
179
shell:
    "cat {input} > {output}"
189
190
shell:
    "cat {input}/*{params.ext} > {output}"
208
209
wrapper:
    "v1.19.0/bio/minimap2/index"
226
227
wrapper:
    "v1.19.0/bio/minimap2/aligner"
243
244
wrapper:
    "v1.19.0/bio/bwa-mem2/index"
262
263
wrapper:
    "v1.19.0/bio/bwa-mem2/mem"
289
290
shell:
    "mv {input} {output} > {log}"
303
304
wrapper:
    "v1.19.0/bio/samtools/stats"
314
315
wrapper:
    "v1.19.1/bio/multiqc"
336
337
338
339
340
341
342
343
344
345
346
shell:
    "pileup.sh in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " fastaorf={input.orf} outorf={output.orf} "
    " concise=t "
    " physical=t "
    " minmapq={params.minmapq} "
    " bincov={output.bincov} "
    " 2> {log}"
371
372
script:
    "../scripts/combine_coverage_MAGs.py"
20
21
22
23
24
25
26
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk identify "
    "--genes --genome_dir {params.gene_dir} "
    " --out_dir {params.outdir} "
    "--extension {params.extension} "
    "--cpus {threads} &> {log[0]}"
42
43
44
45
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk align --identify_dir {params.outdir} --out_dir {params.outdir} "
    "--cpus {threads} &> {log[0]}"
67
68
69
70
71
72
73
74
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk classify --genome_dir {input.genome_dir} --align_dir {params.outdir} "
    " --mash_db {params.mashdir} "
    "--out_dir {params.outdir} "
    " --tmpdir {resources.tmpdir} "
    "--extension {params.extension} "
    "--cpus {threads} &> {log[0]}"
85
86
script:
    "../scripts/combine_taxonomy.py"
102
103
104
105
106
107
108
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk infer --msa_file {input} "
    " --out_dir {params.outdir} "
    " --prefix {wildcards.msa} "
    " --cpus {threads} "
    "--tmpdir {resources.tmpdir} > {log[0]} 2> {log[1]}"
130
131
script:
    "../scripts/root_tree.py"
SnakeMake From line 130 of rules/gtdbtk.smk
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
shell:
    "reformat.sh "
    " {params.inputs} "
    " interleaved={params.interleaved} "
    " {params.outputs} "
    " {params.extra} "
    " overwrite=true "
    " verifypaired={params.verifypaired} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
SnakeMake From line 92 of rules/qc.smk
130
131
script:
    "../scripts/get_read_stats.py"
SnakeMake From line 130 of rules/qc.smk
174
175
176
177
178
179
180
181
182
183
184
185
shell:
    "clumpify.sh "
    " {params.inputs} "
    " {params.outputs} "
    " overwrite=true"
    " dedupe=t "
    " dupesubs={params.dupesubs} "
    " optical={params.only_optical}"
    " threads={threads} "
    " pigz=t unpigz=t "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
SnakeMake From line 174 of rules/qc.smk
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
shell:
    " bbduk.sh {params.inputs} "
    " {params.ref} "
    " interleaved={params.interleaved} "
    " {params.outputs} "
    " stats={output.stats} "
    " overwrite=true "
    " qout=33 "
    " trd=t "
    " {params.hdist} "
    " {params.k} "
    " {params.ktrim} "
    " {params.mink} "
    " trimq={params.trimq} "
    " qtrim={params.qtrim} "
    " threads={threads} "
    " minlength={params.minlength} "
    " maxns={params.maxns} "
    " minbasefrequency={params.minbasefrequency} "
    " ecco={params.error_correction_pe} "
    " prealloc={params.prealloc} "
    " pigz=t unpigz=t "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
331
332
333
334
335
336
337
338
shell:
    "bbsplit.sh"
    " -Xmx{resources.java_mem}G "
    " {params.refs_in} "
    " threads={threads}"
    " k={params.k}"
    " local=t "
    " 2> {log}"
SnakeMake From line 331 of rules/qc.smk
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
shell:
    """
    if [ "{params.paired}" = true ] ; then
        bbsplit.sh in1={input[0]} in2={input[1]} \
            outu1={output[0]} outu2={output[1]} \
            basename="{params.contaminant_folder}/%_R#.fastq.gz" \
            maxindel={params.maxindel} minratio={params.minratio} \
            minhits={params.minhits} ambiguous={params.ambiguous} refstats={output.stats} \
            threads={threads} k={params.k} local=t \
            pigz=t unpigz=t ziplevel=9 \
            -Xmx{resources.java_mem}G 2> {log}
    fi

    bbsplit.sh in={params.input_single}  \
        outu={params.output_single} \
        basename="{params.contaminant_folder}/%_se.fastq.gz" \
        maxindel={params.maxindel} minratio={params.minratio} \
        minhits={params.minhits} ambiguous={params.ambiguous} refstats={output.stats} append=t \
        interleaved=f threads={threads} k={params.k} local=t \
        pigz=t unpigz=t ziplevel=9 \
        -Xmx{resources.java_mem}G 2>> {log}
    """
SnakeMake From line 386 of rules/qc.smk
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
run:
    import shutil
    import pandas as pd

    for i in range(len(MULTIFILE_FRACTIONS)):
        with open(output[i], "wb") as outFile:
            with open(input.clean_reads[i], "rb") as infile1:
                shutil.copyfileobj(infile1, outFile)
                if hasattr(input, "rrna_reads"):
                    with open(input.rrna_reads[i], "rb") as infile2:
                        shutil.copyfileobj(infile2, outFile)

    # append to sample table
    sample_table = load_sample_table(params.sample_table)
    qc_header = [f"Reads_QC_{fraction}" for fraction in MULTIFILE_FRACTIONS]
    sample_table.loc[wildcards.sample, qc_header] = output
    sample_table.to_csv(params.sample_table, sep="\t")
477
478
479
480
481
482
483
484
485
486
487
488
489
490
shell:
    " bbmerge.sh "
    " -Xmx{resources.java_mem}G "
    " threads={threads} "
    " {params.inputs} "
    " {params.flags} k={params.kmer} "
    " extend2={params.extend2} "
    " ihist={output.ihist} merge=f "
    " mininsert0=35 minoverlap0=8 "
    " prealloc=t prefilter=t "
    " minprob={params.minprob} 2> {log} \n  "
    """
    readlength.sh {params.inputs} out={output.read_length} 2>> {log}
    """
SnakeMake From line 477 of rules/qc.smk
511
512
513
514
shell:
    """
    readlength.sh in={input[0]} out={output.read_length} 2> {log}
    """
SnakeMake From line 511 of rules/qc.smk
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
run:
    import pandas as pd
    import os
    from utils.parsers_bbmap import parse_comments

    stats = pd.DataFrame()

    for length_file in input:
        sample = length_file.split(os.path.sep)[0]
        data = parse_comments(length_file)
        data = pd.Series(data)[
            ["Reads", "Bases", "Max", "Min", "Avg", "Median", "Mode", "Std_Dev"]
        ]
        stats[sample] = data

    stats.to_csv(output[0], sep="\t")
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
run:
    import pandas as pd
    import os
    from utils.parsers_bbmap import parse_comments

    stats = pd.DataFrame()

    for insert_file in input:
        sample = insert_file.split(os.path.sep)[0]
        data = parse_comments(insert_file)
        data = pd.Series(data)[
            ["Mean", "Median", "Mode", "STDev", "PercentOfPairs"]
        ]
        stats[sample] = data

    stats.T.to_csv(output[0], sep="\t")
612
613
614
615
616
617
618
619
620
621
run:
    from utils.io import pandas_concat

    pandas_concat(
        list(input.read_count_files),
        output.read_stats,
        sep="\t",
        index_col=[0, 1],
        axis=0,
    )
632
633
634
635
run:
    from utils.io import pandas_concat

    pandas_concat(list(input), output[0], sep="\t", index_col=[0, 1], axis=0)
673
674
script:
    "../report/qc_report.py"
SnakeMake From line 673 of rules/qc.smk
15
16
17
18
19
20
21
22
23
shell:
    "bbsketch.sh "
    "in={input[0]}" # take only one read
    " samplerate=0.5"
    " minkeycount=2 "
    " out={output} "
    " blacklist=nt ssu=f name0={wildcards.sample} depth=t overwrite=t "
    " -Xmx{resources.java_mem}g "
    " &> {log}"
40
41
42
43
44
45
46
shell:
    "comparesketch.sh alltoall "
    " format=3 out={output} "
    " records=5000 "
    " {input} "
    " -Xmx{resources.java_mem}g "
    " &> {log}"
26
27
28
29
30
31
32
33
shell:
    "SemiBin generate_sequence_features_multi"
    " --input-fasta {input.fasta} "
    " --input-bam {input.bams} "
    " --output {output} "
    " --threads {threads} "
    " --separator {params.separator} "
    " 2> {log}"
65
66
67
68
69
70
71
72
shell:
    "SemiBin train_self "
    " --output {params.output_dir} "
    " --threads {threads} "
    " --data {params.data} "
    " --data-split {params.data_split} "
    " {params.extra} "
    " 2> {log}"
128
129
130
131
132
133
134
135
136
137
shell:
    "SemiBin bin "
    " --input-fasta {input.fasta} "
    " --output {params.output_dir} "
    " --threads {threads} "
    " --data {params.data} "
    " --model {input.model} "
    " --minfasta-kbs {params.min_bin_kbs}"
    " {params.extra} "
    " 2> {log}"
158
159
script:
    "../scripts/parse_semibin.py"
30
31
32
33
34
35
36
37
38
39
40
shell:
    " mkdir -p {params.outdir} 2> {log} "
    " ; "
    " prefetch "
    " --output-directory {params.outdir} "
    " -X 999999999 "
    " --progress "
    " --log-level info "
    " {wildcards.sra_run} &>> {log} "
    " ; "
    " vdb-validate {params.outdir}/{wildcards.sra_run}/{wildcards.sra_run}.sra &>> {log} "
SnakeMake From line 30 of rules/sra.smk
66
67
68
69
70
71
72
73
74
75
76
77
shell:
    " vdb-validate {params.sra_file} &>> {log} "
    " ; "
    " parallel-fastq-dump "
    " --threads {threads} "
    " --gzip --split-files "
    " --outdir {params.outdir} "
    " --tmpdir {resources.tmpdir} "
    " --skip-technical --split-3 "
    " -s {params.sra_file} &>> {log} "
    " ; "
    " rm -f {params.sra_file} 2>> {log} "
SnakeMake From line 66 of rules/sra.smk
123
124
125
126
127
128
129
run:
    from utils import io

    for i, fraction in enumerate(SRA_read_fractions):
        if fraction == "":
            fraction = "se"
        io.cat_files(input[fraction], output[i])
56
57
58
59
60
61
62
63
shell:
    "inStrain compare "
    " --input {input.profiles} "
    " -o {output} "
    " -p {threads} "
    " -s {input.scaffold_to_genome} "
    " --database_mode "
    " {params.extra} &> {log}"
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.DEBUG,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts


import pandas as pd

import numpy as np
from utils import genome_dist as gd
import networkx as nx


def get_float(value):
    "Enshure that value is [0-1]"

    assert value >= 0

    if value > 1:
        assert value <= 100, "it should be a percentage"
        logging.debug(f"Value {value} is > 1, I divede it with 100 to get a float")

        return value / 100
    else:
        return value


linkage_method = snakemake.params.linkage_method
pre_cluster_threshold = get_float(snakemake.params.pre_cluster_threshold)
threshold = get_float(snakemake.params.threshold)
min_aligned_fraction = get_float(snakemake.config["genome_dereplication"]["overlap"])

# verify ranges
gd.verify_expected_range(pre_cluster_threshold, 0.8, 1, "pre_cluster_threshold")
gd.verify_expected_range(threshold, 0.8, 1, "ANI cluster threshold")
gd.verify_expected_range(min_aligned_fraction, 0.1, 0.95, "min_aligned_fraction")


# load quality
Q = pd.read_csv(snakemake.input.bin_info, sep="\t", index_col=0)
Q.Additional_Notes = Q.Additional_Notes.fillna("").astype(str)


logging.info("Load distances")
M = gd.load_skani(snakemake.input.dist)

# genome distance to graph
pre_clustering_criteria = (
    f"ANI >= {pre_cluster_threshold} & Align_fraction > {min_aligned_fraction}"
)

logging.info(f"Pre-cluster genomes with the if '{pre_clustering_criteria}'")
G = gd.to_graph(M.query(pre_clustering_criteria))

if hasattr(G, "selfloop_edges"):
    G.remove_edges_from(G.selfloop_edges())


# prepare table for species number
mag2Species = pd.DataFrame(index=Q.index, columns=["SpeciesNr", "Species"])
mag2Species.index.name = "genome"
genomes_to_drop = []

last_species_nr = 1  # start at 1


n_pre_clusters = nx.connected.number_connected_components(G)
logging.info(f"Found {n_pre_clusters} pre-clusters, itterate over them.")
logging.debug(f"Cluster with threshold {threshold} and {linkage_method}-linkage method")
for i, cc in enumerate(nx.connected_components(G)):
    logging.info(f"Precluster {i+1}/{n_pre_clusters} with {len(cc)} genomes")

    Qcc = Q.loc[list(cc)]

    # check translation table
    freq = Qcc["Translation_Table_Used"].value_counts()

    if freq.shape[0] > 1:
        logging.info(
            "Not all genomes use the same translation table,"
            "drop genomes that don't use main translation table."
        )
        logging.info(freq)

        main_tranlation_table = freq.index[0]

        drop_genomes = Qcc.query(
            "Translation_Table_Used != @main_tranlation_table"
        ).index

        cc = cc - set(drop_genomes)
        Qcc = Qcc.loc[list(cc)]
        genomes_to_drop += list(drop_genomes)
        logging.info(f"Drop {len(drop_genomes) } genomes, keep ({len(cc)})")

    ## Check that the same completeness model is used for all

    freq = Qcc["Completeness_Model_Used"].value_counts()
    if freq.shape[0] > 1:
        logging.info(
            "Not all genomes use the same completeness model. Recalibrate completeness."
        )

        logging.info(freq)

        # genomes that don't use specific model
        non_specific = Qcc.index[
            ~Qcc.Completeness_Model_Used.str.contains("Specific Model")
        ]

        logging.debug(
            f"{len(non_specific)} genomes use general completeness model. Recalibrate completeness and quality score to use lower value"
        )

        logging.debug(
            Qcc.loc[
                non_specific,
                ["Completeness_General", "Completeness_Specific", "Contamination"],
            ]
        )

        Qcc.loc[non_specific, "Completeness"] = Qcc.loc[
            non_specific,
            [
                "Completeness_General",
                "Completeness_Specific",
            ],
        ].min(axis=1)

        # add note
        Q.loc[
            non_specific, "Additional_Notes"
        ] += "Completeness was re-calibrated based on Completeness model used in all genomes of the species."

        # transfer to main quality
        Q.loc[list(cc), "Completeness"] = Qcc.loc[list(cc), "Completeness"]

        # drop low quality genomes

        logging.info("Drop low quality genomes acording to filtercriteria")

        try:
            filter_criteria = snakemake.config["genome_filter_criteria"]
            drop_genomes = Qcc.index.difference(Qcc.query(filter_criteria).index)

        except Exception as e:
            logging.error("Cannot filter low quality genomes")
            logging.exception(e)

            drop_genomes = []

        if len(drop_genomes) > 0:
            cc = cc - set(drop_genomes)
            logging.info(
                f"Drop {len(drop_genomes) } with too low quality genomes, keep {len(cc)}"
            )

            Qcc = Qcc.loc[list(cc)]
            genomes_to_drop += list(drop_genomes)

    if len(cc) <= 1:
        logging.info(
            "I am left with {len(cc)} genomes in this pre-cluster. No need to cluster."
        )
    else:
        # subset dist matrix
        Mcc = M.loc[
            (M.index.levels[0].intersection(cc), M.index.levels[1].intersection(cc)),
        ]

        # Cluster species
        labels = gd.group_species_linkage(
            Mcc.ANI, threshold=threshold, linkage_method=linkage_method
        )

        logging.debug(f"Got {labels.max()} species cluster for this pre-cluster.")

        # enter values of labels to species table
        mag2Species.loc[labels.index, "SpeciesNr"] = labels + last_species_nr
        last_species_nr = mag2Species.SpeciesNr.max()


mag2Species.drop(genomes_to_drop, inplace=True)
Q.drop(genomes_to_drop, inplace=True)


missing_species = mag2Species.index[mag2Species.SpeciesNr.isnull()]
N_missing_species = len(missing_species)

logging.info(
    f"{N_missing_species} genomes were not part of a pre-cluster and are singleton-species."
)

Q.loc[missing_species, "Additional_Notes"] += " Singleton species"

mag2Species.loc[missing_species, "SpeciesNr"] = (
    np.arange(last_species_nr, last_species_nr + N_missing_species) + 1
)


n_species = mag2Species.SpeciesNr.unique().shape[0]
logging.info(f"Identified {n_species } species in total")

# create propper species names
n_leading_zeros = len(str(mag2Species.SpeciesNr.max()))
format_int = "sp{:0" + str(n_leading_zeros) + "d}"
mag2Species["Species"] = mag2Species.SpeciesNr.apply(format_int.format)


# calculate quality score


logging.info("Define Quality score defined as Completeness - 5x Contamination")
# recalulate quality score as some completeness might be recalibrated.
Q.eval("Quality_score = Completeness - 5* Contamination", inplace=True)
quality_score = Q.Quality_score

assert (
    not quality_score.isnull().any()
), "I have NA quality values for thq quality score, it seems not all of the values defined in the quality_score_formula are presentfor all entries in tables/Genome_quality.tsv "


# select representative
logging.info("Select representative")
mag2Species["Representative"] = gd.best_genome_from_table(
    mag2Species.Species, quality_score
)

mag2Species.to_csv(snakemake.output.bins2species, sep="\t")

# mag2Species = mag2Species.join(Q)
Q.to_csv(snakemake.output.bin_info, sep="\t")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

import pandas as pd
from utils.parsers import read_checkm2_output


def main(samples, completeness_files, bin_table):
    sample_data = {}
    div = {}

    df_list = []

    for i, sample in enumerate(samples):
        sample_data = read_checkm2_output(completness_table=completeness_files[i])
        sample_data["Sample"] = sample

        df_list.append(sample_data)

    df = pd.concat(df_list, axis=0)

    df.to_csv(bin_table, sep="\t")


if __name__ == "__main__":
    main(
        samples=snakemake.params.samples,
        completeness_files=snakemake.input.completeness_files,
        bin_table=snakemake.output.bin_table,
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from utils.parsers_bbmap import parse_pileup_log_file


def parse_map_stats(sample_data, out_tsv):
    sample_stats = {}
    for sample in sample_data.keys():
        df = pd.read_csv(sample_data[sample]["contig_stats"], sep="\t")

        assert df.shape[0] == 1, "Assumed only one row in file {}; found {}".format(
            sample_data[sample]["contig_stats"], df.iloc[0]
        )

        # n genes
        genes_df = pd.read_csv(sample_data[sample]["gene_table"], index_col=0, sep="\t")
        df["N_Predicted_Genes"] = genes_df.shape[0]

        # mappingt stats
        mapping_stats = parse_pileup_log_file(sample_data[sample]["mapping_log"])
        df["Assembled_Reads"] = mapping_stats["Mapped reads"]
        df["Percent_Assembled_Reads"] = mapping_stats["Percent mapped"]

        logging.info(f"Stats for sample {sample}\n{df}")

        sample_stats[sample] = df

    stats_df = pd.concat(sample_stats, axis=0)
    stats_df.index = stats_df.index.get_level_values(0)
    # remove contig stats and keep only scaffold stats
    stats_df = stats_df.loc[:, ~stats_df.columns.str.startswith("scaf_")]
    stats_df.columns = stats_df.columns.str.replace("ctg_", "")
    # save
    stats_df.to_csv(out_tsv, sep="\t")
    return stats_df


def main(samples, contig_stats, gene_tables, mapping_logs, combined_stats):
    sample_data = {}
    for sample in samples:
        sample_data[sample] = {}
        for c_stat in contig_stats:
            # underscore version was for simplified local testing
            # if "%s_" % sample in c_stat:
            if "%s/" % sample in c_stat:
                sample_data[sample]["contig_stats"] = c_stat
        for g_table in gene_tables:
            # if "%s_" % sample in g_table:
            if "%s/" % sample in g_table:
                sample_data[sample]["gene_table"] = g_table
        for mapping_log in mapping_logs:
            # if "%s_" % sample in mapping_log:
            if "%s/" % sample in mapping_log:
                sample_data[sample]["mapping_log"] = mapping_log

    parse_map_stats(sample_data, combined_stats)


if __name__ == "__main__":
    main(
        samples=snakemake.params.samples,
        contig_stats=snakemake.input.contig_stats,
        gene_tables=snakemake.input.gene_tables,
        mapping_logs=snakemake.input.mapping_logs,
        combined_stats=snakemake.output.combined_contig_stats,
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
import os, gc
from utils.parsers_bbmap import read_coverage_binned, combine_coverages


contig2genome = pd.read_csv(
    snakemake.input.contig2genome, header=None, index_col=0, sep="\t"
).iloc[:, 0]


# sum counts
logging.info("Loading counts and coverage per contig")

combined_cov, Counts_contigs = combine_coverages(
    snakemake.input.coverage_files, snakemake.params.samples
)

combined_cov = combined_cov.T

combined_cov.insert(
    0, "Genome", value=pd.Categorical(contig2genome.loc[combined_cov.index].values)
)

logging.info(f"Saving coverage to {snakemake.output.coverage_contigs}")

combined_cov.reset_index().to_parquet(snakemake.output.coverage_contigs)

logging.info("Sum counts per genome")

Counts_genome = Counts_contigs.groupby(contig2genome, axis=1).sum().T
Counts_genome.index.name = "Sample"

logging.info(f"Saving counts to {snakemake.output.counts}")

Counts_genome.reset_index().to_parquet(snakemake.output.counts)
del Counts_genome, combined_cov, Counts_contigs
gc.collect()

# Binned coverage
logging.info("Loading binned coverage")
binCov = {}
for i, cov_file in enumerate(snakemake.input.binned_coverage_files):
    sample = snakemake.params.samples[i]

    binCov[sample] = read_coverage_binned(cov_file)

binCov = pd.DataFrame.from_dict(binCov)

logging.info("Add genome information to it")
binCov.insert(
    0,
    "Genome",
    value=pd.Categorical(contig2genome.loc[binCov.index.get_level_values(0)].values),
)

gc.collect()
logging.info(f"Saving combined binCov to {snakemake.output.binned_cov}")
binCov.reset_index().to_parquet(snakemake.output.binned_cov)

# Median coverage
logging.info("Calculate median coverage")
Median_abund = binCov.groupby("Genome").median().T
del binCov
gc.collect()
logging.info(f"Saving mediuan coverage {snakemake.output.median_abund}")
Median_abund.reset_index().to_parquet(snakemake.output.median_abund)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

db_columns = {
    "kegg": ["ko_id", "kegg_hit"],
    "peptidase": [
        "peptidase_id",
        "peptidase_family",
        "peptidase_hit",
        "peptidase_RBH",
        "peptidase_identity",
        "peptidase_bitScore",
        "peptidase_eVal",
    ],
    "pfam": ["pfam_hits"],
    "cazy": ["cazy_ids", "cazy_hits", "cazy_subfam_ec", "cazy_best_hit"],
    # "heme": ["heme_regulatory_motif_count"],
}

Tables = defaultdict(list)

for file in snakemake.input:
    df = pd.read_csv(file, index_col=0, sep="\t")

    # drop un-annotated genes
    df = df.query("rank!='E'")

    # change index from 'subset1_Gene111' ->  simply 'Gene111'
    # Gene name to nr
    df.index = (
        df.index.str.split("_", n=1, expand=True)
        .get_level_values(1)
        .str[len("Gene") :]
        .astype(np.int64)
    )
    df.index.name = "GeneNr"

    # select columns, drop na rows and append to list
    for db in db_columns:
        cols = db_columns[db]

        if not df.columns.intersection(cols).empty:

            Tables[db].append(df[cols].dropna(axis=0, how="all"))

    del df

out_dir = Path(snakemake.output[0])
out_dir.mkdir()

for db in Tables:

    combined = pd.concat(Tables[db], axis=0)

    combined.sort_index(inplace=True)

    combined.reset_index().to_parquet(out_dir / (db + ".parquet"))
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of script
import numpy as np
import pandas as pd
import gc, os


import h5py

import h5py

import psutil


def measure_memory(write_log_entry=True):
    mem_uage = psutil.Process().memory_info().rss / (1024 * 1024)

    if write_log_entry:
        logging.info(f"The process is currently using {mem_uage: 7.0f} MB of RAM")

    return mem_uage


logging.info("Start")
measure_memory()

N_samples = len(snakemake.input.covstats)

logging.info("Read gene info")

gene_info = pd.read_table(snakemake.input.info)

# Gene name is only first part of first column
gene_info.index = gene_info["#Name"].str.split(" ", n=1, expand=True)[0]
gene_info.index.name = "GeneName"
gene_info.drop("#Name", axis=1, inplace=True)

gene_info.sort_index(inplace=True)
N_genes = gene_info.shape[0]
# gene_list= gene_info.index

# Sort
gene_info.sort_index(inplace=True)
N_genes = gene_info.shape[0]

gene_info[
    ["Samples_nz_coverage", "Samples_nz_counts", "Sum_coverage", "Max_coverage"]
] = 0


# gene_list= gene_info.index


logging.info("Open hdf files for writing")

gene_matrix_shape = (N_samples, N_genes)

with h5py.File(snakemake.output.cov, "w") as hdf_cov_file, h5py.File(
    snakemake.output.counts, "w"
) as hdf_counts_file:
    combined_cov = hdf_cov_file.create_dataset(
        "data", shape=gene_matrix_shape, fillvalue=0, compression="gzip"
    )
    combined_counts = hdf_counts_file.create_dataset(
        "data", shape=gene_matrix_shape, fillvalue=0, compression="gzip"
    )

    # add Smaple names attribute
    sample_names = np.array(list(snakemake.params.samples)).astype("S")
    combined_cov.attrs["sample_names"] = sample_names
    combined_counts.attrs["sample_names"] = sample_names

    gc.collect()

    Summary = {}

    logging.info("Start reading files")
    initial_mem_uage = measure_memory()

    for i, sample in enumerate(snakemake.params.samples):
        logging.info(f"Read coverage file for sample {i+1} / {N_samples}")
        sample_cov_file = snakemake.input.covstats[i]

        data = pd.read_parquet(
            sample_cov_file, columns=["GeneName", "Reads", "Median_fold"]
        ).set_index("GeneName")

        assert (
            data.shape[0] == N_genes
        ), f"I only have {data.shape[0]} /{N_genes} in the file {sample_cov_file}"

        # genes are not sorted :-()
        assert (
            data.index.is_monotonic_increasing
        ), f"data is not sorted by index in {sample_cov_file}"

        # downcast data
        # median is int
        Median_fold = pd.to_numeric(data.Median_fold, downcast="integer")
        Reads = pd.to_numeric(data.Reads, downcast="integer")

        # delete interminate data and release mem
        del data

        # get summary statistics per sample
        logging.debug("Extract Summary statistics")

        Summary[sample] = {
            "Sum_coverage": Median_fold.sum(),
            "Total_counts": Reads.sum(),
            "Genes_nz_counts": (Reads > 0).sum(),
            "Genes_nz_coverage": (Median_fold > 0).sum(),
        }

        # get gene wise stats
        gene_info["Samples_nz_counts"] += (Reads > 0) * 1
        gene_info["Samples_nz_coverage"] += (Median_fold > 0) * 1
        gene_info["Sum_coverage"] += Median_fold

        gene_info["Max_coverage"] = np.fmax(gene_info["Max_coverage"], Median_fold)

        combined_cov[i, :] = Median_fold.values
        combined_counts[i, :] = Reads.values

        del Median_fold, Reads
        gc.collect()

        current_mem_uage = measure_memory()


logging.info("All samples processed")
gc.collect()

logging.info("Save sample Summary")
pd.DataFrame(Summary).T.to_csv(snakemake.output.sample_info, sep="\t")


logging.info("Save gene Summary")

# downcast
for col in gene_info.columns:
    if col == "GC":
        gene_info[col] = pd.to_numeric(gene_info[col], downcast="float")
    else:
        gene_info[col] = pd.to_numeric(gene_info[col], downcast="integer")

gene_info.reset_index().to_parquet(snakemake.output.gene_info)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

import pandas as pd
import numpy as np
from utils.taxonomy import tax2table

from glob import glob

gtdb_classify_folder = snakemake.input.folder

taxonomy_files = glob(f"{gtdb_classify_folder}/gtdbtk.*.summary.tsv")

N_taxonomy_files = len(taxonomy_files)
logging.info(f"Found {N_taxonomy_files} gtdb taxonomy files.")

if (0 == N_taxonomy_files) or (N_taxonomy_files > 2):
    raise Exception(
        f"Found {N_taxonomy_files} number of taxonomy files 'gtdbtk.*.summary.tsv' in {gtdb_classify_folder} expect 1 or 2."
    )


DT = pd.concat([pd.read_table(file, index_col=0) for file in taxonomy_files], axis=0)

DT.to_csv(snakemake.output.combined)

Tax = tax2table(DT.classification, remove_prefix=True)
Tax.to_csv(snakemake.output.taxonomy, sep="\t")
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
import re


def main(jgi_file):
    # parsing input
    header = {}
    col2keep = ["contigName", "contigLen", "totalAvgDepth"]
    with open(jgi_file) as inF:
        for i, line in enumerate(inF):
            line = line.rstrip().split("\t")
            if i == 0:
                header = {x: ii for ii, x in enumerate(line)}
                col2keep += [x for x in line if x.endswith(".bam")]
                print("\t".join(col2keep))
                continue
            elif line[0] == "":
                continue
            # contig ID
            contig = line[header["contigName"]]
            # collect per-sample info
            out = []
            for col in col2keep:
                out.append(line[header[col]])
            print("\t".join(out))


if __name__ == "__main__":
    if "snakemake" in globals():
        with open(snakemake.log[0], "w") as log:
            sys.stderr = log

            with open(snakemake.output[0], "w") as outf:
                sys.stdout = outf

                main(snakemake.input[0])

    else:
        import argparse
        import logging

        logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.DEBUG)

        class CustomFormatter(
            argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter
        ):
            pass

        desc = (
            "Converting jgi_summarize_bam_contig_depths output to format used by VAMB"
        )
        epi = """DESCRIPTION:
        Output format: contigName<tab>contigLen<tab>totalAvgDepth<tab>SAMPLE1.sort.bam<tab>Sample2.sort.bam<tab>...
        Output written to STDOUT
        """
        parser = argparse.ArgumentParser(
            description=desc, epilog=epi, formatter_class=CustomFormatter
        )
        argparse.ArgumentDefaultsHelpFormatter
        parser.add_argument(
            "jgi_file",
            metavar="jgi_file",
            type=str,
            help="jgi_summarize_bam_contig_depths output table",
        )
        parser.add_argument("--version", action="version", version="0.0.1")

        args = parser.parse_args()
        main(args.jgi_file)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd

annotation_file = snakemake.input.annotations
module_output_table = snakemake.output[0]

from mag_annotator.database_handler import DatabaseHandler
from mag_annotator.summarize_genomes import build_module_net, make_module_coverage_frame

annotations = pd.read_csv(annotation_file, sep="\t", index_col=0)


# get db_locs and read in dbs
database_handler = DatabaseHandler(logger=logging, config_loc=snakemake.input.config)


if "module_step_form" not in database_handler.config["dram_sheets"]:
    raise ValueError(
        "Module step form location must be set in order to summarize genomes"
    )

module_steps_form = pd.read_csv(
    database_handler.config["dram_sheets"]["module_step_form"], sep="\t"
)

all_module_nets = {
    module: build_module_net(module_df)
    for module, module_df in module_steps_form.groupby("module")
}

module_coverage_frame = make_module_coverage_frame(
    annotations, all_module_nets, groupby_column="fasta"
)

module_coverage_frame.to_csv(module_output_table, sep="\t")
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pyfastx


faa_iterator = pyfastx.Fastx(snakemake.input.faa, format="fasta")
fna_iterator = pyfastx.Fastx(snakemake.input.fna, format="fasta")


with open(snakemake.output.faa, "w") as out_faa, open(
    snakemake.output.fna, "w"
) as out_fna, open(snakemake.output.short, "w") as out_short:
    for name, seq, comment in fna_iterator:
        protein = next(faa_iterator)

        # include gene and corresponding protein if gene passes length threshold
        # or annotation contains prodigal info that it's complete
        if (len(seq) >= snakemake.params.minlength_nt) or ("partial=00" in comment):
            out_fna.write(f">{name} {comment}\n{seq}\n")
            out_faa.write(">{0} {2}\n{1}\n".format(*protein))

        else:
            out_short.write(">{0} {2}\n{1}\n".format(*protein))
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from glob import glob
from numpy import log

from utils.parsers import load_quality


Q = load_quality(snakemake.input.quality)

stats = pd.read_csv(snakemake.input.stats, index_col=0, sep="\t")
stats["logN50"] = log(stats.N50)

# merge table but only shared Bins and non overlapping columns
Q = Q.join(stats.loc[Q.index, stats.columns.difference(Q.columns)])
del stats

n_all_bins = Q.shape[0]

filter_criteria = snakemake.params["filter_criteria"]
logging.info(f"Filter genomes according to criteria:\n {filter_criteria}")


Q = Q.query(filter_criteria)

logging.info(f"Retain {Q.shape[0]} genomes from {n_all_bins}")


## GUNC

if hasattr(snakemake.input, "gunc"):
    gunc = pd.read_table(snakemake.input.gunc, index_col=0)
    gunc = gunc.loc[Q.index]

    bad_genomes = gunc.index[gunc["pass.GUNC"] == False]
    logging.info(f"{len(bad_genomes)} Don't pass gunc filtering")

    Q.drop(bad_genomes, inplace=True)
else:
    logging.info(" Don't filter based on gunc")


if Q.shape[0] == 0:
    logging.error(
        f"No bins passed filtering criteria! Bad luck!. You might want to tweek the filtering criteria. Also check the {snakemake.input.quality}"
    )
    exit(1)

# output Q together with quality
Q.to_csv(snakemake.output.info, sep="\t")


# filter path genomes for skani

F = pd.read_table(snakemake.input.paths, index_col=0).squeeze()

F = F.loc[Q.index].iloc[:, 0]
F.to_csv(snakemake.output.paths, index=False, header=False)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of script

import pandas as pd
from utils import gene_scripts

# if MAGs are renamed I need to obtain the old contig names
# otherwise not
if snakemake.params.renamed_contigs:
    contigs2bins = pd.read_csv(
        snakemake.input.contigs2bins, index_col=0, sep="\t", header=None
    )

    contigs2bins.columns = ["Bin"]
    old2newID = pd.read_csv(snakemake.input.old2newID, index_col=0, sep="\t").squeeze()

    contigs2genome = contigs2bins.join(old2newID, on="Bin").dropna().drop("Bin", axis=1)
else:
    contigs2genome = pd.read_csv(
        snakemake.input.contigs2mags, index_col=0, squeeze=False, sep="\t", header=None
    )
    contigs2genome.columns = ["MAG"]

# load orf_info
orf_info = pd.read_parquet(snakemake.input.orf_info)


# recreate Contig name `Sample_ContigNr` and Gene names `Gene0004`
orf_info["Contig"] = orf_info.Sample + "_" + orf_info.ContigNr.astype(str)
orf_info["Gene"] = gene_scripts.geneNr_to_string(orf_info.GeneNr)

# Join genomes on contig
orf_info = orf_info.join(contigs2genome, on="Contig")

# remove genes not on genomes
orf_info = orf_info.dropna(axis=0)


# count genes per genome in a matrix
gene2genome = pd.to_numeric(
    orf_info.groupby(["Gene", "MAG"]).size(), downcast="unsigned"
).unstack(fill_value=0)

# save as parquet
gene2genome.reset_index().to_parquet(snakemake.output[0])
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


## Start

import pandas as pd
import numpy as np

from utils import gene_scripts

# CLuterID    GeneID    empty third column
orf2gene = pd.read_csv(
    snakemake.input.cluster_attribution, header=None, sep="\t", usecols=[0, 1]
)

orf2gene.columns = ["Representative", "ORF"]

# split orf names in sample, contig_nr, and orf_nr
orf_info = gene_scripts.split_orf_to_index(orf2gene.ORF)

# rename representative

representative_names = orf2gene.Representative.unique()

map_names = pd.Series(
    index=representative_names,
    data=np.arange(1, len(representative_names) + 1, dtype=np.uint),
)


orf_info["GeneNr"] = orf2gene.Representative.map(map_names)


orf_info.to_parquet(snakemake.output.cluster_attribution)


# Save name of representatives
map_names.index.name = "Representative"
map_names.name = "GeneNr"
map_names.to_csv(snakemake.output.rep2genenr, sep="\t")
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.DEBUG,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


# start of script
import argparse
import os, sys
import shutil
import warnings

import pandas as pd
from Bio import SeqIO


def get_fasta_of_bins(cluster_attribution, contigs_file, out_folder):
    """
    Creates individual fasta files for each bin using the contigs fasta and the cluster attribution.

    input:
    - cluster attribution file:   tab seperated file of "contig_fasta_header    bin"
    - contigs:                    fasta file of contigs
    - out_prefix:                 output_prefix for bin fastas  {out_folder}/{binid}.fasta
    """
    # create outdir
    if os.path.exists(out_folder):
        shutil.rmtree(out_folder)
    os.makedirs(out_folder)

    CA = pd.read_csv(cluster_attribution, header=None, sep="\t", dtype=str)

    assert CA.shape[1] == 2, "File should have only two columns " + cluster_attribution

    CA.columns = ["Contig", "Bin"]

    # # assert that Contig is unique
    # assert CA.Contig.is_unique, (
    #     f"First column of file {cluster_attribution} should be contigs, hence unique"
    #     f"I got\n{CA.head()}"
    # )

    logging.info(f"index fasta file {contigs_file} for fast access")
    contig_fasta_dict = SeqIO.index(str(contigs_file), "fasta")

    assert len(contig_fasta_dict) > 0, "No contigs in your fasta"

    unique_bins = CA.Bin.unique()

    assert len(unique_bins) >= 1, "No bins found"

    for binid in unique_bins:
        bin_contig_names = CA.loc[CA.Bin == binid, "Contig"].tolist()
        out_file = os.path.join(out_folder, f"{binid}.fasta")

        assert (
            len(bin_contig_names) >= 1
        ), f"No contigs found for bin {binid} in {cluster_attribution}"

        if len(bin_contig_names) == 1:
            warnings.warn(f"single contig bin Bin : {binid} {bin_contig_names}")

        logging.debug(f"Found {len(bin_contig_names)} contigs {bin_contig_names}")

        fasta_contigs = [contig_fasta_dict[c] for c in bin_contig_names]
        SeqIO.write(fasta_contigs, out_file, "fasta")


if __name__ == "__main__":
    if "snakemake" not in globals():
        p = argparse.ArgumentParser()
        p.add_argument("--cluster-attribution")
        p.add_argument("--contigs")
        p.add_argument("--out-folder")
        args = vars(p.parse_args())
        get_fasta_of_bins(**args)
    else:
        get_fasta_of_bins(
            snakemake.input.cluster_attribution,
            snakemake.input.contigs,
            snakemake.output[0],
        )