StaG Metagenomic Workflow Collaboration

public public 1yr ago Version: v0.7.0 0 bookmarks

StaG Metagenomic Workflow Collaboration (mwc)

StaG mwc logo

The StaG Metagenomic Workflow Collaboration (mwc) project focuses on providing a metagenomics analysis workflow suitable for microbiome research and general metagenomics analyses.

Please visit https://stag-mwc.readthedocs.io for the full documentation.

Usage

Step 0: Install conda and Snakemake

Conda and Snakemake are required to be able to use StaG-mwc. Most people would probably want to install Miniconda and install Snakemake into their base environment. When running StaG with the --use-conda or --use-singularity flags, all dependencies are managed automatically. If using conda it will automatically install the required versions of all tools required to run StaG-mwc. There is no need to combine the conda and singularity flags: the Singularity images used by the workflow already contain all required dependencies.

Step 1: Clone workflow

To use StaG-mwc, you need a local copy of the workflow repository. Start by making a clone of the repository:

git clone git@github.com:ctmrbio/stag-mwc

If you use StaG-mwc in a publication, please credit the authors by citing either the URL of this repository or the project's DOI. Also, don't forget to cite the publications of the other tools used in your workflow.

Step 2: Configure workflow

Configure the workflow according to your needs by editing the file config/config.yaml . The most common changes include setting the paths to input and output folders, and configuring what steps of the workflow should be included when running the workflow.

Step 3: Execute workflow

Test your configuration by performing a dry-run via

snakemake --use-conda -n

Execute the workflow locally via

snakemake --use-conda --cores N

This will run the workflow locally using N cores. It is also possible to run it in a cluster environment by using one of the available profiles, or creating your own, e.g. to run on CTMR's Gandalf cluster:

snakemake --profile profiles/ctmr_gandalf

Make sure you specify the Slurm account and partition in profiles/ctmr_gandalf/config.yaml . Refer to the official Snakemake documentation for further details on how to run Snakemake workflows on other types of cluster resources.

Note that in all examples above, --use-conda can be replaced with --use-singularity to run in Singularity containers instead of using a locally installed conda. Read more about it under the Running section in the docs.

Testing

A very basic continuous integration test is currently in place. It merely validates the syntax by trying to let Snakemake build the dependency graph if all outputs are activated. Suggestions for how to improve the automated testing of StaG-mwc are very welcome!

Contributing

Refer to the contributing guidelines in CONTRIBUTING.md for instructions on how to contribute to StaG-mwc.

If you intend to modify or further develop this workflow, you are welcome to fork this reposity. Please consider sharing potential improvements via a pull request.

Citing

If you find StaG-mwc useful in your research, please cite the Zenodo DOI: https://zenodo.org/badge/latestdoi/125840716

Logo attribution

Animal vector created by Patrickss - Freepik.com

Code Snippets

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
shell:
    """
    TEMPDIR="{resources.tmpdir}/{wildcards.sample}"
    mkdir -pv $TEMPDIR >> {log.stdout}
    mkdir -pv {params.outdir} >> {log.stdout}
    cat {input.read1} {input.read2} > $TEMPDIR/{wildcards.sample}_concat.fq.gz

    humann \
        --input $TEMPDIR/{wildcards.sample}_concat.fq.gz \
        --output $TEMPDIR \
        --nucleotide-database {params.nucleotide_db} \
        --protein-database {params.protein_db} \
        --output-basename {wildcards.sample} \
        --threads {threads} \
        --taxonomic-profile {input.taxonomic_profile} \
        {params.extra} \
        >> {log.stdout} \
        2> {log.stderr}

    cp $TEMPDIR/{wildcards.sample}*.tsv {params.outdir}
    cp $TEMPDIR/{wildcards.sample}_humann_temp/{wildcards.sample}.log {log.log}
    rm -rfv $TEMPDIR >> {log.stdout}
    """
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
shell:
    """
    humann_renorm_table \
        --input {input.genefamilies} \
        --output {output.genefamilies} \
        --units {params.method} \
        --mode {params.mode} \
        > {log.stdout} \
        2> {log.stderr}

    humann_renorm_table \
        --input {input.pathabundance} \
        --output {output.pathabundance} \
        --units {params.method} \
        --mode {params.mode} \
        >> {log.stdout} \
        2>> {log.stderr}
    """
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
shell:
    """
    humann_join_tables \
        --input {params.output_dir} \
        --output {output.genefamilies} \
        --file_name {params.genefamilies} \
        > {log.stdout} \
        2> {log.stderr}

    humann_join_tables \
        --input {params.output_dir} \
        --output {output.pathcoverage} \
        --file_name {params.pathcoverage} \
        >> {log.stdout} \
        2>> {log.stderr}

    humann_join_tables \
        --input {params.output_dir} \
        --output {output.pathabundance} \
        --file_name {params.pathabundance} \
        >> {log.stdout} \
        2>> {log.stderr}
    """
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
shell:
    """
    bbmap.sh \
        threads={threads} \
        minid={params.min_id} \
        path={params.db_path} \
        in1={input.read1} \
        in2={input.read2} \
        out={output.sam} \
        covstats={output.covstats} \
        rpkm={output.rpkm} \
        bamscript={output.bamscript} \
        {params.extra} \
        > {log.stdout} \
        2> {log.stderr}

    sed -i 's/_sorted//g' {output.bamscript}

    ./{output.bamscript} 2>> {log.stderr} >> {log.stdout}
    """
134
135
136
137
138
139
140
141
142
shell:
    """
    workflow/scripts/make_count_table.py \
        --annotation-file {params.annotations} \
        --columns {params.columns} \
        --outdir {params.outdir} \
        {input} \
        2> {log}
    """
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
shell:
    """
    featureCounts \
        -a {params.annotations} \
        -o {output.counts} \
        -t {params.feature_type} \
        -g {params.attribute_type} \
        -T {threads} \
        {params.extra} \
        {input.bams} \
        > {log} \
        2>> {log}
    cut \
        -f1,7- \
        {output.counts}  \
        | sed '1d' \
        | sed 's|\t\w\+/bbmap/{params.dbname}/|\t|g' \
        > {output.counts_table}
    """
75
76
wrapper:
    "0.23.1/bio/bowtie2/align"
 97
 98
 99
100
101
102
103
104
shell:
    """
    pileup.sh \
        in={input.bam} \
        out={output.covstats} \
        rpkm={output.rpkm} \
        2> {log}
    """
138
139
140
141
142
143
144
145
146
shell:
    """
    workflow/scripts/make_count_table.py \
        --annotation-file {params.annotations} \
        --columns {params.columns} \
        --outdir {params.outdir} \
        {input} \
        2> {log}
    """
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
shell:
    """
    featureCounts \
        -a {params.annotations} \
        -o {output.counts} \
        -t {params.feature_type} \
        -g {params.attribute_type} \
        -T {threads} \
        {params.extra} \
        {input.bams} \
        > {log} \
        2>> {log} \
    && \
    cut \
        -f1,7- \
        {output.counts}  \
        | sed '1d' \
        | sed 's|\t\w\+/bowtie2/{params.dbname}/|\t|g' \
        > {output.counts_table}
    """
26
27
28
29
30
31
32
shell:
    """
    multiqc {OUTDIR} \
        --filename {output.report} \
        --force \
        2> {log}
    """
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
shell:
    """
    bbcountunique.sh \
      in={input} \
      out={output.txt} \
      interval={params.interval} \
      > {log.stdout} \
      2> {log.stderr}

    workflow/scripts/plot_bbcountunique.py \
      {output.txt} \
      {output.pdf} \
      >> {log.stdout} \
      2>> {log.stderr}
    """
35
36
37
38
39
40
41
42
shell:
    """
    sketch.sh \
        in={input} \
        out={output} \
        name0={wildcards.sample} \
        2> {log}
    """
60
61
62
63
64
65
66
67
68
shell:
    """
    comparesketch.sh \
        format=3 \
        out={output} \
        alltoall \
        {input} \
        2> {log}
    """
87
88
89
90
91
92
93
94
95
shell:
    """
    workflow/scripts/plot_sketch_comparison_heatmap.py \
        --outfile {output.heatmap} \
        --clustered {output.clustered} \
        {input} \
        > {log.stdout} \
        2> {log.stderr}
    """
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
shell:
    """
    kraken2 \
        --db {params.db} \
        --threads {threads} \
        --output {output.kraken} \
        --classified-out {params.classified} \
        --unclassified-out {params.unclassified} \
        --report  {output.kreport} \
        --paired \
        --confidence {params.confidence} \
        {params.extra} \
        {input.read1} {input.read2} \
        2> {log.stderr}
    pigz \
        --processes {threads} \
        --verbose \
        --force \
        {params.fq_to_compress} \
        2>> {log.stderr}
    """
114
115
116
117
118
119
120
121
122
shell:
    """
    workflow/scripts/plot_proportion_kraken2.py \
        {input} \
        --histogram {output.histogram} \
        --barplot {output.barplot} \
        --table {output.txt} \
        2>&1 > {log}
    """
162
163
164
165
166
167
168
169
170
171
shell:
    """
    bowtie2 \
        --threads {threads} \
        -x {params.db_path} \
        -1 {input.read1} \
        -2 {input.read2} \
        -S {output.sam} \
        2> {log.stderr}
    """
187
188
189
190
191
192
193
194
195
shell:
    """
    samtools view \
        -b \
        --threads {threads} \
        {input.sam} \
        -o {output.bam} \
        2> {log.stderr}
    """
211
212
213
214
215
216
217
218
219
220
221
shell:
    """
    samtools view \
        -b \
        -f 13 \
        -F 256 \
        --threads {threads} \
        {input.bam2} \
        -o {output.unmapped} \
        2> {log.stderr}
    """
237
238
239
240
241
242
243
244
245
246
shell: 
    """
    samtools sort \
        -n \
        -m 5G \
        --threads {threads} \
        {input.pairs} \
        -o {output.sorted} \
        2> {log.stderr}
    """
263
264
265
266
267
268
269
270
271
272
273
274
shell: 
    """
    samtools fastq \
        --threads {threads} \
        -1 {output.read1} \
        -2 {output.read2} \
        -0 /dev/null \
        -s /dev/null \
        -n \
        {input.sorted_pairs} \
        2> {log.stderr}
    """
303
304
305
306
307
shell:
    """
    ln -sv $(readlink -f {input.read1}) {output.read1} >> {log.stderr}
    ln -sv $(readlink -f {input.read2}) {output.read2} >> {log.stderr}
    """
38
39
40
41
42
43
44
45
46
shell:
    """
    workflow/scripts/preprocessing_summary.py \
        {params.fastp_arg} \
        {params.kraken2_arg} \
        {params.bowtie2_arg} \
        --output-table {output.table} \
        > {log.stdout}
    """
40
41
42
43
44
45
46
47
48
49
50
51
52
53
shell:
    """
    fastp \
        --in1 {input.read1} \
        --in2 {input.read2} \
        --out1 {output.read1} \
        --out2 {output.read2} \
        --json {output.json} \
        --html {output.html} \
        --thread {threads} \
        {params.extra} \
        > {log.stdout} \
        2> {log.stderr}
    """
