Metagenome-Assembled Genome Analysis of Gut Microbiota

public public 1yr ago Version: v2.16.2 0 bookmarks

MAGs_IBD

MAGs_IBD is a easy-to-use metagenomic pipeline based on snakemake. It handles all steps from QC, Assembly, Binning, to Annotation.

You can start using atlas with three commands:

 mamba install -y -c bioconda -c conda-forge metagenome-atlas={latest_version} atlas init --db-dir databases path/to/fastq/files atlas run all

where {latest_version} should be replaced by

Developpment/Extensions

Here are some ideas I work or want to work on when I have time. If you want to contribute or have some ideas let me know via a feature request issue.

  • Optimized MAG recovery (e.g. Spacegraphcats )

  • Integration of viruses/plasmid that live for now as extensions

  • Add statistics and visualisations as in atlas_analyze

  • Implementation of most rules as snakemake wrapper

  • Cloud execution

  • Update to new Snakemake version and use cool reports.

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

from common_report import *

import os, sys
import pandas as pd
import plotly.express as px


labels = {
    "Percent_Assembled_Reads": "Percent of Assembled Reads",
    "contig_bp": "Total BP",
    "n_contigs": "Contigs (count)",
    "N_Predicted_Genes": "Predicted Genes (count)",
    "N50": "N50-number",
    "L50": "N50-length (bp)",
    "N90": "N90-number",
    "L90": "N90-length (bp)",
}


PLOT_PARAMS = dict(labels=labels)