75
76
77
78
79
shell:
    """
    ln -sv $(readlink -f {input.read1}) {output.read1} >> {log.stderr}
    ln -sv $(readlink -f {input.read2}) {output.read2} >> {log.stderr}
    """
63
64
65
66
67
68
69
70
71
72
shell:
    """
    kaiju \
        -z {threads} \
        -t {params.nodes} \
        -f {params.db} \
        -i {input.read1} \
        -j {input.read2} \
        -o {output.kaiju} > {log}
    """
90
91
92
93
94
95
96
97
98
shell:
    """
    kaiju2krona \
        -t {params.nodes} \
        -n {params.names} \
        -i {input.kaiju} \
        -o {output.krona} \
        -u
    """
113
114
115
116
117
118
shell:
    """
    ktImportText \
        -o {output.krona_html} \
        {input}
    """
136
137
138
139
140
141
142
143
144
145
146
shell:
    """
    kaiju2table \
        -t {params.nodes} \
        -n {params.names} \
        -r {wildcards.level} \
        -l superkingdom,phylum,class,order,family,genus,species \
        -o {output} \
        {input.kaiju} \
        2>&1 > {log}
    """
167
168
169
170
171
172
173
174
175
shell:
    """
    workflow/scripts/join_tables.py \
        --feature-column {params.feature_column} \
        --value-column {params.value_column} \
        --outfile {output} \
        {input} \
        2>&1 > {log}
    """
191
192
193
194
195
196
197
198
shell:
    """
    workflow/scripts/area_plot.py \
        --table {input} \
        --output {output} \
        --mode kaiju \
        2>&1 > {log}
    """
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
shell:
    """
    kraken2 \
        --db {params.db} \
        --confidence {params.confidence} \
        --minimum-hit-groups {params.minimum_hit_groups} \
        --threads {threads} \
        --output {output.kraken} \
        --report {output.kreport} \
        --use-names \
        --paired \
        {input.read1} {input.read2} \
        {params.extra} \
        2> {log}
    """
109
110
111
112
113
114
115
116
117
shell:
    """
    workflow/scripts/KrakenTools/kreport2mpa.py \
        --report-file {input.kreport} \
        --output {output.txt} \
        --display-header \
        2>&1 > {log}
    sed --in-place 's|{input.kreport}|taxon_name\treads|g' {output.txt}
    """
137
138
139
140
141
142
143
144
145
shell:
    """
    workflow/scripts/join_tables.py \
        --outfile {output.table} \
        --value-column {params.value_column} \
        --feature-column '{params.feature_column}' \
        {input.txt} \
        2>&1 > {log}
    """
161
162
163
164
165
166
167
168
shell:
    """
    workflow/scripts/area_plot.py \
        --table {input} \
        --output {output} \
        --mode kraken2 \
        2>&1 > {log}
    """
186
187
188
189
190
191
192
193
shell:
    """
    workflow/scripts/KrakenTools/combine_kreports.py \
        --output {output} \
        --report-files {input.kreports} \
        2>> {log} \
        >> {log}
    """
210
211
212
213
214
215
216
shell:
    """
    workflow/scripts/KrakenTools/kreport2krona.py \
        --report-file {input.kreport} \
        --output {output} \
        2> {log}
    """
232
233
234
235
236
237
    shell:
        """
		ktImportText \
			-o {output.krona_html} \
			{input}
        """
291
292
293
294
295
296
297
298
299
300
301
shell:
    """
    est_abundance.py \
        --input {input.kreport} \
        --kmer_distr {params.kmer_distrib} \
        --output {output.bracken} \
        --out-report {output.bracken_kreport} \
        --level S \
        --thresh {params.thresh} \
        2>&1 > {log}
    """
322
323
324
325
326
327
328
329
330
331
shell:
    """
    est_abundance.py \
        --input {input.kreport} \
        --kmer_distr {params.kmer_distrib} \
        --output {output.bracken} \
        --level {wildcards.level} \
        --thresh {params.thresh} \
        2>&1 > {log}
    """
346
347
348
349
350
351
352
353
354
shell:
    """
    workflow/scripts/KrakenTools/kreport2mpa.py \
        --report-file {input.kreport} \
        --output {output.txt} \
        --display-header \
        2>&1 > {log}
    sed --in-place 's|{input.kreport}|taxon_name\treads|g' {output.txt}
    """
374
375
376
377
378
379
380
381
382
shell:
    """
    workflow/scripts/join_tables.py \
        --outfile {output.table} \
        --value-column {params.value_column} \
        --feature-column {params.feature_column} \
        {input.txt} \
        2>&1 > {log}
    """
398
399
400
401
402
403
404
405
shell:
    """
    workflow/scripts/area_plot.py \
        --table {input} \
        --output {output} \
        --mode kraken2 \
        2>&1 > {log}
    """
425
426
427
428
429
430
431
432
433
shell:
    """
    workflow/scripts/join_tables.py \
        --outfile {output.table} \
        --value-column {params.value_column} \
        --feature-column {params.feature_column} \
        {input.bracken} \
        2>&1 > {log}
    """
450
451
452
453
454
455
456
shell:
    """
    workflow/scripts/KrakenTools/kreport2krona.py \
        --report-file {input.bracken_kreport} \
        --output {output.bracken_krona} \
        2>&1 > {log}
    """
472
473
474
475
476
477
    shell:
        """
		ktImportText \
			-o {output.krona_html} \
			{input}
        """
496
497
498
499
500
501
502
503
504
shell:
    """
    {params.filter_bracken} \
        --input-file {input.bracken} \
        --output {output.filtered} \
        {params.include} \
        {params.exclude} \
        2>&1 > {log}
    """
524
525
526
527
528
529
530
531
532
shell:
    """
    workflow/scripts/join_tables.py \
        --outfile {output.table} \
        --value-column {params.value_column} \
        --feature-column {params.feature_column} \
        {input.bracken} \
        2>&1 > {log}
    """
58
59
60
61
62
63
64
65
66
67
shell:
    """
    fuse.sh \
        in1={input.read1} \
        in2={input.read2} \
        out={output.fasta} \
        pad=1 \
        fusepairs=t \
        2> {log}
    """
88
89
90
91
92
93
94
95
96
97
98
99
shell:
    """
    krakenuniq \
        --db {params.db} \
        --threads {threads} \
        --output {output.kraken} \
        --report-file {output.kreport} \
        --preload-size {params.preload_size} \
        {input.fasta} \
        {params.extra} \
        2> {log}
    """
118
119
120
121
122
123
124
125
126
127
shell:
    """
    workflow/scripts/join_tables.py \
        --feature-column rank,taxName \
        --value-column taxReads \
        --outfile {output.combined} \
        --skiplines 2 \
        {input.kreports} \
        2> {log}
    """
143
144
145
146
147
148
149
150
151
shell:
    """
    workflow/scripts/KrakenTools/kreport2mpa.py \
        --report-file {input.kreport} \
        --output {output.txt} \
        --display-header \
        > {log.stdout} \
        2> {log.stderr}
    """
172
173
174
175
176
177
178
179
180
181
shell:
    """
    workflow/scripts/join_tables.py \
        --outfile {output.table} \
        --value-column {params.value_column} \
        --feature-column '{params.feature_column}' \
        {input.txt} \
        > {log.stdout} \
        2> {log.stderr}
    """
199
200
201
202
203
204
205
206
207
shell:
    """
    awk -v OFS='\\t' '{{
      gsub("\\\\|","\\t",$1);
      print $2,$1;
      }}' {input.kreport} \
      > {output} \
      2> {log.stderr}
    """
224
225
226
227
228
229
    shell:
        """
		ktImportText \
			-o {output.krona_html} \
			{input}
        """
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
shell:
    """
    metaphlan \
        --input_type fastq \
        --nproc 10 \
        --sample_id {wildcards.sample} \
        --samout {output.sam_out} \
        --bowtie2out {output.bt2_out} \
        --bowtie2db {params.bt2_db_dir} \
        --index {params.bt2_index} \
        {input.read1},{input.read2} \
        {output.mpa_out} \
        {params.extra} \
        > {log.stdout} \
        2> {log.stderr}
    """
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
shell:
    """
    set +o pipefail  # Small samples can produce empty output files failing the pipeline
    sed '/#/d' {input.mpa_out} \
        | grep -E "s__|unclassified" \
        | cut -f1,3 \
        | awk '{{print $2,"\t",$1}}' \
        | sed 's/|\w__/\t/g' \
        | sed 's/k__//' \
        > {output.krona} \
        2> {log}
    """
123
124
125
126
127
shell:
    """
    merge_metaphlan_tables.py {input} > {output.txt} 2> {log}
    sed --in-place 's/\.metaphlan//g' {output.txt} 
    """
143
144
145
146
147
148
149
150
shell:
    """
    workflow/scripts/area_plot.py \
        --table {input} \
        --output {output} \
        --mode metaphlan4 \
        2>&1 > {log}
    """
177
178
179
180
181
182
183
184
185
186
187
188
189
190
shell:
    """
    workflow/scripts/plot_metaphlan_heatmap.py \
        --outfile-prefix {params.outfile_prefix} \
        --level {wildcards.level} \
        --topN {wildcards.topN} \
        --pseudocount {params.pseudocount} \
        --colormap {params.colormap} \
        --method {params.method} \
        --metric {params.metric} \
        --force \
        {input} \
        2> {log}
    """
211
212
213
214
215
216
217
218
219
220
221
222
223
shell:
    """
    ktImportText \
        -o {output.html_samples} \
        {input} \
        > {log}

    ktImportText \
        -o {output.html_all} \
        -c \
        {input} \
        >> {log}
    """
234
235
236
237
238
239
240
241
242
243
shell:
    """
    set +o pipefail
    sed '/#.*/d' {input.mpa_combined} | cut -f 1- | head -n1 | tee {output.species} {output.genus} {output.family} {output.order} > /dev/null

    sed '/#.*/d' {input.mpa_combined} | cut -f 1- | grep s__ | sed 's/^.*s__/s__/g' >> {output.species}
    sed '/#.*/d' {input.mpa_combined} | cut -f 1- | grep g__ | sed 's/^.*s__.*//g' | grep g__ | sed 's/^.*g__/g__/g' >> {output.genus}
    sed '/#.*/d' {input.mpa_combined} | cut -f 1- | grep f__ | sed 's/^.*g__.*//g' | grep f__ | sed 's/^.*f__/f__/g' >> {output.family}
    sed '/#.*/d' {input.mpa_combined} | cut -f 1- | grep o__ | sed 's/^.*f__.*//g' | grep o__ | sed 's/^.*o__/o__/g' >> {output.order}
    """
52
53
54
55
56
57
58
59
60
shell:
    """
    sample2markers.py \
         -i {input.sam} \
         -o {params.output_dir} \
         -n 8 \
         > {log.stdout} \
         2> {log.stderr}
    """
81
82
83
84
85
86
87
88
89
90
91
92
93
shell:
    """
    strainphlan \
         -s {input.consensus_markers} \
         --print_clades_only \
         -d {params.database} \
         -o {params.out_dir} \
         -n {threads} \
         > {log.stdout} \
         2> {log.stderr}

    cd {params.out_dir} && ln -s ../logs/strainphlan/available_clades.txt
    """
116
117
118
119
120
121
122
123
124
shell:
    """
    extract_markers.py \
         -c {params.clade} \
         -o {params.out_dir} \
         -d {params.database} \
         > {log.stdout} \
         2> {log.stderr}
    """
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
shell:
    """
    echo "please compare your clade_of_interest to list of available clades in available_clades.txt" > {log.stderr}

    strainphlan \
         -s {input.consensus_markers} \
         -m {input.reference_markers} \
         {params.extra} \
         -d {params.database} \
         -o {params.out_dir} \
         -n {threads} \
         -c {params.clade} \
         --phylophlan_mode accurate \
         --mutation_rates \
         > {log.stdout} \
         2>> {log.stderr}
    """
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
from sys import argv, exit
import argparse
import warnings

from matplotlib import rcParams
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

"""
Generates a pretty areaplot from a collapsed feature table.
"""

__author__ = 'JW Debelius'
__date__ = '2020-02'
__version__ = "0.2"

# Sets up the matplotlib parameters so that we can save to be edited in 
# illustator if a direct conversion is required. Because it just makes life
# better
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42

# Sets up the order of colors to be used in joint plots.
colors_order = ['Reds', 'Blues', 'Greens', 'Purples', "Oranges", 'Greys']

over9 = {'Paired', 'Paired_r', 'Set3', 'Set3_r'}
over8 = over9 | {'Set1', "Pastel1"}

mode_dict = {
    'kaiju': {
        'tax_delim': ';',
        'multi_level': True,
        'tax_col': 'taxon_name',
        'table_drop': [],
        'skip_rows': 0,
    },
    'kraken2': {
        'tax_delim': '|',
        'multi_level': True,
        'tax_col': 'taxon_name',
        'table_drop': [],
        'skip_rows': 0,
    },
    'metaphlan': {
        'tax_delim': '|',
        'multi_level': True,
        'tax_col': 'clade_name',
        'table_drop': ['NCBI_tax_id'],
        'skip_rows': 1,
    },
    'metaphlan4': {
        'tax_delim': '|',
        'multi_level': True,
        'tax_col': 'clade_name',
        'table_drop': [],
        'skip_rows': 1,
    },
    'marker': {
        'tax_delim': ';',
        'multi_level': False,
        'tax_col': 'taxonomy',
        'table_drop': ['sequence'],
        'skip_rows': 0,
    },
}


def extract_label_array(table, tax_col, tax_delim='|'):
    """
    Converts delimited taxonomy strings into a working table

    Parameters
    ----------
    table : DataFrame
        A DataFrame with observation on the rows (biom-style table) with 
        `tax_col` as one of its columns.
    tax_col : str
        The column in `table` containing the taxonomy information
    tax_delim: str, optional
        The delimiter between taxonomic groups

    Returns
    -------
    DataFrame
        The taxonomic strings parsed into n levels
    """
    def f_(x):
        return pd.Series([y.strip() for y in x.split(tax_delim)])

    return table[tax_col].apply(f_)


def level_taxonomy(table, taxa, samples, level, consider_nan=True):
    """
    Gets the taxonomy collapsed to the desired level

    Parameters
    ----------
    table : DataFrame
        A table with observation on the rows (biom-style table) with `samples`
        in its columns.
    taxa: DataFrame
        The taxonomic strings parsed into n levels
    level: list
        The level to which the taxonomy should be summarized 
    samples : list
        The columns from `table` to be included in the analysis
    consider_nan: bool, optional
        Whether the table contains multiple concatenated, in which cases 
        considering `nan` will filter the samples to retain only the levels 
        of interest. This is recommended for kraken/bracken tables, but not 
        applicable for some 16s sequences
    """
    level = level.max()
    if consider_nan:
        leveler = (taxa[level].notna() & taxa[(level + 1)].isna())
    else:
        leveler = (taxa[level].notna())

    cols = list(np.arange(level + 1))

    # Combines the filtered tables
    level_ = pd.concat(
        axis=1, 
        objs=[taxa.loc[leveler, cols],
              (table.loc[leveler, samples] / table.loc[leveler, samples].sum(axis=0))],
    #     # sort=False
    )
    level_.reset_index()
    if taxa.loc[leveler, cols].duplicated().any():
        return level_.groupby(cols).sum()
    else:
        return level_.set_index(cols)


def profile_one_level(collapsed, level, threshold=0.01, count=8):
    """
    Gets upper and lower tables for a single taxonomic level

    Parameters
    ----------
    Collapsed: DataFrame
        The counts data with the index as a multi-level index of 
        levels of interest and the columns as samples
    threshold: float, optional
        The minimum relative abundance for an organism to be shown
    count : int, optional
        The maximum number of levels to show for a single group

    Returns
    -------
    DataFame
        A table of the top taxa for the data of interest
    """
    collapsed['mean'] = collapsed.mean(axis=1)
    collapsed.sort_values(['mean'], ascending=False, inplace=True)
    collapsed['count'] = 1
    collapsed['count'] = collapsed['count'].cumsum()

    thresh_ = (collapsed['mean'] > threshold) & (collapsed['count'] <= count)
    top_taxa = collapsed.loc[thresh_].copy()
    top_taxa.drop(columns=['mean', 'count'], inplace=True)
    for l_ in np.arange(level):
        top_taxa.index = top_taxa.index.droplevel(l_)

    first_ = top_taxa.index[0]

    top_taxa.sort_values(
        [first_],
        ascending=False,
        axis='columns',
        inplace=True,
    )

    upper_ = top_taxa.cumsum()
    lower_ = top_taxa.cumsum() - top_taxa

    return upper_, lower_


def profile_joint_levels(collapsed, lo_, hi_, samples, lo_thresh=0.01, 
    hi_thresh=0.01, lo_count=4, hi_count=5):
    """
    Generates a table of taxonomy using two levels to define grouping

    Parameters
    ----------
    collapsed: DataFrame
        The counts data with the index as a multi-level index of 
        levels of interest and the columns as samples
    lo_, hi_: int
        The numeric identifier for lower (`lo_`) and higher (`hi_`)
        resolution where "low" is defined as having fewer groups.
        (i.e. for taxonomy Phylum is low, family is high)
    lo_thresh, hi_thresh: int, optional
        The minimum relative abundance for an organism to be shown
        at a given level
    lo_count, hi_count : int, optional
        The maximum number of levels to show for a single group. 
        This is to appease the limitations of our eyes and colormaps.

    Returns
    -------
    DataFame
        A table of the top taxa for the data of interest
    """
    collapsed['mean_hi'] = collapsed.mean(axis=1)
    collapsed.reset_index(inplace=True)
    mean_lo_rep = collapsed.groupby(lo_)['mean_hi'].sum().to_dict()
    collapsed['mean_lo'] = collapsed[lo_].replace(mean_lo_rep)
    collapsed.sort_values(['mean_lo', 'mean_hi'], ascending=False, inplace=True)

    collapsed['count_lo'] = ~collapsed[lo_].duplicated(keep='first') * 1.
    collapsed['count_lo'] = collapsed['count_lo'].cumsum()
    collapsed['count_hi'] = 1
    collapsed['count_hi'] = collapsed.groupby(lo_)['count_hi'].cumsum()

    collapsed['thresh_lo'] = ((collapsed['mean_lo'] > lo_thresh) & 
                              (collapsed['count_lo'] <= lo_count))
    collapsed['thresh_hi'] = ((collapsed['mean_hi'] > hi_thresh) & 
                              (collapsed['count_hi'] <= hi_count))

    top_lo = collapsed.loc[collapsed['thresh_lo']].copy()
    top_lo['other'] = ~(top_lo['thresh_lo'] & ~top_lo['thresh_hi']) * 1
    top_lo['new_name'] = top_lo[lo_].apply(lambda x: 'other %s' % x )
    top_lo.loc[top_lo['thresh_hi'], 'new_name'] = \
        top_lo.loc[top_lo['thresh_hi'], hi_]

    drop_levels = np.arange(hi_)[np.arange(hi_) != lo_]
    top_lo.drop(columns=drop_levels, inplace=True)
    new_taxa = top_lo.groupby([lo_, 'new_name']).sum(dropna=True)
    new_taxa.reset_index(inplace=True)
    new_taxa['mean_lo'] = new_taxa[lo_].replace(mean_lo_rep)
    new_taxa.set_index([lo_, 'new_name'], inplace=True)
    new_taxa.sort_values(['mean_lo', 'count_hi', 'mean_hi'], 
                         ascending=[False, True, False], 
                         inplace=True)

    upper_ = new_taxa.cumsum()[samples]
    upper_.sort_values([upper_.index[0], upper_.index[1]],
                       axis='columns', 
                       inplace=True, ascending=False)
    lower_ = upper_ - new_taxa[upper_.columns]

    upper_.index.set_names(['rough', 'fine'], inplace=True)
    lower_.index.set_names(['rough', 'fine'], inplace=True)

    return upper_, lower_


def define_single_cmap(cmap, top_taxa):
    """
    Gets the colormap a single level table

    """
    # Gets the colormap object
    map_ = mpl.colormaps[cmap]
    # Gets the taxonomic object
    return {tax: map_(i) for i, tax in enumerate(top_taxa.index)}


def define_join_cmap(table):
    """
    Defines a joint colormap for a taxonomic table.
    """
    table['dummy'] = 1
    grouping = table['dummy'].reset_index()

    rough_order = grouping['rough'].unique()

    rough_map = {group: mpl.colormaps[cmap] 
                 for (group,cmap) in zip(*(rough_order, colors_order))} 
    pooled_map = dict([])
    for rough_, fine_ in grouping.groupby('rough')['fine']:
        cmap_ = rough_map[rough_]
        colors = {c: cmap_(200-(i + 1) * 20) for i, c in enumerate(fine_)}
        pooled_map.update(colors)

    table.drop(columns=['dummy'], inplace=True)

    return pooled_map


def plot_area(upper_, lower_, colors, sample_interval=5):
    """
    An in-elegant function to make an area plot

    Yes, you'll get far more control if you do it yourself outside this
    function but it will at least give you a first pass of a stacked area
    plot

    Parameters
    ---------
    upper_, lower_ : DataFrame
        The upper (`top_`) and lower (`low_`) limits for the 
        area plot. This should already be sorted in the
        desired order.
    colors: dict
        A dictionary mapping the taxonomic label to the
        appropriate matplotlib readable colors. For convenience,
        `define_single_colormap` and `define_joint_colormap`
        are good functions to use to generate this
    sample_interval : int, optional
        The interval for ticks for counting samples.

    Returns
    -------
    Figure
        A 8" x 4" matplotlib figure with the area plot and legend.
    """

    # Gets the figure
    fig_, ax1 = plt.subplots(1,1)
    fig_.set_size_inches((8, 4))
    ax1.set_position((0.15, 0.125, 0.4, 0.75))

    # Plots the area plot
    x = np.arange(0, len(upper_.columns))
    for taxa, hi_ in upper_.iloc[::-1].iterrows():
        lo_ = lower_.loc[taxa]
        cl_ = colors[taxa]


        ax1.fill_between(x=x, y1=1-lo_.values, y2=1-hi_.values, 
                         color=cl_, label=taxa)
    # Adds the legend
    leg_ = ax1.legend()
    leg_.set_bbox_to_anchor((2.05, 1))

    # Sets up the y-axis so the order matches the colormap
    # (accomplished by flipping the axis?)
    ax1.set_ylim((1, 0))
    ax1.set_yticks(np.arange(0, 1.1, 0.25))
    ax1.set_yticklabels(np.arange(1, -0.1, -0.25), size=11)
    ax1.set_ylabel('Relative Abundance', size=13)

    # Sets up x-axis without numeric labels
    ax1.set_xticklabels([])
    ax1.set_xticks(np.arange(0, x.max(), sample_interval))
    ax1.set_xlim((0, x.max() - 0.99))  # Subtract less than 1 to avoid singularity if xmin=xmax=0
    ax1.set_xlabel('Samples', size=13)

    return fig_