def make_plots(combined_stats):
    ## Make figures with PLOTLY
    # load and rename data
    df = pd.read_csv(combined_stats, sep="\t", index_col=0)
    df.sort_index(ascending=True, inplace=True)
    df.index.name = "Sample"
    df["Sample"] = df.index

    # create plots store in div
    div = {}

    fig = px.strip(df, y="Percent_Assembled_Reads", hover_name="Sample", **PLOT_PARAMS)
    fig.update_yaxes(range=[0, 100])
    div["Percent_Assembled_Reads"] = fig.to_html(**HTML_PARAMS)

    fig = px.strip(df, y="N_Predicted_Genes", hover_name="Sample", **PLOT_PARAMS)
    div["N_Predicted_Genes"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(df, y="L50", x="N50", hover_name="Sample", **PLOT_PARAMS)
    div["N50"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(df, y="L90", x="N90", hover_name="Sample", **PLOT_PARAMS)
    div["N90"] = fig.to_html(**HTML_PARAMS)

    fig = px.scatter(
        df, y="contig_bp", x="n_contigs", hover_name="Sample", **PLOT_PARAMS
    )
    div["Total"] = fig.to_html(**HTML_PARAMS)

    return div


# main


div = make_plots(combined_stats=snakemake.input.combined_contig_stats)


make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_assembly_report.html"),
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts


from common_report import *

import pandas as pd
import plotly.express as px


from utils.taxonomy import tax2table


def make_plots(bin_table):
    div = {}

    div["input_file"] = bin_table

    # Prepare data
    df = pd.read_table(bin_table)

    if snakemake.config["bin_quality_asesser"].lower() == "busco":
        df["Bin Id"] = df["Input_file"].str.replace(".fasta", "", regex=False)

        logging.info("No taxonomic information available, use busco Dataset")

        lineage_name = "Dataset"
        hover_data = [
            "Scores_archaea_odb10",
            "Scores_bacteria_odb10",
            "Scores_eukaryota_odb10",
        ]
        size_name = None

    elif snakemake.config["bin_quality_asesser"].lower() == "checkm":
        df = df.join(
            tax2table(df["Taxonomy (contained)"], remove_prefix=True).fillna("NA")
        )

        lineage_name = "phylum"
        size_name = "Genome size (Mbp)"
        hover_data = ["genus"]

    elif snakemake.config["bin_quality_asesser"].lower() == "checkm2":
        df["Bin Id"] = df.index

        lineage_name = "Translation_Table_Used"
        hover_data = [
            "Completeness_Model_Used",
            "Coding_Density",
            "Contig_N50",
            "GC_Content",
            "Additional_Notes",
        ]
        size_name = "Genome_Size"
    else:
        raise Exception(f"bin_quality_asesser in the config file not understood")

    df.index = df["Bin Id"]

    div[
        "QualityScore"
    ] = "<p>Quality score is calculated as: Completeness - 5 x Contamination.</p>"

    # 2D plot
    fig = px.scatter(
        data_frame=df,
        y="Completeness",
        x="Contamination",
        color=lineage_name,
        size=size_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    fig.update_xaxes(range=(-0.2, 10.1))
    div["2D"] = fig.to_html(**HTML_PARAMS)

    ## By sample
    fig = px.strip(
        data_frame=df,
        y="Quality_score",
        x="Sample",
        color=lineage_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    div["bySample"] = fig.to_html(**HTML_PARAMS)

    # By Phylum
    fig = px.strip(
        data_frame=df,
        y="Quality_score",
        x=lineage_name,
        hover_data=hover_data,
        hover_name="Bin Id",
    )
    fig.update_yaxes(range=(50, 102))
    div["byPhylum"] = fig.to_html(**HTML_PARAMS)

    return div


# main


div = make_plots(bin_table=snakemake.input.bin_table)


make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_bin_report.html"),
    wildcards=snakemake.wildcards,
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

from common_report import *


import pandas as pd
import plotly.express as px
from plotly import subplots
import plotly.graph_objs as go
import numpy as np


labels = {"Total_Reads": "Total Reads", "Total_Bases": "Total Bases"}


PLOT_PARAMS = dict(labels=labels)


import zipfile


def get_stats_from_zips(zips, samples):
    # def get_read_stats(samples, step):
    quality_pe = pd.DataFrame()
    quality_se = pd.DataFrame()
    for zfile, sample in zip(zips, samples):
        zf = zipfile.ZipFile(zfile)

        # single end only
        if "boxplot_quality.txt" in zf.namelist():
            with zf.open("boxplot_quality.txt") as f:
                df = pd.read_csv(f, index_col=0, sep="\t")
                quality_se[sample] = df.mean_1
        else:
            if "se/boxplot_quality.txt" in zf.namelist():
                with zf.open("se/boxplot_quality.txt") as f:
                    df = pd.read_csv(f, index_col=0, sep="\t")
                    quality_se[sample] = df.mean_1

            if "pe/boxplot_quality.txt" in zf.namelist():
                with zf.open("pe/boxplot_quality.txt") as f:
                    df = pd.read_csv(f, index_col=0, sep="\t")
                    df.columns = [df.columns, [sample] * df.shape[1]]

                    quality_pe = pd.concat(
                        (quality_pe, df[["mean_1", "mean_2"]]), axis=1
                    )

    return quality_pe, quality_se


def get_pe_read_quality_plot(df, quality_range, color_range):
    fig = subplots.make_subplots(cols=2)

    for i, sample in enumerate(df["mean_1"].columns):
        fig.append_trace(
            go.Scatter(
                x=df.index,
                y=df["mean_1"][sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                marker=dict(color=color_range[i]),
            ),
            1,
            1,
        )

        fig.append_trace(
            dict(
                x=df.index,
                y=df["mean_2"][sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                showlegend=False,
                marker=dict(color=color_range[i]),
            ),
            1,
            2,
        )

    fig.update_layout(
        yaxis=dict(range=quality_range, autorange=True, title="Average quality score"),
        xaxis1=dict(title="Position forward read"),
        xaxis2=dict(autorange="reversed", title="Position reverse read"),
    )

    return fig


def draw_se_read_quality(df, quality_range, color_range):
    fig = subplots.make_subplots(cols=1)

    for i, sample in enumerate(df.columns):
        fig.append_trace(
            go.Scatter(
                x=df.index,
                y=df[sample].values,
                type="scatter",
                name=sample,
                legendgroup=sample,
                marker=dict(color=color_range[i]),
            ),
            1,
            1,
        )

    fig.update_layout(
        yaxis=dict(range=quality_range, autorange=True, title="Average quality score"),
        xaxis=dict(title="Position read"),
    )
    return fig


def make_plots(
    samples, zipfiles_QC, read_counts, read_length, min_quality, insert_size_stats
):
    div = {}

    ## Quality along read

    N = len(samples)
    color_range = [
        "hsl(" + str(h) + ",50%" + ",50%)" for h in np.linspace(0, 360, N + 1)
    ]

    # load quality profiles for QC and low
    Quality_QC_pe, Quality_QC_se = get_stats_from_zips(zipfiles_QC, samples)
    # Quality_raw_pe, Quality_raw_se = get_stats_from_zips(zipfiles_QC,samples)

    # detrmine range of quality values and if paired
    max_quality = 1 + np.nanmax((Quality_QC_pe.max().max(), Quality_QC_se.max().max()))
    quality_range = [min_quality, max_quality]

    paired = Quality_QC_pe.shape[0] > 0

    # create plots if paired or not

    if paired:
        div["quality_QC"] = get_pe_read_quality_plot(
            Quality_QC_pe, quality_range, color_range
        ).to_html(**HTML_PARAMS)

    #     div["quality_raw"] = get_pe_read_quality_plot(
    #         Quality_raw_pe, quality_range, color_range
    #     ).to_html(**HTML_PARAMS)

    else:
        div["quality_QC"] = draw_se_read_quality(
            Quality_QC_se, quality_range, color_range
        ).to_html(**HTML_PARAMS)

    #     div["quality_raw"] = draw_se_read_quality(
    #         Quality_raw_se, quality_range, color_range
    #     ).to_html(**HTML_PARAMS)

    # Total reads plot

    df = pd.read_csv(read_counts, index_col=[0, 1], sep="\t")

    try:
        df.drop("clean", axis=0, level=1, inplace=True)
    except KeyError:
        pass

    data_qc = df.query('Step=="QC"')

    for var in ["Total_Reads", "Total_Bases"]:
        fig = px.strip(data_qc, y=var, **PLOT_PARAMS)
        fig.update_yaxes(range=(0, data_qc[var].max() * 1.1))
        div[var] = fig.to_html(**HTML_PARAMS)

    ## reads plot across different steps

    total_reads = df.Total_Reads.unstack()
    fig = px.bar(data_frame=total_reads, barmode="group", labels={"value": "Reads"})

    fig.update_yaxes(title="Number of reads")
    fig.update_xaxes(tickangle=45)
    # fig.update_layout(hovermode="x unified")

    div["Reads"] = fig.to_html(**HTML_PARAMS)

    ## Read length plot

    data_length = pd.read_table(read_length, index_col=0).T
    data_length.index.name = "Sample"

    fig = px.bar(
        data_frame=data_length,
        x="Median",
        error_x="Max",
        error_x_minus="Min",
        hover_data=["Median", "Max", "Min", "Avg", "Std_Dev", "Mode"],
    )

    fig.update_xaxes(title="Read length")

    div["Length"] = fig.to_html(**HTML_PARAMS)

    ### Insert insert_size_stats
    if insert_size_stats is None:
        div[
            "Insert"
        ] = "<p>Insert size information is not available for single end reads.</p>"
    else:
        data_insert = pd.read_table(insert_size_stats, index_col=0)
        data_insert.index.name = "Sample"

        fig = px.bar(
            data_frame=data_insert,
            x="Mean",
            error_x="STDev",
            hover_data=["Mean", "Median", "Mode", "PercentOfPairs"],
            labels={"PercentOfPairs": "Percent of pairs"},
        )

        fig.update_xaxes(title="Insert size")

        div["Insert"] = fig.to_html(**HTML_PARAMS)

    return div


# If paired we have information about insert size
if type(snakemake.input.read_length_stats) == str:
    read_length_path = snakemake.input.read_length_stats
    insert_size_stats = None
else:
    read_length_path, insert_size_stats = snakemake.input.read_length_stats

div = make_plots(
    samples=snakemake.params.samples,
    zipfiles_QC=snakemake.input.zipfiles_QC,
    read_counts=snakemake.input.read_counts,
    read_length=read_length_path,
    min_quality=snakemake.params.min_quality,
    insert_size_stats=insert_size_stats,
)

make_html(
    div=div,
    report_out=snakemake.output.report,
    html_template_file=os.path.join(reports_dir, "template_QC_report.html"),
)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
shell:
    """
    reformat.sh {params.inputs} \
        interleaved={params.interleaved} \
        {params.outputs} \
        iupacToN=t \
        touppercase=t \
        qout=33 \
        overwrite=true \
        verifypaired={params.verifypaired} \
        addslash=t \
        trimreaddescription=t \
        threads={threads} \
        pigz=t unpigz=t \
        -Xmx{resources.java_mem}G 2> {log}
    """
90
91
92
93
94
95
96
run:
    # make symlink
    assert len(input) == len(
        output
    ), "Input and ouput files have not same number, can not create symlinks for all."
    for i in range(len(input)):
        os.symlink(os.path.abspath(input[i]), output[i])
136
137
138
139
140
141
142
143
144
145
146
147
148
shell:
    " bbnorm.sh {params.inputs} "
    " {params.outputs} "
    " {params.tmpdir} "
    " tossbadreads=t "
    " hist={output.histin} "
    " histout={output.histout} "
    " mindepth={params.mindepth} "
    " k={params.k} "
    " target={params.target} "
    " prefilter=t "
    " threads={threads} "
    " -Xmx{resources.java_mem}G &> {log} "
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
shell:
    "tadpole.sh -Xmx{resources.java_mem}G "
    " prefilter={params.prefilter} "
    " prealloc=1 "
    " {params.inputs} "
    " {params.outputs} "
    " mode=correct "
    " aggressive={params.aggressive} "
    " tossjunk={params.tossjunk} "
    " lowdepthfraction={params.lowdepthfraction}"
    " tossdepth={params.tossdepth} "
    " merge=t "
    " shave={params.shave} rinse={params.shave} "
    " threads={threads} "
    " pigz=t unpigz=t "
    " ecc=t ecco=t "
    "&> {log} "
232
233
234
235
236
237
238
239
240
241
shell:
    """
    bbmerge.sh -Xmx{resources.java_mem}G threads={threads} \
        in1={input[0]} in2={input[1]} \
        outmerged={output[2]} \
        outu={output[0]} outu2={output[1]} \
        {params.flags} k={params.kmer} \
        pigz=t unpigz=t \
        extend2={params.extend2} 2> {log}
    """
274
275
shell:
    "cat {input} > {output}"
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
shell:
    """
    rm -r {params.outdir} 2> {log}

    megahit \
    {params.inputs} \
    --tmp-dir {TMPDIR} \
    --num-cpu-threads {threads} \
    --k-min {params.k_min} \
    --k-max {params.k_max} \
    --k-step {params.k_step} \
    --out-dir {params.outdir} \
    --out-prefix {wildcards.sample}_prefilter \
    --min-contig-len {params.min_contig_len} \
    --min-count {params.min_count} \
    --merge-level {params.merge_level} \
    --prune-level {params.prune_level} \
    --low-local-ratio {params.low_local_ratio} \
    --memory {resources.mem}000000000  \
    {params.preset} >> {log} 2>&1
    """
351
352
shell:
    "cp {input} {output}"
467
468
shell:
    "cp {input} {output}"
488
489
490
491
492
shell:
    "rename.sh "
    " in={input} out={output} ow=t "
    " prefix={wildcards.sample} "
    " minscaf={params.minlength} &> {log} "
508
509
shell:
    "stats.sh in={input} format=3 out={output} &> {log}"
520
521
522
523
524
525
526
527
528
529
530
run:
    import os
    import pandas as pd

    c = pd.DataFrame()
    for f in input:
        df = pd.read_csv(f, sep="\t")
        assembly_step = os.path.basename(f).replace("_contig_stats.txt", "")
        c.loc[assembly_step]

    c.to_csv(output[0], sep="\t")
550
551
wrapper:
    "v1.19.0/bio/minimap2/aligner"
570
571
572
573
574
575
576
577
578
shell:
    "pileup.sh ref={input.fasta} in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " concise=t "
    " minmapq={params.minmapq} "
    " secondary={params.pileup_secondary} "
    " 2> {log}"
601
602
603
604
605
606
607
608
609
610
611
shell:
    """filterbycoverage.sh in={input.fasta} \
    cov={input.covstats} \
    out={output.fasta} \
    outd={output.removed_names} \
    minc={params.minc} \
    minp={params.minp} \
    minr={params.minr} \
    minl={params.minl} \
    trim={params.trim} \
    -Xmx{resources.java_mem}G 2> {log}"""
627
628
shell:
    "cp {input} {output}"
641
642
run:
    os.symlink(os.path.relpath(input[0], os.path.dirname(output[0])), output[0])
662
663
wrapper:
    "v1.19.0/bio/minimap2/aligner"
691
692
693
694
695
696
697
698
699
700
701
702
703
shell:
    "pileup.sh "
    " ref={input.fasta} "
    " in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " hist={output.covhist} "
    " concise=t "
    " minmapq={params.minmapq} "
    " secondary={params.pileup_secondary} "
    " bincov={output.bincov} "
    " 2> {log} "
716
717
shell:
    "samtools index {input}"
737
738
739
740
741
shell:
    """
    prodigal -i {input} -o {output.gff} -d {output.fna} \
        -a {output.faa} -p meta -f gff 2> {log}
    """
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
run:
    header = [
        "gene_id",
        "Contig",
        "Gene_nr",
        "Start",
        "Stop",
        "Strand",
        "Annotation",
    ]
    with open(output.tsv, "w") as tsv:
        tsv.write("\t".join(header) + "\n")
        with open(input.faa) as fin:
            gene_idx = 0
            for line in fin:
                if line[0] == ">":
                    text = line[1:].strip().split(" # ")
                    old_gene_name = text[0]
                    text.remove(old_gene_name)
                    old_gene_name_split = old_gene_name.split("_")
                    gene_nr = old_gene_name_split[-1]
                    contig_nr = old_gene_name_split[-2]
                    sample = "_".join(
                        old_gene_name_split[: len(old_gene_name_split) - 2]
                    )
                    tsv.write(
                        "{gene_id}\t{sample}_{contig_nr}\t{gene_nr}\t{text}\n".format(
                            text="\t".join(text),
                            gene_id=old_gene_name,
                            i=gene_idx,
                            sample=sample,
                            gene_nr=gene_nr,
                            contig_nr=contig_nr,
                        )
                    )
                    gene_idx += 1
816
817
script:
    "../scripts/combine_contig_stats.py"
829
830
script:
    "../report/assembly_report.py"
30
31
32
33
34
35
36
37
38
39
shell:
    "pileup.sh "
    " ref={input.fasta} "
    " in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " minmapq={params.minmapq} "
    " secondary={params.pileup_secondary} "
    " 2> {log} "
52
53
54
55
56
57
58
run:
    with open(input[0]) as fi, open(output[0], "w") as fo:
        # header
        next(fi)
        for line in fi:
            toks = line.strip().split("\t")
            print(toks[0], toks[1], sep="\t", file=fo)
70
71
72
73
74
75
76
77
run:
    from utils.parsers_bbmap import combine_coverages

    combined_cov, _ = combine_coverages(
        input.covstats, get_alls_samples_of_group(wildcards), "Avg_fold"
    )

    combined_cov.T.to_csv(output[0], sep="\t")
102
103
104
105
106
107
108
109
110
111
112
shell:
    """
    concoct -c {params.Nexpected_clusters} \
        --coverage_file {input.coverage} \
        --composition_file {input.fasta} \
        --basename {params.basename} \
        --read_length {params.read_length} \
        --length_threshold {params.min_length} \
        --converge_out \
        --iterations {params.niterations}
    """
124
125
126
127
run:
    with open(input[0]) as fin, open(output[0], "w") as fout:
        for line in fin:
            fout.write(line.replace(",", "\t"))
147
148
149
150
151
shell:
    """
    jgi_summarize_bam_contig_depths --outputDepth {output} {input.bam} \
        &> {log}
    """
180
181
182
183
184
185
186
187
188
189
190
shell:
    """
    metabat2 -i {input.contigs} \
        --abdFile {input.depth_file} \
        --minContig {params.min_contig_len} \
        --numThreads {threads} \
        --maxEdges {params.sensitivity} \
        --saveCls --noBinOut \
        -o {output} \
        &> {log}
    """
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
shell:
    """
    mkdir {output[0]} 2> {log}
    run_MaxBin.pl -contig {input.fasta} \
        -abund {input.abund} \
        -out {params.output_prefix} \
        -min_contig_length {params.mcl} \
        -thread {threads} \
        -prob_threshold {params.pt} \
        -max_iteration {params.mi} >> {log}

    mv {params.output_prefix}.summary {output[0]}/.. 2>> {log}
    mv {params.output_prefix}.marker {output[0]}/..  2>> {log}
    mv {params.output_prefix}.marker_of_each_bin.tar.gz {output[0]}/..  2>> {log}
    mv {params.output_prefix}.log {output[0]}/..  2>> {log}

    """
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
run:
    import pandas as pd
    import numpy as np


    d = pd.read_csv(input[0], index_col=0, squeeze=True, header=None, sep="\t")

    assert (
        type(d) == pd.Series
    ), "expect the input to be a two column file: {}".format(input[0])

    old_cluster_ids = list(d.unique())
    if 0 in old_cluster_ids:
        old_cluster_ids.remove(0)

    map_cluster_ids = dict(
        zip(
            old_cluster_ids,
            utils.gen_names_for_range(
                len(old_cluster_ids),
                prefix="{sample}_{binner}_".format(**wildcards),
            ),
        )
    )

    new_d = d.map(map_cluster_ids)
    new_d.dropna(inplace=True)
    if new_d.shape[0] == 0:
        logger.warning(
            f"No bins detected with binner {wildcards.binner} in sample {wildcards.sample}.\n"
            "I add longest contig to make the pipline continue"
        )

        new_d[f"{wildcards.sample}_0"] = "{sample}_{binner}_1".format(**wildcards)

    new_d.to_csv(output[0], sep="\t", header=False)
296
297
298
299
300
301
302
303
304
305
306
run:
    (bin_ids,) = glob_wildcards(params.file_name)
    print("found {} bins".format(len(bin_ids)))
    with open(output[0], "w") as out_file:
        for binid in bin_ids:
            with open(params.file_name.format(binid=binid)) as bin_file:
                for line in bin_file:
                    if line.startswith(">"):
                        fasta_header = line[1:].strip().split()[0]
                        out_file.write(f"{fasta_header}\t{binid}\n")
            os.remove(params.file_name.format(binid=binid))
319
320
script:
    "get_fasta_of_bins.py"
332
333
shell:
    "cp {input} {output}"
360
361
362
363
364
365
366
367
368
369
370
371
372
373
shell:
    " DAS_Tool --outputbasename {params.output_prefix} "
    " --bins {params.scaffolds2bin} "
    " --labels {params.binner_names} "
    " --contigs {input.contigs} "
    " --search_engine diamond "
    " --proteins {input.proteins} "
    " --write_bin_evals "
    " --megabin_penalty {params.megabin_penalty}"
    " --duplicate_penalty {params.duplicate_penalty} "
    " --threads {threads} "
    " --debug "
    " --score_threshold {params.score_threshold} &> {log} "
    " ; mv {params.output_prefix}_DASTool_contig2bin.tsv {output.cluster_attribution} &>> {log}"
12
13
14
15
16
17
run:
    from utils.genome_stats import get_many_genome_stats

    filenames = list(Path(input[0]).glob("*" + params.extension))

    get_many_genome_stats(filenames, output[0], threads)
32
33
34
35
36
37
38
39
40
41
42
43
44
run:
    try:
        from utils.io import pandas_concat

        pandas_concat(input, output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
71
72
73
74
75
76
77
78
79
80
81
82
83
84
shell:
    " checkm2 predict "
    " --threads {threads} "
    " {params.lowmem} "
    " --force "
    " --allmodels "
    " -x .fasta "
    " --tmpdir {resources.tmpdir} "
    " --input {input.fasta_dir} "
    " --output-directory {params.dir} "
    " &> {log[0]} "
    ";\n"
    " cp {params.dir}/quality_report.tsv {output.table} 2>> {log[0]} ; "
    " mv {params.dir}/protein_files {output.faa} 2>> {log[0]} ; "
107
108
109
110
111
112
113
114
115
116
117
118
119
shell:
    " mkdir {output.folder} 2> {log}"
    " ;\n"
    " gunc run "
    " --threads {threads} "
    " --gene_calls "
    " --db_file {input.db} "
    " --input_dir {input.fasta_dir} "
    " --temp_dir {resources.tmpdir} "
    " --file_suffix {params.extension} "
    " --out_dir {output.folder} &>> {log} "
    " ;\n "
    " cp {output.folder}/*.tsv {output.table} 2>> {log}"
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    shell:
        " busco -i {input.fasta_dir} "
        " --auto-lineage-prok "
        " -m genome "
        " --out_path {params.tmpdir} "
        " -o output "
        " --download_path {input.db} "
        " -c {threads} "
        " --offline &> {log} "
        " ; "
        " mv {params.tmpdir}/output/batch_summary.txt {output} 2>> {log}"

"""
# fetch also output/logs/busco.log


## Combine bin stats
localrules:
    build_bin_report,
    combine_checkm2,
    combine_gunc,


rule combine_gunc:
    input:
        expand(
            "{sample}/binning/{{binner}}/gunc_output.tsv",
            sample=SAMPLES,
        ),
    output:
        bin_table="Binning/{binner}/gunc_report.tsv",
    params:
        samples=SAMPLES,
    log:
        "logs/binning/{binner}/combine_gunc.log",
    run:
        try:
            from utils.io import pandas_concat

            pandas_concat(input, output[0])

        except Exception as e:
            import traceback

            with open(log[0], "w") as logfile:
                traceback.print_exc(file=logfile)

            raise e


rule combine_checkm2:
    input:
        completeness_files=expand(
            "{sample}/binning/{{binner}}/checkm2_report.tsv",
            sample=SAMPLES,
        ),
    output:
        bin_table="Binning/{binner}/checkm2_quality_report.tsv",
    params:
        samples=SAMPLES,
    log:
        "logs/binning/combine_stats_{binner}.log",
    script:
        "../scripts/combine_checkm2.py"


localrules:
    get_bin_filenames,


rule get_bin_filenames:
    input:
        dirs=expand(
            "{sample}/binning/{{binner}}/bins",
            sample=SAMPLES,
        ),
        protein_dirs=expand(
            "{sample}/binning/{{binner}}/faa",
            sample=SAMPLES,
        ),
    output:
        filenames="Binning/{binner}/paths.tsv",
    run:
        import pandas as pd
        from pathlib import Path
        from utils import io


        def get_list_of_files(dirs, pattern):
            fasta_files = []

            # searh for fasta files (.f*) in all bin folders
            for dir in dirs:
                dir = Path(dir)
                fasta_files += list(dir.glob(pattern))

            filenames = pd.DataFrame(fasta_files, columns=["Filename"])
            filenames.index = filenames.Filename.apply(io.simplify_path)
            filenames.index.name = "Bin"

            filenames.sort_index(inplace=True)

            return filenames


        fasta_filenames = get_list_of_files(input.dirs, "*.f*")
        faa_filenames = get_list_of_files(input.protein_dirs, "*.faa")

        assert all(
            faa_filenames.index == fasta_filenames.index
        ), "faa index and faa index are nt the same"

        faa_filenames.columns = ["Proteins"]

        filenames = pd.concat((fasta_filenames, faa_filenames), axis=1)

        filenames.to_csv(output.filenames, sep="\t")

        """
180
181
182
183
184
185
186
187
188
189
190
191
192
run:
    try:
        from utils.io import pandas_concat

        pandas_concat(input, output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
run:
    import pandas as pd
    from pathlib import Path
    from utils import io


    def get_list_of_files(dirs, pattern):
        fasta_files = []

        # searh for fasta files (.f*) in all bin folders
        for dir in dirs:
            dir = Path(dir)
            fasta_files += list(dir.glob(pattern))

        filenames = pd.DataFrame(fasta_files, columns=["Filename"])
        filenames.index = filenames.Filename.apply(io.simplify_path)
        filenames.index.name = "Bin"

        filenames.sort_index(inplace=True)

        return filenames


    fasta_filenames = get_list_of_files(input.dirs, "*.f*")
    faa_filenames = get_list_of_files(input.protein_dirs, "*.faa")

    assert all(
        faa_filenames.index == fasta_filenames.index
    ), "faa index and faa index are nt the same"

    faa_filenames.columns = ["Proteins"]

    filenames = pd.concat((fasta_filenames, faa_filenames), axis=1)

    filenames.to_csv(output.filenames, sep="\t")

    """
    rule merge_bin_info:
        input:
            stats ="Binning/{binner}/genome_stats.tsv",
            gunc= "Binning/{binner}/gunc_report.tsv",
            quality= "Binning/{binner}/checkm2_quality_report.tsv"
        output:
            "Binning/{binner}/combined_bin_info.tsv"

    """
284
285
script:
    "../report/bin_report.py"
300
301
302
303
run:
    from utils.io import cat_files

    cat_files(input, output[0], gzip=True)
341
342
script:
    "../scripts/filter_genomes.py"
74
75
76
77
78
79
80
81
82
shell:
    """
    cd-hit-est -i {input} -T {threads} \
    -M {resources.mem}000 -o {params.prefix} \
    -c {params.identity} -n 9  -d 0 {params.extra} \
    -aS {params.coverage} -aL {params.coverage} &> {log}

    mv {params.prefix} {output[0]} 2>> {log}
    """
90
91
92
93
94
run:
    with open(output[0], "w") as fout:
        fout.write(f"ORF\tLength\tIdentity\tRepresentative\n")
        Clusters = parse_cd_hit_file(input[0])
        write_cd_hit_clusters(Clusters, fout)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
run:
    import pandas as pd
    import numpy as np

    from utils import gene_scripts

    # cd hit format ORF\tLength\tIdentity\tRepresentative\n
    orf2gene = pd.read_csv(input.orf2gene, sep="\t")

    # rename gene repr to Gene0000XX

    # split orf names in sample, contig_nr, and orf_nr
    orf_info = gene_scripts.split_orf_to_index(orf2gene.ORF)

    # rename representative

    representative_names = orf2gene.Representative.unique()

    map_names = pd.Series(
        index=representative_names,
        data=np.arange(1, len(representative_names) + 1, dtype=np.uint),
    )


    orf_info["GeneNr"] = orf2gene.Representative.map(map_names)


    orf_info.to_parquet(output.cluster_attribution)


    # Save name of representatives
    map_names.index.name = "Representative"
    map_names.name = "GeneNr"
    map_names.to_csv(output.rep2genenr, sep="\t")
24
25
26
27
28
29
shell:
    " reformat.sh in={input} "
    " fastaminlen={params.min_length} "
    " out={output} "
    " overwrite=true "
    " -Xmx{resources.java_mem}G 2> {log} "
49
50
51
52
53
54
55
56
57
58
59
60
run:
    import gzip as gz

    with gz.open(output[0], "wt") as fout:
        for sample, input_fasta in zip(params.samples, input.fasta):
            with open(input_fasta) as fin:
                for line in fin:
                    # if line is a header add sample name
                    if line[0] == ">":
                        line = f">{sample}{params.seperator}" + line[1:]
                    # write each line to the combined file
                    fout.write(line)
79
80
shell:
    "minimap2 -I {params.index_size} -t {threads} -d {output} {input} 2> {log}"
96
97
shell:
    "samtools dict {input} | cut -f1-3 > {output} 2> {log}"
116
117
shell:
    """minimap2 -t {threads} -ax sr {input.mmi} {input.fq} | grep -v "^@" | cat {input.dict} - | samtools view -F 3584 -b - > {output.bam} 2>{log}"""
138
139
shell:
    "samtools sort {input} -T {params.prefix} --threads {threads} -m 3G -o {output} 2>{log}"
154
155
156
157
shell:
    "jgi_summarize_bam_contig_depths "
    " --outputDepth {output} "
    " {input.bam} &> {log} "
172
173
script:
    "../scripts/convert_jgi2vamb_coverage.py"
196
197
198
199
200
201
202
203
shell:
    "vamb --outdir {output} "
    " -m {params.mincontig} "
    " --minfasta {params.minfasta} "
    " -o '{params.separator}' "
    " --jgi {input.coverage} "
    " --fasta {input.fasta} "
    "2> {log}"
228
229
script:
    "../scripts/parse_vamb.py"
134
135
136
137
138
139
run:
    shell(
        "wget -O {output} 'https://zenodo.org/record/{ZENODO_ARCHIVE}/files/{wildcards.filename}' "
    )
    if not FILES[wildcards.filename] == md5(output[0]):
        raise OSError(2, "Invalid checksum", output[0])
148
149
150
151
152
153
154
155
run:
    shell(
        "wget -O {output.tar} 'https://zenodo.org/record/{ZENODO_ARCHIVE}/files/{CHECKM_ARCHIVE}' "
    )
    if not FILES[CHECKM_ARCHIVE] == md5(output.tar):
        raise OSError(2, "Invalid checksum", CHECKM_ARCHIVE)

    shell("tar -zxf {output.tar} --directory {params.path}")
173
174
shell:
    "checkm data setRoot {params.database_dir} &> {log} "
191
192
shell:
    " wget {GTDB_DATA_URL} -O {output} &> {log} "
207
208
shell:
    'tar -xzvf {input} -C "{GTDBTK_DATA_PATH}" --strip 1 2> {log}; '
221
222
223
shell:
    " checkm2 database --download --path {output} "
    " &>> {log}"
238
239
240
shell:
    "gunc download_db {resources.tmpdir} -db {wildcards.gunc_database} &> {log} ;"
    "mv {resources.tmpdir}/gunc_db_{wildcards.gunc_database}*.dmnd {output} 2>> {log}"
254
255
shell:
    "busco -q --download_path {output} --download prokaryota &> {log}"
28
29
30
31
32
33
34
35
36
shell:
    " DRAM-setup.py prepare_databases "
    " --output_dir {output.dbdir} "
    " --threads {threads} "
    " --verbose "
    " --skip_uniref "
    " &> {log} "
    " ; "
    " DRAM-setup.py export_config --output_file {output.config}"
SnakeMake From line 28 of rules/dram.smk
53
54
shell:
    "DRAM-setup.py import_config --config_loc {input} &> {log}"
SnakeMake From line 53 of rules/dram.smk
78
79
80
81
82
83
84
85
shell:
    " DRAM.py annotate "
    " --input_fasta {input.fasta}"
    " --output_dir {output.outdir} "
    " --threads {threads} "
    " --min_contig_size {params.min_contig_size} "
    " {params.extra} "
    " --verbose &> {log}"
SnakeMake From line 78 of rules/dram.smk
110
111
112
113
114
115
116
117
118
119
120
run:
    from utils import io

    for i, annotation_file in enumerate(DRAM_ANNOTATON_FILES):
        input_files = [
            os.path.join(dram_folder, annotation_file) for dram_folder in input
        ]

        io.pandas_concat(
            input_files, output[i], sep="\t", index_col=0, axis=0, disk_based=True
        )
137
138
139
140
141
shell:
    " DRAM.py distill "
    " --input_file {input[0]}"
    " --output_dir {output} "
    "  &> {log}"
SnakeMake From line 137 of rules/dram.smk
157
158
script:
    "../scripts/DRAM_get_all_modules.py"
SnakeMake From line 157 of rules/dram.smk
19
20
script:
    "../scripts/filter_genes.py"
46
47
48
49
50
51
run:
    from utils.io import cat_files

    cat_files(input.faa, output.faa)
    cat_files(input.fna, output.fna)
    cat_files(input.short, output.short)
69
70
71
72
73
74
run:
    from utils.io import cat_files

    cat_files(input.faa, output.faa)
    cat_files(input.fna, output.fna)
    cat_files(input.short, output.short)
104
105
106
107
108
109
110
111
112
113
114
shell:
    """
    mkdir -p {params.tmpdir} {output} 2>> {log}
    mmseqs createdb {input.faa} {params.db} &> {log}

    mmseqs {params.clustermethod} -c {params.coverage} \
    --min-seq-id {params.minid} {params.extra} \
    --threads {threads} {params.db} {params.clusterdb} {params.tmpdir}  &>>  {log}

    rm -fr  {params.tmpdir} 2>> {log}
    """
132
133
134
135
136
137
138
139
140
141
142
shell:
    """
    mmseqs createtsv {params.db} {params.db} {params.clusterdb} {output.cluster_attribution}  &> {log}

    mkdir {output.rep_seqs_db} 2>> {log}

    mmseqs result2repseq {params.db} {params.clusterdb} {output.rep_seqs_db}/db  &>> {log}

    mmseqs result2flat {params.db} {params.db} {output.rep_seqs_db}/db {output.rep_seqs}  &>> {log}

    """
158
159
160
161
162
163
164
165
shell:
    " filterbyname.sh "
    " in={input.all}"
    " names={input.names}"
    " include=t"
    " out={output} "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
176
177
script:
    "../scripts/generate_orf_info.py"
207
208
script:
    "../scripts/rename_genecatalog.py"
224
225
shell:
    "stats.sh gcformat=4 gc={output} in={input} &> {log}"
237
238
wrapper:
    "v1.19.0/bio/minimap2/index"
252
253
shell:
    "cat {input} > {output} 2> {log}"
270
271
wrapper:
    "v1.19.0/bio/minimap2/aligner"
291
292
293
294
295
296
297
298
299
shell:
    " pileup.sh "
    " in={input.bam}"
    " covstats={output.covstats} "
    " rpkm={output.rpkm} "
    " secondary=t "
    " minmapq={params.minmapq} "
    " -Xmx{resources.java_mem}G "
    " 2> {log} "
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
run:
    try:
        import pandas as pd
        from utils.parsers_bbmap import read_pileup_coverage

        data = read_pileup_coverage(
            input[0],
            coverage_measure="Median_fold",
            other_columns=["Avg_fold", "Covered_percent", "Read_GC", "Std_Dev"],
        )
        data.index.name = "GeneName"
        data.sort_index(inplace=True)

        # rpkm = pd.read_csv(input[1],sep='\t',skiprows=4,usecols=["#Name","RPKM"],index_col=0).sort_index()

        data.reset_index().to_parquet(output[0])

    except Exception as e:
        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
376
377
script:
    "../scripts/combine_gene_coverages.py"
416
417
script:
    "../scripts/split_genecatalog.py"
457
458
459
460
461
462
shell:
    """
    emapper.py -m diamond --no_annot --no_file_comments \
        --data_dir {params.data_dir} --cpu {threads} -i {input.faa} \
        -o {params.prefix} --override 2> {log}
    """
492
493
494
495
496
497
498
499
500
501
502
503
shell:
    """

    if [ {params.copyto_shm} == "t" ] ;
    then
        cp {EGGNOG_DIR}/eggnog.db {params.data_dir}/eggnog.db 2> {log}
        cp {EGGNOG_DIR}/eggnog_proteins.dmnd {params.data_dir}/eggnog_proteins.dmnd 2>> {log}
    fi

    emapper.py --annotate_hits_table {input.seed} --no_file_comments \
      --override -o {params.prefix} --cpu {threads} --data_dir {params.data_dir} 2>> {log}
    """
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
run:
    try:

        import pandas as pd

        Tables = [
            pd.read_csv(file, index_col=None, header=None, sep="\t")
            for file in input
        ]

        combined = pd.concat(Tables, axis=0)

        del Tables

        combined.columns = EGGNOG_HEADER

        #           combined.sort_values("Gene",inplace=True)

        combined.to_parquet(output[0], index=False)
    except Exception as e:

        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
run:
    try:
        import pandas as pd

        df = pd.read_table(input[0], index_col=None)

        df.to_parquet(output[0], index=False)

    except Exception as e:

        import traceback

        with open(log[0], "w") as logfile:
            traceback.print_exc(file=logfile)

        raise e
607
608
609
610
611
612
613
614
615
616
617
shell:
    " rm -rf {params.outdir} &> {log[0]};"
    "\n"
    " DRAM.py annotate_genes "
    " --input_faa {input.faa}"
    " --config_loc {input.config} "
    " --output_dir {params.outdir} "
    " --threads {threads} "
    " {params.extra} "
    " --log_file_path {log[1]} "
    " --verbose &>> {log[0]}"
638
639
script:
    "../scripts/combine_dram_gene_annotations.py"
657
658
script:
    "../scripts/gene2genome.py"
718
719
720
721
722
723
724
725
726
727
728
729
shell:
    " DIR=$(dirname $(readlink -f $(which DAS_Tool))) "
    ";"
    " ruby {params.script_dir}/rules/scg_blank_diamond.rb diamond"
    " {input} "
    " $DIR\/db/{params.key}.all.faa "
    " $DIR\/db/{params.key}.scg.faa "
    " $DIR\/db/{params.key}.scg.lookup "
    " {threads} "
    " 2> {log} "
    " ; "
    " mv {input[0]}.scg {output}"
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    shell:
        " rm -r {params.working_dir} 2> {log}"
        ";"
        " dRep compare "
        " --genomes {input.paths} "
        " --S_ani {params.ANI} "
        " --S_algorithm fastANI "
        " --cov_thresh {params.overlap} "
        " --processors {threads} "
        " {params.greedy_options} "
        " {params.working_dir} "
        " &>> {log} "
"""


rule dereplicate:
    input:
        paths="Binning/{binner}/filtered_bins_paths.txt".format(
            binner=config["final_binner"]
        ),
        quality="Binning/{binner}/filtered_quality.csv".format(
            binner=config["final_binner"]
        ),
    output:
        genomes=temp(directory("genomes/Dereplication/dereplicated_genomes")),
        wdb="genomes/Dereplication/data_tables/Wdb.csv",
        tables="genomes/Dereplication/data_tables/Cdb.csv",
        bdb="genomes/Dereplication/data_tables/Bdb.csv",
    threads: config["large_threads"]
    log:
        "logs/genomes/dereplicate.log",
    conda:
        "../envs/dRep.yaml"
    params:
        # no filtering
        no_filer=" --length 100  --completeness 0 --contamination  100 ",
        ANI=get_drep_ani,
        greedy_options=get_greedy_drep_Arguments,
        overlap=config["genome_dereplication"]["overlap"],
        completeness_weight=config["genome_dereplication"]["score"]["completeness"],
        contamination_weight=config["genome_dereplication"]["score"]["contamination"],
        #not in table
        N50_weight=config["genome_dereplication"]["score"]["N50"],
        size_weight=config["genome_dereplication"]["score"]["length"],
        opt_parameters=config["genome_dereplication"]["opt_parameters"],
        working_dir=lambda wc, output: Path(output.genomes).parent,
    shell:
        " rm -r {params.working_dir} 2> {log}"
        ";"
        " dRep dereplicate "
        " {params.no_filer} "
        " --genomes {input.paths} "
        " --S_algorithm fastANI "
        " {params.greedy_options} "
        " --genomeInfo {input.quality} "
        " --S_ani {params.ANI} "
        " --cov_thresh {params.overlap} "
        " --completeness_weight {params.completeness_weight} "
        " --contamination_weight {params.contamination_weight} "
        " --N50_weight {params.N50_weight} "
        " --size_weight {params.size_weight} "
        " --processors {threads} "
        " --run_tertiary_clustering "
        " {params.opt_parameters} "
        " {params.working_dir} "
        " &> {log} "


localrules:
    parse_drep,


rule parse_drep:
    input:
        cdb="genomes/Dereplication/data_tables/Cdb.csv",
        bdb="genomes/Dereplication/data_tables/Bdb.csv",
        wdb="genomes/Dereplication/data_tables/Wdb.csv",
    output:
        "genomes/clustering/allbins2genome_oldname.tsv",
    run:
        import pandas as pd


        Cdb = pd.read_csv(input.cdb)
        Cdb.set_index("genome", inplace=True)

        Wdb = pd.read_csv(input.wdb)
        Wdb.set_index("cluster", inplace=True)
        genome2cluster = Cdb.secondary_cluster.map(Wdb.genome)


        genome2cluster = genome2cluster.to_frame().reset_index()
        genome2cluster.columns = ["Bin", "Rep"]

        # map to full paths
        file_paths = pd.read_csv(input.bdb, index_col=0).location
        for col in genome2cluster:
            genome2cluster[col + "_path"] = file_paths.loc[genome2cluster[col]].values

        # expected output is inverted columns
        genome2cluster[["Rep_path", "Bin_path"]].to_csv(
            output[0], sep="\t", index=False, header=False
        )


localrules:
    rename_genomes,


checkpoint rename_genomes:
    input:
        genomes="genomes/Dereplication/dereplicated_genomes",
        mapping_file="genomes/clustering/allbins2genome_oldname.tsv",
        genome_info=f"Binning/{config['final_binner']}/filtered_bin_info.tsv",
    output:
        dir=directory("genomes/genomes"),
        mapfile_contigs="genomes/clustering/contig2genome.tsv",
        mapfile_old2mag="genomes/clustering/old2newID.tsv",
        mapfile_allbins2mag="genomes/clustering/allbins2genome.tsv",
        genome_info="genomes/genome_quality.tsv",
    params:
        rename_contigs=config["rename_mags_contigs"],
    shadow:
        "shallow"
    log:
        "logs/genomes/rename_genomes.log",
    script:
        "../scripts/rename_genomes.py"


def get_genome_dir():
    if ("genome_dir" in config) and (config["genome_dir"] is not None):
        genome_dir = config["genome_dir"]
        assert os.path.exists(genome_dir), f"{genome_dir} Doesn't exists"

        logger.info(f"Set genomes from {genome_dir}.")

        # check if genomes are present
        genomes = glob_wildcards(os.path.join(genome_dir, "{genome}.fasta")).genome

        if len(genomes) == 0:
            logger.error(f"No genomes found with fasta extension in {genome_dir} ")
            exit(1)

    else:
        genome_dir = "genomes/genomes"

    return genome_dir


genome_dir = get_genome_dir()


def get_all_genomes(wildcards):
    global genome_dir

    if genome_dir == "genomes/genomes":
        checkpoints.rename_genomes.get()

    # check if genomes are present
    genomes = glob_wildcards(os.path.join(genome_dir, "{genome}.fasta")).genome

    if len(genomes) == 0:
        logger.error(
            f"No genomes found with fasta extension in {genome_dir} "
            "You don't have any Metagenome assembled genomes with sufficient quality. "
            "You may want to change the assembly, binning or filtering parameters. "
            "Or focus on the genecatalog workflow only."
        )
        exit(1)

    return genomes


rule get_contig2genomes:
    input:
        genome_dir,
    output:
        "genomes/clustering/contig2genome.tsv",
    run:
        from glob import glob

        fasta_files = glob(input[0] + "/*.f*")

        with open(output[0], "w") as out_contigs:
            for fasta in fasta_files:
                bin_name, ext = os.path.splitext(os.path.split(fasta)[-1])
                # if gz remove also fasta extension
                if ext == ".gz":
                    bin_name = os.path.splitext(bin_name)[0]

                # write names of contigs in mapping file
                with open(fasta) as f:
                    for line in f:
                        if line[0] == ">":
                            header = line[1:].strip().split()[0]
                            out_contigs.write(f"{header}\t{bin_name}\n")


# alternative way to get to contigs2genomes for quantification with external genomes


ruleorder: get_contig2genomes > rename_genomes


# rule predict_genes_genomes:
#     input:
#         dir= genomes_dir
#     output:
#         directory("genomes/annotations/genes")
#     conda:
#         "%s/prodigal.yaml" % CONDAENV
#     log:
#         "logs/genomes/prodigal.log"
#     shadow:
#         "shallow"
#     threads:
#         config.get("threads", 1)
#     script:
#         "predict_genes_of_genomes.py"


rule predict_genes_genomes:
    input:
        os.path.join(genome_dir, "{genome}.fasta"),
    output:
        fna="genomes/annotations/genes/{genome}.fna",
        faa="genomes/annotations/genes/{genome}.faa",
        gff=temp("genomes/annotations/genes/{genome}.gff"),
    conda:
        "%s/prodigal.yaml" % CONDAENV
    log:
        "logs/genomes/prodigal/{genome}.txt",
    threads: 1
    resources:
        mem=config["simplejob_mem"],
        time=config["runtime"]["simplejob"],
    shell:
        """
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
run:
    import pandas as pd


    Cdb = pd.read_csv(input.cdb)
    Cdb.set_index("genome", inplace=True)

    Wdb = pd.read_csv(input.wdb)
    Wdb.set_index("cluster", inplace=True)
    genome2cluster = Cdb.secondary_cluster.map(Wdb.genome)


    genome2cluster = genome2cluster.to_frame().reset_index()
    genome2cluster.columns = ["Bin", "Rep"]

    # map to full paths
    file_paths = pd.read_csv(input.bdb, index_col=0).location
    for col in genome2cluster:
        genome2cluster[col + "_path"] = file_paths.loc[genome2cluster[col]].values

    # expected output is inverted columns
    genome2cluster[["Rep_path", "Bin_path"]].to_csv(
        output[0], sep="\t", index=False, header=False
    )
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
run:
    from glob import glob

    fasta_files = glob(input[0] + "/*.f*")

    with open(output[0], "w") as out_contigs:
        for fasta in fasta_files:
            bin_name, ext = os.path.splitext(os.path.split(fasta)[-1])
            # if gz remove also fasta extension
            if ext == ".gz":
                bin_name = os.path.splitext(bin_name)[0]

            # write names of contigs in mapping file
            with open(fasta) as f:
                for line in f:
                    if line[0] == ">":
                        header = line[1:].strip().split()[0]
                        out_contigs.write(f"{header}\t{bin_name}\n")
336
337
shell:
    "cat {input} > {output}"
347
348
shell:
    "cat {input}/*{params.ext} > {output}"
366
367
wrapper:
    "v1.19.0/bio/minimap2/index"
384
385
wrapper:
    "v1.19.0/bio/minimap2/aligner"
401
402
wrapper:
    "v1.19.0/bio/bwa-mem2/index"
420
421
wrapper:
    "v1.19.0/bio/bwa-mem2/mem"
447
448
shell:
    "mv {input} {output} > {log}"
461
462
wrapper:
    "v1.19.0/bio/samtools/stats"
472
473
wrapper:
    "v1.19.1/bio/multiqc"
494
495
496
497
498
499
500
501
502
503
504
shell:
    "pileup.sh in={input.bam} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " covstats={output.covstats} "
    " fastaorf={input.orf} outorf={output.orf} "
    " concise=t "
    " physical=t "
    " minmapq={params.minmapq} "
    " bincov={output.bincov} "
    " 2> {log}"
529
530
script:
    "../scripts/combine_coverage_MAGs.py"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import os, sys
import shutil
import warnings

import pandas as pd
from Bio import SeqIO


def get_fasta_of_bins(cluster_attribution, contigs, out_folder):
    """
    Creates individual fasta files for each bin using the contigs fasta and the cluster attribution.

    input:
    - cluster attribution file:   tab seperated file of "contig_fasta_header    bin"
    - contigs:                    fasta file of contigs
    - out_prefix:                 output_prefix for bin fastas  {out_folder}/{binid}.fasta
    """
    # create outdir
    if os.path.exists(out_folder):
        shutil.rmtree(out_folder)
    os.makedirs(out_folder)

    CA = pd.read_csv(cluster_attribution, header=None, index_col=1, sep="\t")

    assert CA.shape[1] == 1, "File should have only two columns " + cluster_attribution
    CA = CA.iloc[:, 0]
    CA.index = CA.index.astype("str")
    # exclude cluster 0 which is unclustered at least for metabat
    CA = CA.loc[CA != "0"]

    contigs = SeqIO.to_dict(SeqIO.parse(contigs, "fasta"))

    for id in CA.index.unique():
        bin_contig_names = CA.loc[id]
        out_file = os.path.join(out_folder, "{id}.fasta".format(id=id))
        if type(bin_contig_names) == str:
            warnings.warn("single contig bin Bin: " + out_file)
            bin_contig_names = [bin_contig_names]
        bin_contigs = [contigs[c] for c in bin_contig_names]
        SeqIO.write(bin_contigs, out_file, "fasta")


if __name__ == "__main__":
    if "snakemake" not in globals():
        p = argparse.ArgumentParser()
        p.add_argument("--cluster-attribution")
        p.add_argument("--contigs")
        p.add_argument("--out-folder")
        args = vars(p.parse_args())
        get_fasta_of_bins(**args)
    else:
        with open(snakemake.log[0], "w") as log:
            sys.stderr = sys.stdout = log

            get_fasta_of_bins(
                snakemake.input.cluster_attribution,
                snakemake.input.contigs,
                snakemake.output[0],
            )
20
21
22
23
24
25
26
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk identify "
    "--genes --genome_dir {params.gene_dir} "
    " --out_dir {params.outdir} "
    "--extension {params.extension} "
    "--cpus {threads} &> {log[0]}"
42
43
44
45
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk align --identify_dir {params.outdir} --out_dir {params.outdir} "
    "--cpus {threads} &> {log[0]}"
67
68
69
70
71
72
73
74
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk classify --genome_dir {input.genome_dir} --align_dir {params.outdir} "
    " --mash_db {params.mashdir} "
    "--out_dir {params.outdir} "
    " --tmpdir {resources.tmpdir} "
    "--extension {params.extension} "
    "--cpus {threads} &> {log[0]}"
85
86
script:
    "../scripts/combine_taxonomy.py"
102
103
104
105
106
107
108
shell:
    'export GTDBTK_DATA_PATH="{GTDBTK_DATA_PATH}" ; '
    "gtdbtk infer --msa_file {input} "
    " --out_dir {params.outdir} "
    " --prefix {wildcards.msa} "
    " --cpus {threads} "
    "--tmpdir {resources.tmpdir} > {log[0]} 2> {log[1]}"
130
131
script:
    "../scripts/root_tree.py"
SnakeMake From line 130 of rules/gtdbtk.smk
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
shell:
    "reformat.sh "
    " {params.inputs} "
    " interleaved={params.interleaved} "
    " {params.outputs} "
    " {params.extra} "
    " overwrite=true "
    " verifypaired={params.verifypaired} "
    " threads={threads} "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
SnakeMake From line 92 of rules/qc.smk
130
131
script:
    "../scripts/get_read_stats.py"
SnakeMake From line 130 of rules/qc.smk
174
175
176
177
178
179
180
181
182
183
184
185
shell:
    "clumpify.sh "
    " {params.inputs} "
    " {params.outputs} "
    " overwrite=true"
    " dedupe=t "
    " dupesubs={params.dupesubs} "
    " optical={params.only_optical}"
    " threads={threads} "
    " pigz=t unpigz=t "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
SnakeMake From line 174 of rules/qc.smk
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
shell:
    " bbduk.sh {params.inputs} "
    " {params.ref} "
    " interleaved={params.interleaved} "
    " {params.outputs} "
    " stats={output.stats} "
    " overwrite=true "
    " qout=33 "
    " trd=t "
    " {params.hdist} "
    " {params.k} "
    " {params.ktrim} "
    " {params.mink} "
    " trimq={params.trimq} "
    " qtrim={params.qtrim} "
    " threads={threads} "
    " minlength={params.minlength} "
    " maxns={params.maxns} "
    " minbasefrequency={params.minbasefrequency} "
    " ecco={params.error_correction_pe} "
    " prealloc={params.prealloc} "
    " pigz=t unpigz=t "
    " -Xmx{resources.java_mem}G "
    " 2> {log}"
331
332
333
334
335
336
337
338
shell:
    "bbsplit.sh"
    " -Xmx{resources.java_mem}G "
    " {params.refs_in} "
    " threads={threads}"
    " k={params.k}"
    " local=t "
    " 2> {log}"
SnakeMake From line 331 of rules/qc.smk
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
shell:
    """
    if [ "{params.paired}" = true ] ; then
        bbsplit.sh in1={input[0]} in2={input[1]} \
            outu1={output[0]} outu2={output[1]} \
            basename="{params.contaminant_folder}/%_R#.fastq.gz" \
            maxindel={params.maxindel} minratio={params.minratio} \
            minhits={params.minhits} ambiguous={params.ambiguous} refstats={output.stats} \
            threads={threads} k={params.k} local=t \
            pigz=t unpigz=t ziplevel=9 \
            -Xmx{resources.java_mem}G 2> {log}
    fi

    bbsplit.sh in={params.input_single}  \
        outu={params.output_single} \
        basename="{params.contaminant_folder}/%_se.fastq.gz" \
        maxindel={params.maxindel} minratio={params.minratio} \
        minhits={params.minhits} ambiguous={params.ambiguous} refstats={output.stats} append=t \
        interleaved=f threads={threads} k={params.k} local=t \
        pigz=t unpigz=t ziplevel=9 \
        -Xmx{resources.java_mem}G 2>> {log}
    """
SnakeMake From line 386 of rules/qc.smk
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
run:
    import shutil
    import pandas as pd

    for i in range(len(MULTIFILE_FRACTIONS)):
        with open(output[i], "wb") as outFile:
            with open(input.clean_reads[i], "rb") as infile1:
                shutil.copyfileobj(infile1, outFile)
                if hasattr(input, "rrna_reads"):
                    with open(input.rrna_reads[i], "rb") as infile2:
                        shutil.copyfileobj(infile2, outFile)

    # append to sample table
    sample_table = load_sample_table(params.sample_table)
    qc_header = [f"Reads_QC_{fraction}" for fraction in MULTIFILE_FRACTIONS]
    sample_table.loc[wildcards.sample, qc_header] = output
    sample_table.to_csv(params.sample_table, sep="\t")
477
478
479
480
481
482
483
484
485
486
487
488
489
490
shell:
    " bbmerge.sh "
    " -Xmx{resources.java_mem}G "
    " threads={threads} "
    " {params.inputs} "
    " {params.flags} k={params.kmer} "
    " extend2={params.extend2} "
    " ihist={output.ihist} merge=f "
    " mininsert0=35 minoverlap0=8 "
    " prealloc=t prefilter=t "
    " minprob={params.minprob} 2> {log} \n  "
    """
    readlength.sh {params.inputs} out={output.read_length} 2>> {log}
    """
SnakeMake From line 477 of rules/qc.smk
511
512
513
514
shell:
    """
    readlength.sh in={input[0]} out={output.read_length} 2> {log}
    """
SnakeMake From line 511 of rules/qc.smk
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
run:
    import pandas as pd
    import os
    from utils.parsers_bbmap import parse_comments

    stats = pd.DataFrame()

    for length_file in input:
        sample = length_file.split(os.path.sep)[0]
        data = parse_comments(length_file)
        data = pd.Series(data)[
            ["Reads", "Bases", "Max", "Min", "Avg", "Median", "Mode", "Std_Dev"]
        ]
        stats[sample] = data

    stats.to_csv(output[0], sep="\t")
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
run:
    import pandas as pd
    import os
    from utils.parsers_bbmap import parse_comments

    stats = pd.DataFrame()

    for insert_file in input:
        sample = insert_file.split(os.path.sep)[0]
        data = parse_comments(insert_file)
        data = pd.Series(data)[
            ["Mean", "Median", "Mode", "STDev", "PercentOfPairs"]
        ]
        stats[sample] = data

    stats.T.to_csv(output[0], sep="\t")
612
613
614
615
616
617
618
619
run:
    import pandas as pd

    all_read_counts = pd.DataFrame()
    for read_stats_file in input.read_count_files:
        d = pd.read_csv(read_stats_file, index_col=[0, 1], sep="\t")
        all_read_counts = all_read_counts.append(d)
    all_read_counts.to_csv(output.read_stats, sep="\t")
630
631
632
633
634
635
636
637
638
639
run:
    import pandas as pd

    stats = pd.DataFrame()

    for f in input:
        d = pd.read_csv(f, index_col=[0, 1], sep="\t")
        stats = stats.append(d)

    stats.to_csv(output[0], sep="\t")
677
678
script:
    "../report/qc_report.py"
SnakeMake From line 677 of rules/qc.smk
25
26
27
28
29
30
31
32
shell:
    "SemiBin generate_sequence_features_multi"
    " --input-fasta {input.fasta} "
    " --input-bam {input.bams} "
    " --output {params.output_dir} "
    " --threads {threads} "
    " --separator {params.separator} "
    " 2> {log}"
57
58
59
60
61
62
63
64
shell:
    "SemiBin train_self "
    " --output {params.output_dir} "
    " --threads {threads} "
    " --data {input.data} "
    " --data-split {input.data_split} "
    " {params.extra} "
    " 2> {log}"
90
91
92
93
94
95
96
97
98
99
shell:
    "SemiBin bin "
    " --input-fasta {input.fasta} "
    " --output {params.output_dir} "
    " --threads {threads} "
    " --data {input.data} "
    " --model {input.model} "
    " --minfasta-kbs {params.min_bin_kbs}"
    " {params.extra} "
    " 2> {log}"
117
118
script:
    "../scripts/parse_semibin.py"
30
31
32
33
34
35
36
37
38
39
40
shell:
    " mkdir -p {params.outdir} 2> {log} "
    " ; "
    " prefetch "
    " --output-directory {params.outdir} "
    " -X 999999999 "
    " --progress "
    " --log-level info "
    " {wildcards.sra_run} &>> {log} "
    " ; "
    " vdb-validate {params.outdir}/{wildcards.sra_run}/{wildcards.sra_run}.sra &>> {log} "
SnakeMake From line 30 of rules/sra.smk
66
67
68
69
70
71
72
73
74
75
76
77
shell:
    " vdb-validate {params.sra_file} &>> {log} "
    " ; "
    " parallel-fastq-dump "
    " --threads {threads} "
    " --gzip --split-files "
    " --outdir {params.outdir} "
    " --tmpdir {resources.tmpdir} "
    " --skip-technical --split-3 "
    " -s {params.sra_file} &>> {log} "
    " ; "
    " rm -f {params.sra_file} 2>> {log} "
SnakeMake From line 66 of rules/sra.smk
123
124
125
126
127
128
129
run:
    from utils import io

    for i, fraction in enumerate(SRA_read_fractions):
        if fraction == "":
            fraction = "se"
        io.cat_files(input[fraction], output[i])
56
57
58
59
60
61
62
63
shell:
    "inStrain compare "
    " --input {input.profiles} "
    " -o {output} "
    " -p {threads} "
    " -s {input.scaffold_to_genome} "
    " --database_mode "
    " {params.extra} &> {log}"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logging.captureWarnings(True)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

import pandas as pd
from utils.parsers import read_checkm2_output


def main(samples, completeness_files, bin_table):
    sample_data = {}
    div = {}

    df = pd.DataFrame()

    for i, sample in enumerate(samples):
        sample_data = read_checkm2_output(completness_table=completeness_files[i])
        sample_data["Sample"] = sample

        df = df.append(sample_data)

    df.to_csv(bin_table, sep="\t")


if __name__ == "__main__":
    main(
        samples=snakemake.params.samples,
        completeness_files=snakemake.input.completeness_files,
        bin_table=snakemake.output.bin_table,
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from utils.parsers_bbmap import parse_pileup_log_file


def parse_map_stats(sample_data, out_tsv):
    stats_df = pd.DataFrame()
    for sample in sample_data.keys():
        df = pd.read_csv(sample_data[sample]["contig_stats"], sep="\t")
        assert df.shape[0] == 1, "Assumed only one row in file {}; found {}".format(
            sample_data[sample]["contig_stats"], df.iloc[0]
        )
        df = df.iloc[0]
        df.name = sample
        genes_df = pd.read_csv(sample_data[sample]["gene_table"], index_col=0, sep="\t")
        df["N_Predicted_Genes"] = genes_df.shape[0]

        mapping_stats = parse_pileup_log_file(sample_data[sample]["mapping_log"])

        df["Assembled_Reads"] = mapping_stats["Mapped reads"]
        df["Percent_Assembled_Reads"] = mapping_stats["Percent mapped"]

        stats_df = stats_df.append(df)
    stats_df = stats_df.loc[:, ~stats_df.columns.str.startswith("scaf_")]
    stats_df.columns = stats_df.columns.str.replace("ctg_", "")
    stats_df.to_csv(out_tsv, sep="\t")
    return stats_df


def main(samples, contig_stats, gene_tables, mapping_logs, combined_stats):
    sample_data = {}
    for sample in samples:
        sample_data[sample] = {}
        for c_stat in contig_stats:
            # underscore version was for simplified local testing
            # if "%s_" % sample in c_stat:
            if "%s/" % sample in c_stat:
                sample_data[sample]["contig_stats"] = c_stat
        for g_table in gene_tables:
            # if "%s_" % sample in g_table:
            if "%s/" % sample in g_table:
                sample_data[sample]["gene_table"] = g_table
        for mapping_log in mapping_logs:
            # if "%s_" % sample in mapping_log:
            if "%s/" % sample in mapping_log:
                sample_data[sample]["mapping_log"] = mapping_log

    parse_map_stats(sample_data, combined_stats)


if __name__ == "__main__":
    main(
        samples=snakemake.params.samples,
        contig_stats=snakemake.input.contig_stats,
        gene_tables=snakemake.input.gene_tables,
        mapping_logs=snakemake.input.mapping_logs,
        combined_stats=snakemake.output.combined_contig_stats,
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
import os, gc
from utils.parsers_bbmap import read_coverage_binned, combine_coverages


contig2genome = pd.read_csv(
    snakemake.input.contig2genome, header=None, index_col=0, sep="\t"
).iloc[:, 0]


# sum counts
logging.info("Loading counts and coverage per contig")

combined_cov, Counts_contigs = combine_coverages(
    snakemake.input.coverage_files, snakemake.params.samples
)

combined_cov = combined_cov.T

combined_cov.insert(
    0, "Genome", value=pd.Categorical(contig2genome.loc[combined_cov.index].values)
)

logging.info(f"Saving coverage to {snakemake.output.coverage_contigs}")

combined_cov.reset_index().to_parquet(snakemake.output.coverage_contigs)

logging.info("Sum counts per genome")

Counts_genome = Counts_contigs.groupby(contig2genome, axis=1).sum().T
Counts_genome.index.name = "Sample"

logging.info(f"Saving counts to {snakemake.output.counts}")

Counts_genome.reset_index().to_parquet(snakemake.output.counts)
del Counts_genome, combined_cov, Counts_contigs
gc.collect()

# Binned coverage
logging.info("Loading binned coverage")
binCov = {}
for i, cov_file in enumerate(snakemake.input.binned_coverage_files):
    sample = snakemake.params.samples[i]

    binCov[sample] = read_coverage_binned(cov_file)

binCov = pd.DataFrame.from_dict(binCov)

logging.info("Add genome information to it")
binCov.insert(
    0,
    "Genome",
    value=pd.Categorical(contig2genome.loc[binCov.index.get_level_values(0)].values),
)

gc.collect()
logging.info(f"Saving combined binCov to {snakemake.output.binned_cov}")
binCov.reset_index().to_parquet(snakemake.output.binned_cov)

# Median coverage
logging.info("Calculate median coverage")
Median_abund = binCov.groupby("Genome").median().T
del binCov
gc.collect()
logging.info(f"Saving mediuan coverage {snakemake.output.median_abund}")
Median_abund.reset_index().to_parquet(snakemake.output.median_abund)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

db_columns = {
    "kegg": ["ko_id", "kegg_hit"],
    "peptidase": [
        "peptidase_id",
        "peptidase_family",
        "peptidase_hit",
        "peptidase_RBH",
        "peptidase_identity",
        "peptidase_bitScore",
        "peptidase_eVal",
    ],
    "pfam": ["pfam_hits"],
    "cazy": ["cazy_ids", "cazy_hits", "cazy_subfam_ec", "cazy_best_hit"],
    # "heme": ["heme_regulatory_motif_count"],
}

Tables = defaultdict(list)

for file in snakemake.input:
    df = pd.read_csv(file, index_col=0, sep="\t")

    # drop un-annotated genes
    df = df.query("rank!='E'")

    # change index from 'subset1_Gene111' ->  simply 'Gene111'
    # Gene name to nr
    df.index = (
        df.index.str.split("_", n=1, expand=True)
        .get_level_values(1)
        .str[len("Gene") :]
        .astype(np.int64)
    )
    df.index.name = "GeneNr"

    # select columns, drop na rows and append to list
    for db in db_columns:
        cols = db_columns[db]

        if not df.columns.intersection(cols).empty:

            Tables[db].append(df[cols].dropna(axis=0, how="all"))

    del df

out_dir = Path(snakemake.output[0])
out_dir.mkdir()

for db in Tables:

    combined = pd.concat(Tables[db], axis=0)

    combined.sort_index(inplace=True)

    combined.reset_index().to_parquet(out_dir / (db + ".parquet"))
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of script
import numpy as np
import pandas as pd
import gc, os


import h5py

import h5py

import psutil


def measure_memory(write_log_entry=True):
    mem_uage = psutil.Process().memory_info().rss / (1024 * 1024)

    if write_log_entry:
        logging.info(f"The process is currently using {mem_uage: 7.0f} MB of RAM")

    return mem_uage


logging.info("Start")
measure_memory()

N_samples = len(snakemake.input.covstats)

logging.info("Read gene info")

gene_info = pd.read_table(snakemake.input.info)

# Gene name is only first part of first column
gene_info.index = gene_info["#Name"].str.split(" ", n=1, expand=True)[0]
gene_info.index.name = "GeneName"
gene_info.drop("#Name", axis=1, inplace=True)

gene_info.sort_index(inplace=True)
N_genes = gene_info.shape[0]
# gene_list= gene_info.index

# Sort
gene_info.sort_index(inplace=True)
N_genes = gene_info.shape[0]

gene_info[
    ["Samples_nz_coverage", "Samples_nz_counts", "Sum_coverage", "Max_coverage"]
] = 0


# gene_list= gene_info.index


logging.info("Open hdf files for writing")

gene_matrix_shape = (N_samples, N_genes)

with h5py.File(snakemake.output.cov, "w") as hdf_cov_file, h5py.File(
    snakemake.output.counts, "w"
) as hdf_counts_file:
    combined_cov = hdf_cov_file.create_dataset(
        "data", shape=gene_matrix_shape, fillvalue=0, compression="gzip"
    )
    combined_counts = hdf_counts_file.create_dataset(
        "data", shape=gene_matrix_shape, fillvalue=0, compression="gzip"
    )

    # add Smaple names attribute
    sample_names = np.array(list(snakemake.params.samples)).astype("S")
    combined_cov.attrs["sample_names"] = sample_names
    combined_counts.attrs["sample_names"] = sample_names

    gc.collect()

    Summary = {}

    logging.info("Start reading files")
    initial_mem_uage = measure_memory()

    for i, sample in enumerate(snakemake.params.samples):
        logging.info(f"Read coverage file for sample {i+1} / {N_samples}")
        sample_cov_file = snakemake.input.covstats[i]

        data = pd.read_parquet(
            sample_cov_file, columns=["GeneName", "Reads", "Median_fold"]
        ).set_index("GeneName")

        assert (
            data.shape[0] == N_genes
        ), f"I only have {data.shape[0]} /{N_genes} in the file {sample_cov_file}"

        # genes are not sorted :-()
        assert (
            data.index.is_monotonic_increasing
        ), f"data is not sorted by index in {sample_cov_file}"

        # downcast data
        # median is int
        Median_fold = pd.to_numeric(data.Median_fold, downcast="integer")
        Reads = pd.to_numeric(data.Reads, downcast="integer")

        # delete interminate data and release mem
        del data

        # get summary statistics per sample
        logging.debug("Extract Summary statistics")

        Summary[sample] = {
            "Sum_coverage": Median_fold.sum(),
            "Total_counts": Reads.sum(),
            "Genes_nz_counts": (Reads > 0).sum(),
            "Genes_nz_coverage": (Median_fold > 0).sum(),
        }

        # get gene wise stats
        gene_info["Samples_nz_counts"] += (Reads > 0) * 1
        gene_info["Samples_nz_coverage"] += (Median_fold > 0) * 1
        gene_info["Sum_coverage"] += Median_fold

        gene_info["Max_coverage"] = np.fmax(gene_info["Max_coverage"], Median_fold)

        combined_cov[i, :] = Median_fold.values
        combined_counts[i, :] = Reads.values

        del Median_fold, Reads
        gc.collect()

        current_mem_uage = measure_memory()


logging.info("All samples processed")
gc.collect()

logging.info("Save sample Summary")
pd.DataFrame(Summary).T.to_csv(snakemake.output.sample_info, sep="\t")


logging.info("Save gene Summary")

# downcast
for col in gene_info.columns:
    if col == "GC":
        gene_info[col] = pd.to_numeric(gene_info[col], downcast="float")
    else:
        gene_info[col] = pd.to_numeric(gene_info[col], downcast="integer")

gene_info.reset_index().to_parquet(snakemake.output.gene_info)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of scripts

import pandas as pd
import numpy as np
from utils.taxonomy import tax2table

from glob import glob

gtdb_classify_folder = snakemake.input.folder

taxonomy_files = glob(f"{gtdb_classify_folder}/gtdbtk.*.summary.tsv")

N_taxonomy_files = len(taxonomy_files)
logging.info(f"Found {N_taxonomy_files} gtdb taxonomy files.")

if (0 == N_taxonomy_files) or (N_taxonomy_files > 2):
    raise Exception(
        f"Found {N_taxonomy_files} number of taxonomy files 'gtdbtk.*.summary.tsv' in {gtdb_classify_folder} expect 1 or 2."
    )


DT = pd.concat([pd.read_table(file, index_col=0) for file in taxonomy_files], axis=0)

DT.to_csv(snakemake.output.combined)

Tax = tax2table(DT.classification, remove_prefix=True)
Tax.to_csv(snakemake.output.taxonomy, sep="\t")
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
import re


def main(jgi_file):
    # parsing input
    header = {}
    col2keep = ["contigName", "contigLen", "totalAvgDepth"]
    with open(jgi_file) as inF:
        for i, line in enumerate(inF):
            line = line.rstrip().split("\t")
            if i == 0:
                header = {x: ii for ii, x in enumerate(line)}
                col2keep += [x for x in line if x.endswith(".bam")]
                print("\t".join(col2keep))
                continue
            elif line[0] == "":
                continue
            # contig ID
            contig = line[header["contigName"]]
            # collect per-sample info
            out = []
            for col in col2keep:
                out.append(line[header[col]])
            print("\t".join(out))


if __name__ == "__main__":
    if "snakemake" in globals():
        with open(snakemake.log[0], "w") as log:
            sys.stderr = log

            with open(snakemake.output[0], "w") as outf:
                sys.stdout = outf

                main(snakemake.input[0])

    else:
        import argparse
        import logging

        logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.DEBUG)

        class CustomFormatter(
            argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter
        ):
            pass

        desc = (
            "Converting jgi_summarize_bam_contig_depths output to format used by VAMB"
        )
        epi = """DESCRIPTION:
        Output format: contigName<tab>contigLen<tab>totalAvgDepth<tab>SAMPLE1.sort.bam<tab>Sample2.sort.bam<tab>...
        Output written to STDOUT
        """
        parser = argparse.ArgumentParser(
            description=desc, epilog=epi, formatter_class=CustomFormatter
        )
        argparse.ArgumentDefaultsHelpFormatter
        parser.add_argument(
            "jgi_file",
            metavar="jgi_file",
            type=str,
            help="jgi_summarize_bam_contig_depths output table",
        )
        parser.add_argument("--version", action="version", version="0.0.1")

        args = parser.parse_args()
        main(args.jgi_file)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd

annotation_file = snakemake.input[0]
module_output_table = snakemake.output[0]

from mag_annotator.database_handler import DatabaseHandler
from mag_annotator.summarize_genomes import build_module_net, make_module_coverage_frame

annotations = pd.read_csv(annotation_file, sep="\t", index_col=0)


# get db_locs and read in dbs
database_handler = DatabaseHandler(logger=logging)


if "module_step_form" not in database_handler.config["dram_sheets"]:
    raise ValueError(
        "Module step form location must be set in order to summarize genomes"
    )

module_steps_form = pd.read_csv(
    database_handler.config["dram_sheets"]["module_step_form"], sep="\t"
)

all_module_nets = {
    module: build_module_net(module_df)
    for module, module_df in module_steps_form.groupby("module")
}

module_coverage_frame = make_module_coverage_frame(
    annotations, all_module_nets, groupby_column="fasta"
)

module_coverage_frame.to_csv(module_output_table, sep="\t")
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pyfastx


faa_iterator = pyfastx.Fastx(snakemake.input.faa, format="fasta")
fna_iterator = pyfastx.Fastx(snakemake.input.fna, format="fasta")


with open(snakemake.output.faa, "w") as out_faa, open(
    snakemake.output.fna, "w"
) as out_fna, open(snakemake.output.short, "w") as out_short:
    for name, seq, comment in fna_iterator:
        protein = next(faa_iterator)

        # include gene and corresponding protein if gene passes length threshold
        # or annotation contains prodigal info that it's complete
        if (len(seq) >= snakemake.params.minlength_nt) or ("partial=00" in comment):
            out_fna.write(f">{name} {comment}\n{seq}\n")
            out_faa.write(">{0} {2}\n{1}\n".format(*protein))

        else:
            out_short.write(">{0} {2}\n{1}\n".format(*protein))
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from glob import glob
from numpy import log

from utils.parsers import load_quality


Q = load_quality(snakemake.input.quality)

stats = pd.read_csv(snakemake.input.stats, index_col=0, sep="\t")
stats["logN50"] = log(stats.N50)

# merge table but only shared Bins and non overlapping columns
Q = Q.join(stats.loc[Q.index, stats.columns.difference(Q.columns)])
del stats

n_all_bins = Q.shape[0]

filter_criteria = snakemake.params["filter_criteria"]
logging.info(f"Filter genomes according to criteria:\n {filter_criteria}")


Q = Q.query(filter_criteria)

logging.info(f"Retain {Q.shape[0]} genomes from {n_all_bins}")


## GUNC

if hasattr(snakemake.input, "gunc"):

    gunc = pd.read_table(snakemake.input.gunc, index_col=0)
    gunc = gunc.loc[Q.index]

    bad_genomes = gunc.index[gunc["pass.GUNC"] == False]
    logging.info(f"{len(bad_genomes)} Don't pass gunc filtering")

    Q.drop(bad_genomes, inplace=True)
else:
    logging.info(" Don't filter based on gunc")


if Q.shape[0] == 0:
    logging.error(
        f"No bins passed filtering criteria! Bad luck!. You might want to tweek the filtering criteria. Also check the {snakemake.input.quality}"
    )
    exit(1)

# output Q together with quality
Q.to_csv(snakemake.output.info, sep="\t")


# output quality for derepliation
D = Q.copy()

D.index.name = "genome"

D.columns = D.columns.str.lower()
# fasta extension is needed even if otherwise stated https://github.com/MrOlm/drep/issues/169
D.index += ".fasta"
D = D[["completeness", "contamination"]]
D.to_csv(snakemake.output.quality_for_derep)

# filter path genomes

F = pd.read_table(snakemake.input.paths, index_col=0).squeeze()

F = F.loc[Q.index].iloc[:, 0]
F.to_csv(snakemake.output.paths, index=False, header=False)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

#### Begining of script

import pandas as pd
from utils import gene_scripts

# if MAGs are renamed I need to obtain the old contig names
# otherwise not
if snakemake.params.renamed_contigs:
    contigs2bins = pd.read_csv(
        snakemake.input.contigs2bins, index_col=0, squeeze=False, sep="\t", header=None
    )

    contigs2bins.columns = ["Bin"]
    old2newID = pd.read_csv(
        snakemake.input.old2newID, index_col=0, squeeze=True, sep="\t"
    )

    contigs2genome = contigs2bins.join(old2newID, on="Bin").dropna().drop("Bin", axis=1)
else:
    contigs2genome = pd.read_csv(
        snakemake.input.contigs2mags, index_col=0, squeeze=False, sep="\t", header=None
    )
    contigs2genome.columns = ["MAG"]

# load orf_info
orf_info = pd.read_parquet(snakemake.input.orf_info)


# recreate Contig name `Sample_ContigNr` and Gene names `Gene0004`
orf_info["Contig"] = orf_info.Sample + "_" + orf_info.ContigNr.astype(str)
orf_info["Gene"] = gene_scripts.geneNr_to_string(orf_info.GeneNr)

# Join genomes on contig
orf_info = orf_info.join(contigs2genome, on="Contig")

# remove genes not on genomes
orf_info = orf_info.dropna(axis=0)


# count genes per genome in a matrix
gene2genome = pd.to_numeric(
    orf_info.groupby(["Gene", "MAG"]).size(), downcast="unsigned"
).unstack(fill_value=0)

# save as parquet
gene2genome.reset_index().to_parquet(snakemake.output[0])
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


## Start

import pandas as pd
import numpy as np

from utils import gene_scripts

# CLuterID    GeneID    empty third column
orf2gene = pd.read_csv(
    snakemake.input.cluster_attribution, header=None, sep="\t", usecols=[0, 1]
)

orf2gene.columns = ["Representative", "ORF"]

# split orf names in sample, contig_nr, and orf_nr
orf_info = gene_scripts.split_orf_to_index(orf2gene.ORF)

# rename representative

representative_names = orf2gene.Representative.unique()

map_names = pd.Series(
    index=representative_names,
    data=np.arange(1, len(representative_names) + 1, dtype=np.uint),
)


orf_info["GeneNr"] = orf2gene.Representative.map(map_names)


orf_info.to_parquet(snakemake.output.cluster_attribution)


# Save name of representatives
map_names.index.name = "Representative"
map_names.name = "GeneNr"
map_names.to_csv(snakemake.output.rep2genenr, sep="\t")
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os, sys
import logging, traceback


logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


# begining of script

import datetime
import shutil
import os


timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%X")


def get_read_stats(fraction, params_in):
    "get read stats by running reformat.sh"

    from snakemake.shell import shell

    subfolder = os.path.join(snakemake.params.folder, fraction)
    tmp_file = os.path.join(subfolder, "read_stats.tmp")
    shell(
        f" mkdir -p {subfolder} 2>> {snakemake.log[0]} "
        " ; "
        f" reformat.sh {params_in} "
        f" bhist={subfolder}/base_hist.txt "
        f" qhist={subfolder}/quality_by_pos.txt "
        f" lhist={subfolder}/readlength.txt "
        f" gchist={subfolder}/gc_hist.txt "
        " gcbins=auto "
        f" bqhist={subfolder}/boxplot_quality.txt "
        f" threads={snakemake.threads} "
        " overwrite=true "
        f" -Xmx{snakemake.resources.java_mem}G "
        f" 2> >(tee -a {snakemake.log[0]} {tmp_file} ) "
    )
    content = open(tmp_file).read()
    pos = content.find("Input:")
    if pos == -1:
        raise Exception("Didn't find read number in file:\n\n" + content)
    else:
        content[pos:].split()[1:4]
        # Input:    123 reads   1234 bases
        n_reads, _, n_bases = content[pos:].split()[1:4]

        os.remove(tmp_file)
    return int(n_reads), int(n_bases)


if len(snakemake.input) >= 2:
    n_reads_pe, n_bases_pe = get_read_stats(
        "pe", "in1={0} in2={1}".format(*snakemake.input)
    )

    n_reads_pe = n_reads_pe / 2

    headers = [
        "Sample",
        "Step",
        "Total_Reads",
        "Total_Bases",
        "Reads_pe",
        "Bases_pe",
        "Reads_se",
        "Bases_se",
        "Timestamp",
    ]

    if os.path.exists(snakemake.params.single_end_file):
        n_reads_se, n_bases_se = get_read_stats(
            "se", "in=" + snakemake.params.single_end_file
        )
    else:
        n_reads_se, n_bases_se = 0, 0

    values = [
        n_reads_pe + n_reads_se,
        n_bases_pe + n_bases_se,
        n_reads_pe,
        n_bases_pe,
        n_reads_se,
        n_bases_se,
    ]
else:
    headers = [
        "Sample",
        "Step",
        "Total_Reads",
        "Total_Bases",
        "Reads",
        "Bases",
        "Timestamp",
    ]
    values = 2 * get_read_stats("", "in=" + snakemake.input[0])

with open(snakemake.output.read_counts, "w") as f:
    f.write("\t".join(headers) + "\n")
    f.write(
        "\t".join(
            [snakemake.wildcards.sample, snakemake.wildcards.step]
            + [str(v) for v in values]
            + [timestamp]
        )
        + "\n"
    )

shutil.make_archive(snakemake.params.folder, "zip", snakemake.params.folder)
shutil.rmtree(snakemake.params.folder)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

from utils.fasta import parse_fasta_headers
from utils.utils import gen_names_for_range
from glob import glob
import pandas as pd


fasta_files = glob(f"{snakemake.input[0]}/*{snakemake.params.extension}")

if len(fasta_files) > 0:
    Bin_names = gen_names_for_range(
        N=len(fasta_files), prefix=f"{snakemake.wildcards.sample}_SemiBin_"
    )

    mappings = []

    for bin_name, fasta in zip(Bin_names, fasta_files):
        contigs = parse_fasta_headers(fasta)

        mappings.append(pd.Series(data=bin_name, index=contigs))

    pd.concat(mappings, axis=0).to_csv(
        snakemake.output[0], sep="\t", header=False, index=True
    )

else:
    logging.warning(
        f"No bins found in {snakemake.input[0]} add longest contig as bin to make atlas continue."
    )

    with open(snakemake.output[0], "w") as outf:
        outf.write("{sample}_0\t{sample}_SemiBin_1\n".format(**snakemake.wildcards))
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import os, sys
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.DEBUG,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from utils.utils import gen_names_for_range
from utils.fasta import parse_fasta_headers

bin_dir = os.path.join(snakemake.input[0], "bins")
fasta_extension = snakemake.params.fasta_extension
separator = snakemake.params.separator

cluster_output_path = snakemake.params.output_path

vamb_cluster_file = os.path.join(snakemake.input[0], "clusters.tsv")
output_culsters = snakemake.output.renamed_clusters


big_bins = []

for file in os.listdir(bin_dir):
    bin_name, extension = os.path.splitext(file)

    logging.debug(f"Found file {bin_name} with extension {extension}")

    if extension == fasta_extension:
        big_bins.append(bin_name)


logging.info(
    f"Found {len(big_bins)} bins created by Vamb (above size limit)\n"
    f"E.g. {big_bins[:5]}"
)


logging.info(f"Load vamb cluster file {vamb_cluster_file}")
clusters_contigs = pd.read_table(vamb_cluster_file, header=None)

clusters_contigs.columns = ["OriginalName", "Contig"]


clusters = clusters_contigs.Contig.str.rsplit(separator, n=1, expand=True)
clusters.columns = ["Sample", "Contig"]

clusters["BinID"] = clusters_contigs.OriginalName.str.rsplit(
    separator, n=1, expand=True
)[1]
clusters["OriginalName"] = clusters_contigs.OriginalName

clusters["Large_enough"] = clusters.OriginalName.isin(big_bins)

del clusters_contigs

logging.info(f"Write reformated table to {output_culsters}")
clusters.to_csv(output_culsters, sep="\t", index=False)

clusters = clusters.query("Large_enough")

clusters["SampleBin"] = clusters.Sample + "_vamb_" + clusters.BinID

logging.info(f"Write cluster_attribution for samples")
for sample, cl in clusters.groupby("Sample"):
    sample_output_path = cluster_output_path.format(sample=sample)

    logging.debug(f"Write file {sample_output_path}")
    cl[["Contig", "SampleBin"]].to_csv(
        sample_output_path, sep="\t", index=False, header=False
    )


samples_without_bins = set(snakemake.params.samples).difference(set(clusters.Sample))

if len(samples_without_bins) > 0:
    logging.warning(
        "The following samples did't yield bins, I add longest contig to make the pipline continue:\n"
        + "\n".join(samples_without_bins)
    )

    for sample in samples_without_bins:
        sample_output_path = cluster_output_path.format(sample=sample)
        with open(sample_output_path, "w") as fout:
            fout.write(f"{sample}_0\t{sample}_vamb_1\n")
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


import pandas as pd
from utils.gene_scripts import geneNr_to_string


# Start


map_genenr = pd.read_csv(snakemake.input.rep2genenr, index_col=0, sep="\t").squeeze()


# from gene Nr to gene name
rep2gene = geneNr_to_string(map_genenr)

logging.info(
    f"Collect and rename representative genes according to:\n {rep2gene.head()}"
)

assert rep2gene.shape[0] > 0


with open(snakemake.output[0], "w") as fout:
    with open(snakemake.input.fasta, "r") as fin:
        for line in fin:
            if line[0] == ">":
                gene_name = line[1:].strip().split(" ")[0]

                gene_id = rep2gene.loc[gene_name]

                fout.write(f">{gene_id} {gene_name}\n")

            else:
                fout.write(line)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception


# start


from snakemake.io import glob_wildcards

from atlas import utils
import pandas as pd


# load mapping file  with old names
# tsv with format
# path/to/rep.fasta path/to/bin.fasta

mapping = pd.read_csv(
    snakemake.input.mapping_file, sep="\t", usecols=[0, 1], header=None
)
mapping.columns = ["Rep_path", "Bin_path"]

assert (
    mapping.Bin_path.is_unique
), "The second column of {snakemake.input.mapping_file} should be unique"

# go from path to id
mapping[["Representative", "Bin"]] = mapping.applymap(
    lambda x: os.path.basename(x).replace(".fasta", "")
)
mapping.set_index("Bin", inplace=True)


# standardize names of representatives
# MAG001 ....
representatives = mapping.Representative.unique()
old2new_name = dict(
    zip(representatives, utils.gen_names_for_range(len(representatives), prefix="MAG"))
)
mapping["MAG"] = mapping.Representative.map(old2new_name)


# write cluster attribution
mapping[["MAG", "Representative"]].to_csv(
    snakemake.output.mapfile_allbins2mag, sep="\t", header=True
)

# write out old2new ids
old2new = mapping.loc[representatives, "MAG"]
old2new.index.name = "Representative"
old2new.to_csv(snakemake.output.mapfile_old2mag, sep="\t", header=True)

#### Write genomes and contig to genome mapping file

output_dir = snakemake.output.dir
mapfile_contigs = snakemake.output.mapfile_contigs
rename_contigs = snakemake.params.rename_contigs


os.makedirs(output_dir)

with open(mapfile_contigs, "w") as out_contigs:
    for rep, row in mapping.loc[representatives].iterrows():
        fasta_in = row.Rep_path
        new_name = row.MAG

        fasta_out = os.path.join(output_dir, f"{row.MAG}.fasta")

        # write names of contigs in mapping file
        with open(fasta_in) as ffi, open(fasta_out, "w") as ffo:
            Nseq = 0
            for line in ffi:
                # if header line
                if line[0] == ">":
                    Nseq += 1

                    if rename_contigs:
                        new_header = f"{row.MAG}_{Nseq}"
                    else:
                        new_header = line[1:].strip().split()[0]

                    # write to contig to mapping file
                    out_contigs.write(f"{new_header}\t{row.MAG}\n")
                    # write to fasta file
                    ffo.write(f">{new_header}\n")
                else:
                    ffo.write(line)


def rename_quality(quality_in, quality_out, old2new_name):
    Q = pd.read_csv(quality_in, index_col=0, sep="\t")

    Q = Q.loc[old2new_name.keys()].rename(index=old2new_name)

    Q.to_csv(quality_out, sep="\t")


rename_quality(
    quality_in=snakemake.input.genome_info,
    quality_out=snakemake.output.genome_info,
    old2new_name=old2new_name,
)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

# start
import ete3

T = ete3.Tree(snakemake.input.tree, quoted_node_names=True, format=1)

try:
    T.unroot()
    if len(T) > 2:
        T.set_outgroup(T.get_midpoint_outgroup())

except Exception as e:
    logging.error("Failed to root tree, keep unrooted. Reason was:\n\n" + str(e))


T.write(outfile=snakemake.output.tree)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys, os
import logging, traceback

logging.basicConfig(
    filename=snakemake.log[0],
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logging.error(
        "".join(
            [
                "Uncaught exception: ",
                *traceback.format_exception(exc_type, exc_value, exc_traceback),
            ]
        )
    )


# Install exception handler
sys.excepthook = handle_exception

## start


from utils import fasta

fasta.split(
    snakemake.input[0],
    snakemake.params.subset_size,
    snakemake.output[0],
    simplify_headers=True,
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
__author__ = "Christopher Schröder, Patrik Smeds"
__copyright__ = "Copyright 2020, Christopher Schröder, Patrik Smeds"
__email__ = "[email protected], [email protected]"
__license__ = "MIT"

from os import path
from snakemake.shell import shell

log = snakemake.log_fmt_shell(stdout=True, stderr=True)

# Check inputs/arguments.
if len(snakemake.input) == 0:
    raise ValueError("A reference genome has to be provided.")
elif len(snakemake.input) > 1:
    raise ValueError("Please provide exactly one reference genome as input.")

valid_suffixes = {".0123", ".amb", ".ann", ".bwt.2bit.64", ".pac"}


def get_valid_suffix(path):
    for suffix in valid_suffixes:
        if path.endswith(suffix):
            return suffix


prefixes = set()
for s in snakemake.output:
    suffix = get_valid_suffix(s)
    if suffix is None:
        raise ValueError(f"{s} cannot be generated by bwa-mem2 index (invalid suffix).")
    prefixes.add(s[: -len(suffix)])

if len(prefixes) != 1:
    raise ValueError("Output files must share common prefix up to their file endings.")
(prefix,) = prefixes

shell("bwa-mem2 index -p {prefix} {snakemake.input[0]} {log}")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
__author__ = "Christopher Schröder, Johannes Köster, Julian de Ruiter"
__copyright__ = (
    "Copyright 2020, Christopher Schröder, Johannes Köster and Julian de Ruiter"
)
__email__ = "[email protected] [email protected], [email protected]"
__license__ = "MIT"


import tempfile
from os import path
from snakemake.shell import shell
from snakemake_wrapper_utils.java import get_java_opts
from snakemake_wrapper_utils.samtools import get_samtools_opts


# Extract arguments.
extra = snakemake.params.get("extra", "")
log = snakemake.log_fmt_shell(stdout=False, stderr=True)
sort = snakemake.params.get("sort", "none")
sort_order = snakemake.params.get("sort_order", "coordinate")
sort_extra = snakemake.params.get("sort_extra", "")
samtools_opts = get_samtools_opts(snakemake)
java_opts = get_java_opts(snakemake)


index = snakemake.input.get("index", "")
if isinstance(index, str):
    index = path.splitext(snakemake.input.idx)[0]
else:
    index = path.splitext(snakemake.input.idx[0])[0]


# Check inputs/arguments.
if not isinstance(snakemake.input.reads, str) and len(snakemake.input.reads) not in {
    1,
    2,
}:
    raise ValueError("input must have 1 (single-end) or 2 (paired-end) elements")

if sort_order not in {"coordinate", "queryname"}:
    raise ValueError(f"Unexpected value for sort_order ({sort_order})")

# Determine which pipe command to use for converting to bam or sorting.
if sort == "none":
    # Simply convert to bam using samtools view.
    pipe_cmd = "samtools view {samtools_opts}"

elif sort == "samtools":
    # Sort alignments using samtools sort.
    pipe_cmd = "samtools sort {samtools_opts} {sort_extra} -T {tmpdir}"

    # Add name flag if needed.
    if sort_order == "queryname":
        sort_extra += " -n"

elif sort == "picard":
    # Sort alignments using picard SortSam.
    pipe_cmd = "picard SortSam {java_opts} {sort_extra} --INPUT /dev/stdin --TMP_DIR {tmpdir} --SORT_ORDER {sort_order} --OUTPUT {snakemake.output[0]}"

else:
    raise ValueError(f"Unexpected value for params.sort ({sort})")


with tempfile.TemporaryDirectory() as tmpdir:
    shell(
        "(bwa-mem2 mem"
        " -t {snakemake.threads}"
        " {extra}"
        " {index}"
        " {snakemake.input.reads}"
        " | " + pipe_cmd + ") {log}"
    )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__author__ = "Tom Poorten"
__copyright__ = "Copyright 2017, Tom Poorten"
__email__ = "[email protected]"
__license__ = "MIT"


from os import path
from snakemake.shell import shell
from snakemake_wrapper_utils.samtools import infer_out_format
from snakemake_wrapper_utils.samtools import get_samtools_opts


samtools_opts = get_samtools_opts(snakemake, parse_output=False)
extra = snakemake.params.get("extra", "")
log = snakemake.log_fmt_shell(stdout=False, stderr=True)
sort = snakemake.params.get("sorting", "none")
sort_extra = snakemake.params.get("sort_extra", "")

out_ext = infer_out_format(snakemake.output[0])

pipe_cmd = ""
if out_ext != "PAF":
    # Add option for SAM output
    extra += " -a"

    # Determine which pipe command to use for converting to bam or sorting.
    if sort == "none":

        if out_ext != "SAM":
            # Simply convert to output format using samtools view.
            pipe_cmd = f"| samtools view -h {samtools_opts}"

    elif sort in ["coordinate", "queryname"]:

        # Add name flag if needed.
        if sort == "queryname":
            sort_extra += " -n"

        # Sort alignments.
        pipe_cmd = f"| samtools sort {sort_extra} {samtools_opts}"

    else:
        raise ValueError(f"Unexpected value for params.sort: {sort}")


shell(
    "(minimap2"
    " -t {snakemake.threads}"
    " {extra} "
    " {snakemake.input.target}"
    " {snakemake.input.query}"
    " {pipe_cmd}"
    " > {snakemake.output[0]}"
    ") {log}"
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
__author__ = "Tom Poorten"
__copyright__ = "Copyright 2017, Tom Poorten"
__email__ = "[email protected]"
__license__ = "MIT"

from snakemake.shell import shell

extra = snakemake.params.get("extra", "")
log = snakemake.log_fmt_shell(stdout=True, stderr=True)

shell(
    "(minimap2 -t {snakemake.threads} {extra} "
    "-d {snakemake.output[0]} {snakemake.input.target}) {log}"
)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
__author__ = "Julian de Ruiter"
__copyright__ = "Copyright 2017, Julian de Ruiter"
__email__ = "[email protected]"
__license__ = "MIT"


from snakemake.shell import shell
from snakemake_wrapper_utils.samtools import get_samtools_opts


bed = snakemake.input.get("bed", "")
if bed:
    bed = "-t " + bed

samtools_opts = get_samtools_opts(
    snakemake, parse_write_index=False, parse_output=False, parse_output_format=False
)


extra = snakemake.params.get("extra", "")
region = snakemake.params.get("region", "")
log = snakemake.log_fmt_shell(stdout=False, stderr=True)


shell(
    "samtools stats {samtools_opts} {extra} {snakemake.input.bam} {bed} {region} > {snakemake.output} {log}"
)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
__author__ = "Julian de Ruiter"
__copyright__ = "Copyright 2017, Julian de Ruiter"
__email__ = "[email protected]"
__license__ = "MIT"


from os import path

from snakemake.shell import shell


extra = snakemake.params.get("extra", "")
# Set this to False if multiqc should use the actual input directly
# instead of parsing the folders where the provided files are located
use_input_files_only = snakemake.params.get("use_input_files_only", False)

if not use_input_files_only:
    input_data = set(path.dirname(fp) for fp in snakemake.input)
else:
    input_data = set(snakemake.input)

output_dir = path.dirname(snakemake.output[0])
output_name = path.basename(snakemake.output[0])
log = snakemake.log_fmt_shell(stdout=True, stderr=True)

shell(
    "multiqc"
    " {extra}"
    " --force"
    " -o {output_dir}"
    " -n {output_name}"
    " {input_data}"
    " {log}"
)
ShowHide 139 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/HamzaMbareche/MAGs_IBD
Name: mags_ibd
Version: v2.16.2
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: BSD 3-Clause "New" or "Revised" License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...