def single_area_plot(table, level=3, samples=None, 
    tax_col='taxon_name', cmap='Set3',
    tax_delim='|', multilevel_table=True, abund_thresh=0.1, 
    group_thresh=8):
    """
    Generates an area plot for the table at the specified level of resolution

    Parameters
    ----------
    table : DataFrame
        A pandas dataframe of the original table of data (either containing 
        counts or relative abundance)
    level : int
        The hierarchical level within the table to display as an integer
    cmap : str
        The qualitative colormap to use to generate your plot. Refer to 
        colorbrewer for options. If a selected colormap exceeds the number
        of groups (`--group-thresh`) possible, it will default to Set3.
    samples : list, optional
        The columns from `table` to be included in the analysis. If `samples`
        is None, then all columns in `table` except `tax_col` will be used.
    tax_col : str, optional
        The column in `table` which contains the taxonomic information.
    tax_delim: str, optional
        The delimiter between taxonomic levels, for example "|" or ";".
    multilevel_table: bool, optional
        Whether the table contains multiple concatenated, in which cases 
        considering `nan` will filter the samples to retain only the levels 
        of interest. This is recommended for kraken/bracken tables, but not 
        applicable for some 16s sequences
    abund_thresh: float [0, 1]
        The mean abundance threshold for a sample to be plotted. This is 
        in conjunction with the group threshold (`--group-thresh`) will be 
        used to determine the groups that are shown.
    group_thresh: int, [1, 12]
        The maximum number of groups (colors) to show in the area plot. This 
        is handled in conjunction with the `--abund-thresh` in that 
        to be displayed, a group must have both a mean relative abundance 
        exceeding the `abund-thresh` and must be in the top `group-thresh` 
        groups.

    Returns
    -------
    Figure
        A 8" x 4" matplotlib figure with the area plot and legend.

    Also See
    --------
    make_joint_area_plot

    """

    if group_thresh > 12:
        raise ValueError("You may display at most 12 colors on this plot. "
                         "Please re-consider your plotting choices.")
    elif (group_thresh > 9) & ~(cmap in over9):
        warnings.warn('There are too many colors for your colormap. '
                      'Changing to Set3.')
        cmap = 'Set3'
    elif (group_thresh > 8) & ~(cmap in over8):
        warnings.warn('There are too many colors for your colormap. '
                      'Changing to Set3.')
        cmap = 'Set3'

    # Parses the taxonomy and collapses the table
    taxa = extract_label_array(table, tax_col, tax_delim)

    if samples is None:
        samples = list(table.columns.values)
        samples.remove(tax_col)

    # Gets the appropriate taxonomic level information to go forward
    collapsed = level_taxonomy(table, taxa, samples, np.array([level]), 
                              consider_nan=multilevel_table)

    # Gets the top taxonomic levels
    upper_, lower_, = profile_one_level(collapsed, np.array([level]), 
                                        threshold=abund_thresh, 
                                        count=group_thresh)

    # Gets the colormap 
    cmap = define_single_cmap(cmap, upper_)

    # Plots the data
    fig_ = plot_area(upper_, lower_, cmap)

    return fig_


def joint_area_plot(table, rough_level=2, fine_level=5, samples=None, 
    tax_col='taxon_name', tax_delim='|', 
    multilevel_table=True, abund_thresh_rough=0.1, 
    abund_thresh_fine=0.05, group_thresh_fine=5, 
    group_thresh_rough=5):
    """
    Generates an area plot with nested grouping where the the higher level
    (`rough_level`) in the table (lower resolution/fewer groups) is used to 
    provide the general grouping structure and then within each `rough_level`,
    a number of `fine_level` groups are displayed. 

    Parameters
    ----------
    table : DataFrame
        A dataframe of hte original data, either as counts or relative 
        abundance with the taxonomic information in `tax_col`. The data
        can have separate count values at multiple levels (i.e. combine)
        collapsed phylum, class, etc levels.
    rough_level, fine_level: int
        The taxonomic levels to be displayed. The `fine_level` will be grouped
        by `rough_level` to display the data grouped by `rough_level`. The
        `rough_level` should smaller than the `fine_level`. 
    samples : list, optional
        The columns from `table` to be included in the analysis. If `samples`
        is None, then all columns in `table` except `tax_col` will be used.
    tax_col : str, optional
        The column in `table` which contains the taxonomic information.
    tax_delim: str, optional
        The delimiter between taxonomic levels, for example "|" or ";".
    multilevel_table: bool, optional
        Whether the table contains multiple concatenated, in which cases 
        considering `nan` will filter the samples to retain only the levels 
        of interest. This is recommended for kraken/bracken tables, but not 
        applicable for some 16s sequences
    abund_thresh_rough, abund_thresh_fine : float [0, 1]
        The mean abundance threshold for a taxonomic group to be plotted for
        the higher level grouping (`abund_thresh_rough`) and sub grouping
        level. This will be used in conjunction with the `group_thresh_rough`
        and `group_thresh_fine` to determine the number of groups to be
        included.
    group_thresh_fine, group_thresh_rough: int, [1, 6]
        The maximum number of taxonmic groups to display for the respective 
        level. If `group_thresh_rough` > 6, then it will be replaced with 
        6 because this is the maximum number of avaliable color groups.

    Returns
    -------
    Figure
        A 8" x 4" matplotlib figure with the area plot and legend.

    Also See
    --------
    single_area_plot

    """

    # Parses the taxonomy and collapses the table
    taxa = extract_label_array(table, tax_col, tax_delim)
    if samples is None:
        samples = list(table.drop(columns=[tax_col]).columns)

    # Gets the appropriate taxonomic level information to go forward
    collapsed = level_taxonomy(table, taxa, samples, 
                               level=np.array([fine_level]), 
                               consider_nan=multilevel_table)
    samples = collapsed.columns

    # Gets the top taxonomic levels
    upper_, lower_, = profile_joint_levels(collapsed, rough_level, fine_level, 
                                           samples=samples,
                                           lo_thresh=abund_thresh_rough, 
                                           lo_count=min(5, group_thresh_rough),
                                           hi_thresh=abund_thresh_fine,
                                           hi_count=group_thresh_fine,
                                           )
    # Gets the colormap 
    cmap = define_join_cmap(upper_)
    upper_.index = upper_.index.droplevel('rough')
    lower_.index = lower_.index.droplevel('rough')

    # Plots the data
    fig_ = plot_area(upper_.astype(float), lower_.astype(float), cmap)

    return fig_


# Sets up the main arguments for argparse.
def create_argparse():
    parser_one = argparse.ArgumentParser(
        description=('A set of functions to generate diagnostic stacked area '
                     'plots from metagenomic outputs.'),
        prog=('area_plotter'),
        )
    parser_one.add_argument(
        '-t', '--table', 
        help=('The abundance table as a tsv classic biom (features as rows, '
              'samples as columns) containing absloute or relative abundance '
              'for the samples.'),
        required=True,
        )
    parser_one.add_argument(
        '-o', '--output',
        help=('The location for the final figure'),
        required=True,
        )
    parser_one.add_argument(
        '-s', '--samples', 
        help=('A text file with the list of samples to be included (one '
            'per line). If no list is provided, then data from all columns '
            'in the table (except the one specifying taxonomy) will be used.'),
        )
    parser_one.add_argument(
        '--mode', 
        choices=mode_dict.keys(),
        help=('The software generating the table to make parsing easier. '
              'Options are kraken, metaphlan, marker (i.e. CTMR amplicon).'),
        )
    parser_one.add_argument(
        '-l', '--level',
        help=('The taxonomic level (as an integer) to plot the data.'),
        default=3,
        type=int,
        )
    parser_one.add_argument(
        '--abund-thresh',
        help=("the minimum abundance required to display a group."),
        default=0.01,
        type=float,
        )
    parser_one.add_argument(
        '--group-thresh',
        help=("The maximum number of groups to be displayed in the graph."),
        default=8,
        type=int,
        )
    parser_one.add_argument(
        '-c', '--colormap',
        help=("The qualitative colormap to use to generate your plot. Refer"
             ' to colorbrewer for options. If a selected colormap exceeds '
             'the number of groups (`--group-thresh`) possible, it will '
             'default to Set3.'),
        default='Set3',
        )
    parser_one.add_argument(
        '--sub-level',
        help=('The second level to use if doing a joint plot'),
        type=int,
        )
    parser_one.add_argument(
        '--sub-abund-thresh',
        help=("the minimum abundance required to display a sub group"),
        default=0.05,
        type=float,
        )
    parser_one.add_argument(
        '--sub-group-thresh',
        help=("the maximum number of sub groups allowed in a joint level plot."),
        default=5,
        type=float,
        )
    parser_one.add_argument(
        '--tax-delim',
        help=("String delimiting taxonomic levels."),
        type=str,
        )
    parser_one.add_argument(
        '--multi-level',
        help=("Whether the table contains multiple concatenated, in which "
             "case considering `nan` will filter the samples to retain only"
             "the levels of interest. This is recommended for most "
             "metagenomic tables, but not applicable for some 16s sequences."),
        )
    parser_one.add_argument(
        "--tax-col",
        help=("The column in `table` containig the taxobnomy information"),
        )
    parser_one.add_argument(
        '--table-drop',
        help=('A comma-seperated list describing the columns to drop'),
        )
    parser_one.add_argument(
        '--skip-rows',
        help=('The number of rows to skip when reading in the feature table.')
        )

    return parser_one


if __name__ == '__main__':
    parser_one = create_argparse()

    if len(argv) < 2:
        parser_one.print_help()
        exit()

    args = parser_one.parse_args()

    if args.table_drop is not None:
        args.table_drop = [s for s in args.table_drop.split(',')]
    else:
        args.table_drop = []

    mode_defaults = mode_dict.get(args.mode, mode_dict['kraken2'])
    mode_defaults.update({k: v for k, v in args.__dict__.items() 
                         if (k in mode_defaults) and (v)})

    table = pd.read_csv(args.table, sep='\t', 
                        skiprows=mode_defaults['skip_rows'])
    if args.samples is not None:
        with open(args.samples, 'r') as f_:
            samples = f_.read().split('\n')
    else:
        samples = None

    if args.sub_level is not None:
        fig_ = joint_area_plot(
            table.drop(columns=mode_defaults['table_drop']),
            rough_level=args.level,
            fine_level=args.sub_level,
            samples=args.samples,
            tax_delim=mode_defaults['tax_delim'],
            tax_col=mode_defaults['tax_col'],
            multilevel_table=mode_defaults['multi_level'],
            abund_thresh_rough=args.abund_thresh,
            group_thresh_rough=args.group_thresh,
            abund_thresh_fine=args.sub_abund_thresh,
            group_thresh_fine=args.sub_group_thresh,
        )
    else:
        fig_ = single_area_plot(
            table.drop(columns=mode_defaults['table_drop']),
            level=args.level,
            cmap=args.colormap,
            samples=samples,
            tax_delim=mode_defaults['tax_delim'],
            tax_col=mode_defaults['tax_col'],
            multilevel_table=mode_defaults['multi_level'],
            abund_thresh=args.abund_thresh,
            group_thresh=args.group_thresh,
        )

    fig_.savefig(args.output, dpi=300)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
__author__ = "Fredrik Boulund"
__date__ = "2020-2022"
__version__ = "1.1"

from sys import argv, exit
from functools import reduce, partial
from pathlib import Path
import argparse

import pandas as pd


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("TABLE", nargs="+",
            help="TSV table with columns headers.")
    parser.add_argument("-f", "--feature-column", dest="feature_column",
            default="name",
            help="Column header of feature column to use, "
                 "typically containing taxa names. "
                 "Select several columns by separating with comma (e.g. name,taxid) "
                 "[%(default)s].")
    parser.add_argument("-c", "--value-column", dest="value_column",
            default="fraction_total_reads",
            help="Column header of value column to use, "
                 "typically containing counts or abundances [%(default)s].")
    parser.add_argument("-o", "--outfile", dest="outfile",
            default="joined_table.tsv",
            help="Outfile name [%(default)s].")
    parser.add_argument("-n", "--fillna", dest="fillna", metavar="FLOAT",
            default=0.0,
            type=float,
            help="Fill NA values in merged table with FLOAT [%(default)s].")
    parser.add_argument("-s", "--skiplines", dest="skiplines", metavar="N",
            default=0,
            type=int,
            help="Skip N lines before parsing header (e.g. for files "
                 "containing comments before the real header) [%(default)s].")

    if len(argv) < 2:
        parser.print_help()
        exit()

    return parser.parse_args()


def main(table_files, feature_column, value_column, outfile, fillna, skiplines):
    feature_columns = feature_column.split(",")

    tables = []
    for table_file in table_files:
        sample_name = Path(table_file).name.split(".")[0]
        tables\
            .append(pd.read_csv(table_file, sep="\t", skiprows=skiplines)\
            .set_index(feature_columns)\
            .rename(columns={value_column: sample_name})\
            .loc[:, [sample_name]])  # Ugly hack to get a single-column DataFrame

    df = tables[0]
    for table in tables[1:]:
        df = df.join(table, how="outer")
    df.fillna(fillna, inplace=True)

    df.to_csv(outfile, sep="\t")


if __name__ == "__main__":
    args = parse_args()
    if len(args.TABLE) < 2:
        print("Need at least two tables to merge!")
        exit(1)
    main(args.TABLE, args.feature_column, args.value_column, args.outfile, args.fillna, args.skiplines)
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import os, sys, argparse
import operator
from time import gmtime 
from time import strftime 

#Tree Class 
#usage: tree node used in constructing a taxonomy tree
#   including only the taxonomy levels and genomes identified in the Kraken report
class Tree(object):
    'Tree node.'
    def __init__(self, name, taxid, level_num, level_id, all_reads, lvl_reads, children=None, parent=None):
        self.name = name
        self.taxid = taxid
        self.level_num = level_num
        self.level_id = level_id
        self.tot_all = all_reads
        self.tot_lvl = lvl_reads
        self.all_reads = {}
        self.lvl_reads = {}
        self.children = []
        self.parent = parent
        if children is not None:
            for child in children:
                self.add_child(child)
    def add_child(self,node):
        assert isinstance(node,Tree)
        self.children.append(node)
    def add_reads(self, sample, all_reads, lvl_reads):
        self.all_reads[sample] = all_reads
        self.lvl_reads[sample] = lvl_reads
        self.tot_all += all_reads
        self.tot_lvl += lvl_reads
    def __lt__(self,other):
        return self.tot_all < other.tot_all

####################################################################
#process_kraken_report
#usage: parses a single line in the kraken report and extracts relevant information
#input: kraken report file with the following tab delimited lines 
#   - percent of total reads   
#   - number of reads (including at lower levels)
#   - number of reads (only at this level)
#   - taxonomy classification of level 
#       (U, - (root), - (cellular org), D, P, C, O, F, G, S) 
#   - taxonomy ID (0 = unclassified, 1 = root, 2 = Bacteria...etc)
#   - spaces + name 
#returns:
#   - classification/genome name
#   - taxonomy ID for this classification
#   - level for this classification (number)
#   - level name (U, -, D, P, C, O, F, G, S)
#   - all reads classified at this level and below in the tree
#   - reads classified only at this level
def process_kraken_report(curr_str):
    split_str = curr_str.strip().split('\t')
    if len(split_str) < 5:
        return []
    try:
        int(split_str[1])
    except ValueError:
        return []
    #Extract relevant information
    all_reads =  int(split_str[1])
    level_reads = int(split_str[2])
    level_type = split_str[-3]
    taxid = split_str[-2] 
    #Get name and spaces
    spaces = 0
    name = split_str[-1]
    for char in name:
        if char == ' ':
            name = name[1:]
            spaces += 1 
        else:
            break 
    #Determine which level based on number of spaces
    level_num = int(spaces/2)
    return [name, taxid, level_num, level_type, all_reads, level_reads]

####################################################################
#Main method
def main():
    #Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-r','--report-file','--report-files',
        '--report','--reports', required=True,dest='r_files',nargs='+',
        help='Input kraken report files to combine (separate by spaces)') 
    parser.add_argument('-o','--output', required=True,dest='output',
        help='Output kraken report file with combined information')
    parser.add_argument('--display-headers',required=False,dest='headers',
        action='store_true', default=True,
        help='Include header lines')
    parser.add_argument('--no-headers',required=False,dest='headers',
        action='store_false',default=True,
        help='Do not include header lines')
    parser.add_argument('--sample-names',required=False,nargs='+',
        dest='s_names',default=[],help='Sample names to use as headers in the new report')
    parser.add_argument('--only-combined', required=False, dest='c_only',
        action='store_true', default=False, 
        help='Include only the total combined reads column, not the individual sample cols')
    args=parser.parse_args()


    #Initialize combined values 
    main_lvls = ['U','R','D','K','P','C','O','F','G','S']
    map_lvls = {'kingdom':'K', 'superkingdom':'D','phylum':'P','class':'C','order':'O','family':'F','genus':'G','species':'S'}
    count_samples = 0
    num_samples = len(args.r_files)
    sample_names = args.s_names
    root_node = -1 
    prev_node = -1
    curr_node = -1
    u_reads = {0:0} 
    total_reads = {0:0} 
    taxid2node = {}

    #Check input values 
    if len(sample_names) > 0 and len(sample_names) != num_samples: 
        sys.stderr.write("Number of sample names provided does not match number of reports\n")
        sys.exit(1)
    #Map names
    id2names = {} 
    id2files = {} 
    if len(sample_names) == 0:
        for i in range(num_samples):
            id2names[i+1] = "S" + str(i+1)
            id2files[i+1] = ""
    else:
        for i in range(num_samples):
            id2names[i+1] = sample_names[i] 
            id2files[i+1] = ""

    #################################################
    #STEP 1: READ IN REPORTS
    #Iterate through reports and make combined tree! 
    sys.stdout.write(">>STEP 1: READING REPORTS\n")
    sys.stdout.write("\t%i/%i samples processed" % (count_samples, num_samples))
    sys.stdout.flush()
    for r_file in args.r_files:
        count_samples += 1 
        sys.stdout.write("\r\t%i/%i samples processed" % (count_samples, num_samples))
        sys.stdout.flush()
        id2files[count_samples] = r_file
        #Open File 
        curr_file = open(r_file,'r')
        for line in curr_file: 
            report_vals = process_kraken_report(line)
            if len(report_vals) < 5:
                continue
            [name, taxid, level_num, level_id, all_reads, level_reads] = report_vals
            if level_id in map_lvls:
                level_id = map_lvls[level_id]
            #Total reads 
            total_reads[0] += level_reads
            total_reads[count_samples] = level_reads 
            #Unclassified 
            if level_id == 'U' or taxid == '0':
                u_reads[0] += level_reads
                u_reads[count_samples] = level_reads 
                continue
            #Tree Root 
            if taxid == '1': 
                if count_samples == 1:
                    root_node = Tree(name, taxid, level_num, 'R', 0,0)
                    taxid2node[taxid] = root_node 
                root_node.add_reads(count_samples, all_reads, level_reads) 
                prev_node = root_node
                continue 
            #Move to correct parent
            while level_num != (prev_node.level_num + 1):
                prev_node = prev_node.parent
            #IF NODE EXISTS 
            if taxid in taxid2node: 
                taxid2node[taxid].add_reads(count_samples, all_reads, level_reads) 
                prev_node = taxid2node[taxid]
                continue 
            #OTHERWISE
            #Determine correct level ID
            if level_id == '-' or len(level_id)> 1:
                if prev_node.level_id in main_lvls:
                    level_id = prev_node.level_id + '1'
                else:
                    num = int(prev_node.level_id[-1]) + 1
                    level_id = prev_node.level_id[:-1] + str(num)
            #Add node to tree
            curr_node = Tree(name, taxid, level_num, level_id, 0, 0, None, prev_node)
            curr_node.add_reads(count_samples, all_reads, level_reads)
            taxid2node[taxid] = curr_node
            prev_node.add_child(curr_node)
            prev_node = curr_node 
        curr_file.close()

    sys.stdout.write("\r\t%i/%i samples processed\n" % (count_samples, num_samples))
    sys.stdout.flush()

    #################################################
    #STEP 2: SETUP OUTPUT FILE
    sys.stdout.write(">>STEP 2: WRITING NEW REPORT HEADERS\n")
    o_file = open(args.output,'w') 
    #Lines mapping sample ids to filenames
    if args.headers: 
        o_file.write("#Number of Samples: %i\n" % num_samples) 
        o_file.write("#Total Number of Reads: %i\n" % total_reads[0])
        for i in id2names:
            o_file.write("#")
            o_file.write("%s\t" % id2names[i])
            o_file.write("%s\n" % id2files[i])
        #Report columns
        o_file.write("#perc\ttot_all\ttot_lvl")
        if not args.c_only:
            for i in id2names:
                o_file.write("\t%s_all" % i)
                o_file.write("\t%s_lvl" % i)
        o_file.write("\tlvl_type\ttaxid\tname\n")
    #################################################
    #STEP 3: PRINT TREE
    sys.stdout.write(">>STEP 3: PRINTING REPORT\n")
    #Print line for unclassified reads
    o_file.write("%0.4f\t" % (float(u_reads[0])/float(total_reads[0])*100))
    for i in u_reads:
        if i == 0 or (i > 0 and not args.c_only):
            o_file.write("%i\t" % u_reads[i])
            o_file.write("%i\t" % u_reads[i])
    o_file.write("U\t0\tunclassified\n")
    #Print for all remaining reads 
    all_nodes = [root_node]
    curr_node = -1
    curr_lvl = 0
    prev_node = -1
    while len(all_nodes) > 0:
        #Remove node and insert children
        curr_node = all_nodes.pop()
        if len(curr_node.children) > 0:
            curr_node.children.sort()
            for node in curr_node.children:
                all_nodes.append(node)
        #Print information for this node 
        o_file.write("%0.4f\t" % (float(curr_node.tot_all)/float(total_reads[0])*100))
        o_file.write("%i\t" % curr_node.tot_all)
        o_file.write("%i\t" % curr_node.tot_lvl)
        if not args.c_only:
            for i in range(num_samples):
                if (i+1) not in curr_node.all_reads: 
                    o_file.write("0\t0\t")
                else:
                    o_file.write("%i\t" % curr_node.all_reads[i+1])
                    o_file.write("%i\t" % curr_node.lvl_reads[i+1])
        o_file.write("%s\t" % curr_node.level_id)
        o_file.write("%s\t" % curr_node.taxid)
        o_file.write(" "*curr_node.level_num*2)
        o_file.write("%s\n" % curr_node.name)
    o_file.close() 
####################################################################
if __name__ == "__main__":
    main()
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os, sys, argparse

####################################################################
#process_kraken_report
#usage: parses a single line in the kraken report and extracts relevant information
#input: kraken report file with the following tab delimited lines
#   - percent of total reads
#   - number of reads (including at lower levels)
#   - number of reads (only at this level)
#   - taxonomy classification of level
#       (U, D, P, C, O, F, G, S, -)
#   - taxonomy ID (0 = unclassified, 1 = root, 2 = Bacteria,...etc)
#   - spaces + name
#returns:
#   - classification/genome name
#   - level name (U, -, D, P, C, O, F, G, S)
#   - reads classified at this level and below in the tree
def process_kraken_report(curr_str):
    split_str = curr_str.strip().split('\t')
    if len(split_str) < 2:
        return []
    try:
        int(split_str[1])
    except ValueError:
        return []
    all_reads = int(split_str[1])
    lvl_reads = int(split_str[2])
    level_type = split_str[-3]
    type2main = {'superkingdom':'D','phylum':'P',
        'class':'C','order':'O','family':'F',
        'genus':'G','species':'S'} 
    if len(level_type) > 1:
        if level_type in type2main:
            level_type = type2main[level_type]
        else:
            level_type = '-'
    #Get name and spaces 
    spaces = 0
    name = split_str[-1]
    for char in name:
        if char == ' ':
            name = name[1:]
            spaces += 1
        else:
            break
    name = name.replace(' ','_')
    #Determine level based on number of spaces
    level_num = spaces/2
    return [name, level_num, level_type, lvl_reads]

###################################################################
#kreport2krona_all
#usage: prints all levels for a kraken report 
#input: kraken report file and output krona file names 
#returns: none 
def kreport2krona_all(report_file, out_file):
    #Process report file and output 
    curr_path = [] 
    prev_lvl_num = -1
    r_file = open(report_file, 'r')
    o_file = open(out_file, 'w')
    #Read through report file 
    main_lvls = ['D','P','C','O','F','G','S']
    for line in r_file:
        report_vals = process_kraken_report(line)
        #If header line, skip
        if len(report_vals) < 4: 
            continue
        #Get relevant information from the line 
        [name, level_num, level_type, lvl_reads] = report_vals
        if level_type == 'U':
            o_file.write(str(lvl_reads) + "\tUnclassified\n")
            continue
        #Create level name 
        if level_type not in main_lvls:
            level_type = "x"
        elif level_type == "D":
            level_type = "K"
        level_str = level_type.lower() + "__" + name
        #Determine full string to add
        if prev_lvl_num == -1:
            #First level
            prev_lvl_num = level_num
            curr_path.append(level_str)
            o_file.write(str(lvl_reads) + "\t" + level_str + "\n")
        else:
            o_file.write(str(lvl_reads))
            #Move back if needed
            while level_num != (prev_lvl_num + 1):
                prev_lvl_num -= 1
                curr_path.pop()
            #Print all ancestors of current level followed by |
            for string in curr_path:
                if string[0] != "r": 
                    o_file.write("\t" + string)
            #Print final level and then number of reads
            o_file.write("\t" + level_str + "\n")
            #Update
            curr_path.append(level_str)
            prev_lvl_num = level_num
    o_file.close()
    r_file.close()

###################################################################
#kreport2krona_main
#usage: prints only main taxonomy levels for a kraken report 
#input: kraken report file and output krona file names 
#returns: none 
def kreport2krona_main(report_file, out_file):
    #Process report file and output 
    main_lvls = ['D','P','C','O','F','G','S']
    curr_path = [] 
    prev_lvl_num = -1
    num2path = {} 
    path2reads = {} 
    line_num = -1
    #Read through report file 
    r_file = open(report_file, 'r')
    for line in r_file:
        line_num += 1
        #########################################
        report_vals = process_kraken_report(line)
        #If header line, skip
        if len(report_vals) < 4: 
            continue
        #Get relevant information from the line 
        [name, level_num, level_type, lvl_reads] = report_vals
        if level_type == 'U':
            num2path[line_num] = ["Unclassified"]
            path2reads["Unclassified"] = lvl_reads 
            continue
        #########################################
        #Create level name 
        if level_type not in main_lvls:
            level_type = "x"
        elif level_type == "D":
            level_type = "K"
        level_str = level_type.lower() + "__" + name
        #########################################
        #Determine full string to add
        if prev_lvl_num == -1:
            #First level
            prev_lvl_num = level_num
            curr_path.append(level_str)
            #Save
            if curr_path[-1][0] == "x":
                num2path[line_num] = ""
            else:
                path2reads[curr_path[-1]] = lvl_reads
                num2path[line_num] = []
                for i in curr_path:
                    num2path[line_num].append(i)
            continue
        else:
            #########################################
            #Move back if needed
            while level_num != (prev_lvl_num + 1):
                prev_lvl_num -= 1
                curr_path.pop()
            #Update the list 
            curr_path.append(level_str)
            prev_lvl_num = level_num
            #########################################
            #IF AT NON-TRADITIONAL LEVEL, ADD TO PARENT
            if level_type == "x":
                test_num = len(curr_path) - 1
                while(test_num >= 0):
                    if curr_path[test_num][0] != "x":
                        path2reads[curr_path[test_num]] += lvl_reads 
                        test_num = -1
                    test_num = test_num - 1 
                num2path[line_num] = ""
            #IF AT TRADITIONAL LEVEL, SAVE 
            if level_type != "x":
                path2reads[curr_path[-1]] = lvl_reads
                num2path[line_num] = []
                for i in curr_path:
                    num2path[line_num].append(i)
    r_file.close() 

    #WRITE OUTPUT FILE
    o_file = open(out_file, 'w')
    for i in range(0,line_num+1):
        #Get values
        if i not in num2path:
            continue
        curr_path = num2path[i] 
        if len(curr_path) > 0:
            curr_reads = path2reads[curr_path[-1]] 
            if curr_path[-1][0] != "x":
                o_file.write("%i" % curr_reads)
            for name in curr_path:
                if name[0] != "r" and name[0] != "x":
                    o_file.write("\t%s" % name)
            o_file.write("\n")
    o_file.close()

######################################################################
#Main method
def main():
    #Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--report-file', '--report', required=True,
        dest='r_file', help='Input kraken report file for converting')
    parser.add_argument('-o', '--output', required=True,
        dest='o_file', help='Output krona-report file name')
    parser.add_argument('--intermediate-ranks', action='store_true',
        dest='x_include', default=False, required=False,
        help='Include non-traditional taxonomic ranks in output')
    parser.add_argument('--no-intermediate-ranks', action='store_false',
        dest='x_include', default=False, required=False,
        help='Do not include non-traditional taxonomic ranks in output [default: no intermediate ranks]')
    args=parser.parse_args()

    #Determine which krona report to make 
    if args.x_include:
        kreport2krona_all(args.r_file,args.o_file)
    else:
        kreport2krona_main(args.r_file,args.o_file) 

#################################################################
if __name__ == "__main__":
    main()
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os, sys, argparse

#process_kraken_report
#usage: parses a single line in the kraken report and extracts relevant information
#input: kraken report file with the following tab delimited lines
#   - percent of total reads
#   - number of reads (including at lower levels)
#   - number of reads (only at this level)
#   - taxonomy classification of level
#       (U, D, P, C, O, F, G, S, -)
#   - taxonomy ID (0 = unclassified, 1 = root, 2 = Bacteria,...etc)
#   - spaces + name
#returns:
#   - classification/genome name
#   - level name (U, -, D, P, C, O, F, G, S)
#   - reads classified at this level and below in the tree
def process_kraken_report(curr_str):
    split_str = curr_str.strip().split('\t')
    if len(split_str) < 4:
        return []
    try:
        int(split_str[1])
    except ValueError:
        return []
    percents = float(split_str[0])
    all_reads = int(split_str[1])
    #Extract relevant information
    try:
        taxid = int(split_str[-3]) 
        level_type = split_str[-2]
        map_kuniq = {'species':'S', 'genus':'G','family':'F',
            'order':'O','class':'C','phylum':'P','superkingdom':'D',
            'kingdom':'K'}
        if level_type not in map_kuniq:
            level_type = '-'
        else:
            level_type = map_kuniq[level_type]
    except ValueError:
        taxid = int(split_str[-2])
        level_type = split_str[-3]
    #Get name and spaces 
    spaces = 0
    name = split_str[-1]
    for char in name:
        if char == ' ':
            name = name[1:]
            spaces += 1
        else:
            break
    name = name.replace(' ','_')
    #Determine level based on number of spaces
    level_num = spaces/2
    return [name, level_num, level_type, all_reads, percents]

#Main method
def main():
    #Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--report-file', '--report', required=True,
        dest='r_file', help='Input kraken report file for converting')
    parser.add_argument('-o', '--output', required=True,
        dest='o_file', help='Output mpa-report file name')
    parser.add_argument('--display-header', action='store_true', 
        dest='add_header', default=False, required=False,
        help='Include header [Kraken report filename] in mpa-report file [default: no header]') 
    parser.add_argument('--read_count', action='store_true',
        dest='use_reads', default=True, required=False,
        help='Use read count for output [default]')
    parser.add_argument('--percentages', action='store_false',
        dest='use_reads', default=True, required=False,
        help='Use percentages for output [instead of reads]')
    parser.add_argument('--intermediate-ranks', action='store_true',
        dest='x_include', default=False, required=False,
        help='Include non-traditional taxonomic ranks in output')
    parser.add_argument('--no-intermediate-ranks', action='store_false',
        dest='x_include', default=False, required=False,
        help='Do not include non-traditional taxonomic ranks in output [default]')
    args=parser.parse_args()

    #Process report file and output 
    curr_path = [] 
    prev_lvl_num = -1
    r_file = open(args.r_file, 'r')
    o_file = open(args.o_file, 'w')
    #Print header
    if args.add_header:
        o_file.write("taxon_name\treads\n")

    #Read through report file 
    main_lvls = ['R','K','D','P','C','O','F','G','S']
    for line in r_file:
        report_vals = process_kraken_report(line)
        #If header line, skip
        if len(report_vals) < 5: 
            continue
        #Get relevant information from the line 
        [name, level_num, level_type, all_reads, percents] = report_vals
        if level_type == 'U':
            continue
        #Create level name 
        if level_type not in main_lvls:
            level_type = "x"
        elif level_type == "K":
            level_type = "k"
        elif level_type == "D":
            level_type = "k"
        level_str = level_type.lower() + "__" + name
        #Determine full string to add
        if prev_lvl_num == -1:
            #First level
            prev_lvl_num = level_num
            curr_path.append(level_str)
        else:
            #Move back if needed
            while level_num != (prev_lvl_num + 1):
                prev_lvl_num -= 1
                curr_path.pop()
            #Print if at non-traditional level and that is requested
            if (level_type == "x" and args.x_include) or level_type != "x":
                #Print all ancestors of current level followed by |
                for string in curr_path:
                    if (string[0] == "x" and args.x_include) or string[0] != "x":
                        if string[0] != "r": 
                            o_file.write(string + "|")
                #Print final level and then number of reads
                if args.use_reads:
                    o_file.write(level_str + "\t" + str(all_reads) + "\n")
                else:
                    o_file.write(level_str + "\t" + str(percents) + "\n")
            #Update
            curr_path.append(level_str)
            prev_lvl_num = level_num
    o_file.close()
    r_file.close()

if __name__ == "__main__":
    main()
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
__author__ = "Fredrik Boulund"
__date__ = "2019-03-07"
__version__ = "2.0.0"

from sys import argv, exit, stderr
from collections import defaultdict
from pathlib import Path
import argparse
import logging
import csv

logging.basicConfig(format="%(levelname)s: %(message)s")


def parse_args():
    desc = "{} Version v{}. Copyright (c) {}.".format(__doc__, __version__, __author__, __date__[:4])
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument("RPKM", nargs="+",
            help="RPKM file(s) from BBMap pileup.sh.")
    parser.add_argument("-c", "--columns", dest="columns",
            default="",
            help="Comma-separated list of column names to include [all columns].")
    parser.add_argument("-a", "--annotation-file", dest="annotation_file", 
            required=True,
            help="Two-column tab-separated annotation file.")
    parser.add_argument("-o", "--outdir", dest="outdir", metavar="DIR",
            default="",
            help="Directory for output files, will create one output file per selected column.")

    if len(argv) < 2:
        parser.print_help()
        exit(1)

    return parser.parse_args()


def parse_rpkm(rpkm_file):
    read_counts = {}
    with open(rpkm_file) as f:
        firstline = f.readline()
        if not firstline.startswith("#File"):
            logging.error("File does not look like a BBMap pileup.sh RPKM: %s", rpkm_file)
        _ = [f.readline() for l in range(4)] # Skip remaining header lines: #Reads, #Mapped, #RefSequences, Table header
        for line_no, line in enumerate(f, start=1):
            try:
                ref, length, bases, coverage, reads, RPKM, frags, FPKM = line.strip().split("\t")
            except ValueError:
                logging.error("Could not parse RPKM file line %s: %s", line_no, rpkm_file)
                continue
            if int(reads) != 0:
                ref = ref.split()[0]  # Truncate reference header on first space
                read_counts[ref] = int(reads)
    return read_counts


def parse_annotations(annotation_file):
    annotations = defaultdict(dict)
    with open(annotation_file) as f:
        csv_reader = csv.DictReader(f, delimiter="\t")
        for line in csv_reader:
            ref = list(line.values())[0].split()[0]  # Truncate reference header on first space
            for colname, value in list(line.items())[1:]:
                annotations[colname][ref] = value
    return annotations


def merge_counts(annotations, rpkms):
    output_table = {"Unknown": [0 for n in range(len(rpkms))]}
    for annotation in set(annotations.values()):
        output_table[annotation] = [0 for n in range(len(rpkms))]
    for idx, rpkm in enumerate(rpkms):
        for ref, count in rpkm.items():
            try:
                output_table[annotations[ref]][idx] += count
            except KeyError:
                logging.warning("Found no annotation for '%s', assigning to 'Unknown'", ref)
                output_table["Unknown"][idx] += count
    return output_table


def write_table(table_data, sample_names, outfile):
    with open(str(outfile), "w") as outf:
        header = "\t".join(["Annotation"] + [sample_name for sample_name in sample_names]) + "\n"
        outf.write(header)
        for ref, counts in table_data.items():
            outf.write("{}\t{}\n".format(ref, "\t".join(str(count) for count in counts)))


if __name__ == "__main__":
    args = parse_args()

    Path(args.outdir).mkdir(parents=True, exist_ok=True)

    rpkms = []
    for rpkm_file in args.RPKM:
        rpkms.append(parse_rpkm(rpkm_file))

    annotations = parse_annotations(args.annotation_file)

    if args.columns:
        selected_columns = []
        for col in args.columns.split(","):
            if col in annotations:
                selected_columns.append(col)
            else:
                logging.warning("Column %s not found in annotation file!", col)
    else:
        selected_columns = list(annotations.keys())

    for selected_column in selected_columns:
        table_data = merge_counts(annotations[selected_column], rpkms)
        sample_names = [Path(fn).stem.split(".")[0] for fn in args.RPKM]

        table_filename = Path(args.outdir) / "counts.{}.tsv".format(selected_column)
        write_table(table_data, sample_names, table_filename)
        logging.debug("Wrote", table_filename)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
__author__ = "Fredrik Boulund"
__date__ = "2022"
__version__ = "0.4"

from sys import argv, exit
from collections import defaultdict
from pathlib import Path
import argparse
import logging

import numpy as np
import pandas as pd
import seaborn as sns

TAXLEVELS = [
    "Kingdom", 
    "Phylum", 
    "Class", 
    "Order", 
    "Family", 
    "Genus", 
    "Species", 
    "Strain",
]

def parse_args():
    desc = f"{__doc__} v{__version__}. {__author__} (c) {__date__}."
    parser = argparse.ArgumentParser(description=desc, epilog="Version "+__version__)
    parser.add_argument("mpa_table",
            help="MetaPhlAn TSV table to plot.")
    parser.add_argument("-o", "--outfile-prefix", dest="outfile_prefix",
            default="mpa_heatmap",
            help="Outfile name [%(default)s]. "
                 "Will be appended with <taxonomic_level>_top<N>.{png,pdf}")
    parser.add_argument("-f", "--force", action="store_true",
            default=False,
            help="Overwrite output file if it already exists [%(default)s].")
    parser.add_argument("-l", "--level", 
            default="Species",
            choices=TAXLEVELS,
            help="Taxonomic level to summarize results for [%(default)s].")
    parser.add_argument("-t", "--topN", metavar="N",
            default=50,
            type=int,
            help="Only plot the top N taxa [%(default)s].")
    parser.add_argument("-p", "--pseudocount", metavar="P",
            default=-1,
            type=float,
            help="Use custom pseudocount, a negative value means to "
                 "autocompute a pseudocount as the median of the 0.01th "
                 "quantile across all samples [%(default)s].")
    parser.add_argument("-c", "--colormap",
            default="viridis",
            help="Matplotlib colormap to use [%(default)s].")
    parser.add_argument("-M", "--method",
            default="average",
            help="Linkage method to use, "
                 "see scipy.cluster.hierarchy.linkage docs [%(default)s].")
    parser.add_argument("-m", "--metric",
            default="euclidean",
            help="Distance metric to use, "
                 "see scipy.spatial.distance.pdist docs [%(default)s].")
    parser.add_argument("-L", "--loglevel", choices=["INFO", "DEBUG"],
            default="INFO",
            help="Set logging level [%(default)s].")

    if len(argv) < 2:
        parser.print_help()
        exit()

    return parser.parse_args()


def parse_mpa_table(mpa_tsv):
    """Read joined MetaPhlAn tables into a Pandas DataFrame.

    * Convert ranks from first column into hierarchical MultiIndex
    """
    with open(mpa_tsv) as f:
        for lineno, line in enumerate(f):
            if line.startswith("#"):
                continue
            elif line.startswith("clade_name"):
                skiprows = lineno
                dropcols = ["clade_name"]
                break
            elif not line.startswith("#"):
                logger.error(f"Don't know how to process table")
                exit(3)

    df = pd.read_csv(mpa_tsv, sep="\t", skiprows=skiprows)

    logger.debug(df.head())

    lineages = df[dropcols[0]].str.split("|", expand=True)
    levels_present = TAXLEVELS[:len(lineages.columns)]  # Some tables don't have strain or species assignments
    df[levels_present] = lineages\
            .rename(columns={key: level for key, level in zip(range(len(levels_present)), levels_present)})
    mpa_table = df.drop(columns=dropcols).set_index(levels_present)

    logger.debug(f"Parsed data dimensions: {mpa_table.shape}")
    logger.debug(mpa_table.sample(10))

    return mpa_table


def extract_specific_level(mpa_table, level):
    """Extract abundances for a specific taxonomic level."""

    level_pos = mpa_table.index.names.index(level)

    if level_pos+1 == len(mpa_table.index.names):
        level_only = ~mpa_table.index.get_level_values(level).isnull()
        mpa_level = mpa_table.loc[level_only]
    else:
        level_assigned = ~mpa_table.index.get_level_values(level).isnull()
        next_level_assigned = ~mpa_table.index.get_level_values(mpa_table.index.names[level_pos+1]).isnull()
        level_only = level_assigned & ~next_level_assigned  # AND NOT 
        mpa_level = mpa_table.loc[level_only]

    ranks = mpa_table.index.names.copy()
    ranks.remove(level)
    mpa_level.index = mpa_level.index.droplevel(ranks)
    logger.debug(f"Table dimensions after extracting {level}-level only: {mpa_level.shape}")
    return mpa_level


def plot_clustermap(mpa_table, topN, pseudocount, colormap, method, metric):
    """Plot Seaborn clustermap."""

    top_taxa = mpa_table.median(axis=1).nlargest(topN)
    mpa_topN = mpa_table.loc[mpa_table.index.isin(top_taxa.index)]
    logger.debug(f"Table dimensions after extracting top {topN} taxa: {mpa_topN.shape}")

    if pseudocount < 0:
        pseudocount = mpa_topN.quantile(0.05).median() / 10
        if pseudocount < 1e-10:
            logger.warning(f"Automatically generated pseudocount is very low: {pseudocount}! "
                            "Setting pseudocount to 1e-10.")
            pseudocount = 1e-10
        logger.debug(f"Automatically generated pseudocount is: {pseudocount}")

    figwidth = mpa_topN.shape[1]
    figheight = 10+topN/5

    sns.set("notebook")
    clustergrid = sns.clustermap(
            mpa_topN.apply(lambda x: np.log10(x+pseudocount)),
            figsize=(figwidth, figheight),
            method=method,
            metric=metric,
            cmap=colormap,
            cbar_kws={"label": "$log_{10}$(abundance)"},
    )
    return clustergrid


def main(mpa_table, outfile_prefix, overwrite, level, topN, pseudocount, colormap, method, metric):
    mpa_table = parse_mpa_table(mpa_table)
    mpa_level = extract_specific_level(mpa_table, level)
    clustermap = plot_clustermap(mpa_level, topN, pseudocount, colormap, method, metric)

    outfile_png = Path(f"{outfile_prefix}.{level}_top{topN}.png")
    outfile_pdf = Path(f"{outfile_prefix}.{level}_top{topN}.pdf")
    if (outfile_png.exists() or outfile_pdf.exists()) and not overwrite:
        logger.error(f"Output file {outfile_png} or {outfile_pdf} already exists and --force is not set.")
        exit(2)

    clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xticklabels(), rotation=90)

    clustermap.savefig(outfile_png)
    clustermap.savefig(outfile_pdf)


if __name__ == "__main__":
    args = parse_args()
    logger = logging.getLogger(__name__)
    loglevels = {"INFO": logging.INFO, "DEBUG": logging.DEBUG}
    logging.basicConfig(format='%(levelname)s: %(message)s', level=loglevels[args.loglevel])

    main(
        args.mpa_table, 
        args.outfile_prefix, 
        args.force,
        args.level,
        args.topN,
        args.pseudocount,
        args.colormap,
        args.method,
        args.metric,
    )
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
__author__ = "CTMR, Fredrik Boulund"
__date__ = "2020"
__version__ = "0.1"

from sys import argv, exit
from pathlib import Path
import argparse

import matplotlib as mpl
mpl.use("agg")
mpl.rcParams.update({'figure.autolayout': True})
import matplotlib.pyplot as plt

import pandas as pd

def parse_args():
    desc = f"{__doc__} Copyright (c) {__author__} {__date__}. Version v{__version__}"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("log_output", metavar="LOG", nargs="+",
            help="Kraken2 log output (txt).")
    parser.add_argument("-H", "--histogram", dest="histogram", metavar="FILE",
            default="histogram.pdf",
            help="Filename of output histogram plot [%(default)s].")
    parser.add_argument("-b", "--barplot", dest="barplot", metavar="FILE",
            default="barplot.pdf",
            help="Filename of output barplot [%(default)s].")
    parser.add_argument("-t", "--table", dest="table", metavar="FILE",
            default="proportions.tsv",
            help="Filename of histogram data in TSV format [%(default)s].")
    parser.add_argument("-u", "--unclassified", dest="unclassified", action="store_true",
            default=False,
            help="Plot proportion unclassified reads instead of classified reads [%(default)s].")

    if len(argv) < 2:
        parser.print_help()
        exit(1)

    return parser.parse_args()


def parse_kraken2_logs(logfiles, unclassified):
    search_string = "unclassified" if unclassified else " classified"
    for logfile in logfiles:
        with open(logfile) as f:
            sample_name = Path(logfile).stem.split(".")[0]
            for line in f:
                if search_string in line:
                    yield sample_name, float(line.split("(")[1].split(")")[0].strip("%"))


if __name__ == "__main__":
    options = parse_args()

    proportions = list(parse_kraken2_logs(options.log_output, options.unclassified))
    action = "unclassified" if options.unclassified else "classified"

    df = pd.DataFrame(proportions, columns=["Sample", "Proportion"]).set_index("Sample").rename(columns={"Proportion": f"% {action}"})
    print("Loaded {} proportions for {} samples.".format(df.shape[0], len(df.index.unique())))

    fig, ax = plt.subplots(figsize=(7, 5))
    df.plot(kind="hist", ax=ax, legend=None)
    ax.set_title(f"Proportion {action} reads")
    ax.set_xlabel(f"Proportion {action} reads")
    ax.set_ylabel("Frequency")
    fig.savefig(options.histogram, bbox_inches="tight")

    length_longest_sample_name = max([s for s in df.index.str.len()])
    fig2_width = max(5, length_longest_sample_name * 0.4)
    fig2_height = max(3, df.shape[0] * 0.25)

    fig2, ax2 = plt.subplots(figsize=(fig2_width, fig2_height))
    df.plot(kind="barh", ax=ax2, legend=None)
    ax2.set_title(f"Proportion {action} reads")
    ax2.set_xlabel(f"Proportion {action} reads")
    ax2.set_ylabel("Sample")
    fig2.savefig(options.barplot, bbox_inches="tight")

    df.to_csv(options.table, sep="\t")
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
__author__ = "Fredrik Boulund"
__date__ = "2018"
__version__ = "0.2.0"

from sys import argv, exit
from pathlib import Path
import argparse

import matplotlib as mpl
mpl.use("agg")
mpl.rcParams.update({'figure.autolayout': True})

import pandas as pd
import seaborn as sns


def parse_args():
    desc = f"{__doc__} Version {__version__}. Copyright (c) {__author__} {__date__}."
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("alltoall", metavar="alltoall",
            help="Output table from BBMap's comparesketch.sh in format=3.")
    parser.add_argument("-o", "--outfile", dest="outfile", metavar="FILE",
            default="all_vs_all.pdf",
            help="Filename of heatmap plot [%(default)s].")
    parser.add_argument("-c", "--clustered", dest="clustered", metavar="FILE",
            default="all_vs_all.clustered.pdf",
            help="Filename of clustered heatmap plot [%(default)s].")
    if len(argv) < 2:
        parser.print_help()
        exit(1)

    return parser.parse_args()


if __name__ == "__main__":
    options = parse_args()

    df = pd.read_table(
            options.alltoall, 
            index_col=False)
    print("Loaded data for {} sample comparisons.".format(df.shape[0]))

    similarity_matrix = df.pivot(index="#Query", 
            columns="Ref", values="ANI").fillna(100)

    corr = similarity_matrix.corr().fillna(0)
    g = sns.heatmap(corr, annot=True, fmt="2.1f", annot_kws={"fontsize": 2})
    g.set_title("Sample similarity")
    #g.set_xticklabels(g.get_xticklabels(), fontsize=4)  #WIP
    #g.set_yticklabels(g.get_yticklabels(), rotation=0, fontsize=4) #WIP
    g.set_ylabel("")
    g.set_xlabel("")
    g.figure.savefig(str(Path(options.outfile)))

    g = sns.clustermap(corr, annot=True, fmt="2.1f", annot_kws={"fontsize": 2})
    g.fig.suptitle("Sample similarity (clustered)")
    g.savefig(str(Path(options.clustered)))
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
__author__ = "CTMR, Fredrik Boulund"
__date__ = "2021-2023"
__version__ = "0.3"

from sys import argv, exit
from pathlib import Path
import json
import argparse

import matplotlib as mpl
mpl.use("agg")
mpl.rcParams.update({'figure.autolayout': True})
import matplotlib.pyplot as plt

import pandas as pd

def parse_args():
    desc = f"{__doc__} Copyright (c) {__author__} {__date__}. Version v{__version__}"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("--fastp", metavar="sample.json", nargs="+",
            help="fastp JSON output file.")
    parser.add_argument("--kraken2", metavar="sample.kraken2.log", nargs="+",
            help="Kraken2 log output.")
    parser.add_argument("--bowtie2", metavar="sample.samtools.fastq.log", nargs="+",
            help="Bowtie2 samtools fastq log output.")
    parser.add_argument("-o", "--output-table", metavar="TSV",
            default="read_processing_summary.txt",
            help="Filename of output table in tsv format [%(default)s].")
    parser.add_argument("-p", "--output-plot", metavar="PDF",
            default="",
            help="Filename of output table in PDF format [%(default)s].")

    if len(argv) < 2:
        parser.print_help()
        exit(1)

    return parser.parse_args()


def parse_bowtie2_samtools_fastq_logs(logfiles):
    for logfile in logfiles:
        with open(logfile) as f:
            sample_name = Path(logfile).stem.split(".")[0]
            for line in f:
                if not line.startswith("[M::bam2fq_mainloop]"):
                    raise ValueError
                if "bam2fq_mainloop] processed" in line:
                    yield {
                        "Sample": sample_name,
                        "after_bowtie2_host_removal": int(int(line.split()[2])/2),  # /2 because bowtie2 counts both pairs
                    }


def parse_kraken2_logs(logfiles):
    for logfile in logfiles:
        with open(logfile) as f:
            sample_name = Path(logfile).stem.split(".")[0]
            for line in f:
                if " unclassified" in line:
                    yield {
                        "Sample": sample_name,
                        "after_kraken2_host_removal": int(line.strip().split()[0]),
                    }


def parse_fastp_logs(logfiles):
    for logfile in logfiles:
        sample_name = Path(logfile).stem.split(".")[0]
        with open(logfile) as f:
            fastp_data = json.load(f)
            yield {
                "Sample": sample_name, 
                "before_fastp": int(fastp_data["summary"]["before_filtering"]["total_reads"]/2),  # /2 because fastp counts both pairs
                "after_fastp": int(fastp_data["summary"]["after_filtering"]["total_reads"]/2), 
                "duplication": float(fastp_data["duplication"]["rate"]),
            }


if __name__ == "__main__":
    args = parse_args()

    dfs = {
        "fastp": pd.DataFrame(),
        "kraken2": pd.DataFrame(),
        "bowtie2": pd.DataFrame(),
    }

    if args.fastp:
        data_fastp = list(parse_fastp_logs(args.fastp))
        dfs["fastp"] = pd.DataFrame(data_fastp).set_index("Sample")
    if args.kraken2:
        data_kraken2 = list(parse_kraken2_logs(args.kraken2))
        dfs["kraken2"] = pd.DataFrame(data_kraken2).set_index("Sample")
    if args.bowtie2:
        data_bowtie2 = list(parse_bowtie2_samtools_fastq_logs(args.bowtie2))
        dfs["bowtie2"] = pd.DataFrame(data_bowtie2).set_index("Sample")

    df = pd.concat(dfs.values(), axis="columns")

    column_order = [
        "duplication",
        "before_fastp",
        "after_fastp",
        "after_kraken2_host_removal",
        "after_bowtie2_host_removal",
    ]
    final_columns = [c for c in column_order if c in df.columns]
    df = df[final_columns]

    df.to_csv(args.output_table, sep="\t")

    if args.output_plot:
        fig, ax = plt.subplots(figsize=(6, 5))
        df[final_columns[1:]]\
            .transpose()\
            .plot(kind="line", style=".-", ax=ax)
        ax.set_title("Reads passing through QC and host removal")
        ax.set_xlabel("Stage")
        ax.set_ylabel("Reads")
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, loc="upper left", bbox_to_anchor=(0, -0.1))
        fig.savefig(args.output_plot, bbox_inches="tight")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
__author__ = "Johannes Köster"
__copyright__ = "Copyright 2016, Johannes Köster"
__email__ = "[email protected]"
__license__ = "MIT"


from snakemake.shell import shell

extra = snakemake.params.get("extra", "")
log = snakemake.log_fmt_shell(stdout=True, stderr=True)

n = len(snakemake.input.sample)
assert n == 1 or n == 2, "input->sample must have 1 (single-end) or 2 (paired-end) elements."

if n == 1:
    reads = "-U {}".format(*snakemake.input.sample)
else:
    reads = "-1 {} -2 {}".format(*snakemake.input.sample)

shell(
    "(bowtie2 --threads {snakemake.threads} {snakemake.params.extra} "
    "-x {snakemake.params.index} {reads} "
    "| samtools view -Sbh -o {snakemake.output[0]} -) {log}")
ShowHide 61 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/ctmrbio/stag-mwc
Name: stag-mwc
Version: v0.7.0
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Other Versions:
Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...