Integrated workflow for SV calling from single-cell Strand-seq data

public public 1yr ago Version: 2.2.1 0 bookmarks

MosaiCatcher

Structural variant calling from single-cell Strand-seq data Snakemake pipeline.

Overview of this workflow

This workflow uses Snakemake to execute all steps of MosaiCatcher in order. The starting point are single-cell BAM files from Strand-seq experiments and the final output are SV predictions in a tabular format as well as in a graphical representation. To get to this point, the workflow goes through the following steps:

  1. Binning of sequencing reads in genomic windows of 200kb via mosaic

  2. Strand state detection

  3. [Optional]Normalization of coverage with respect to a reference sample

  4. Multi-variate segmentation of cells ( mosaic )

  5. Haplotype resolution via StrandPhaseR

  6. Bayesian classification of segmentation to find SVs using MosaiClassifier

  7. Visualization of results using custom R plots

summary
MosaiCatcher snakemake pipeline
ashleys-qc-pipeline

📘 Documentation

📆 Roadmap

Technical-related features

  • [x] Zenodo automatic download of external files + indexes ( 1.2.1 )

  • [x] Multiple samples in the parent folder ( 1.2.2 )

  • [x] Automatic testing of BAM SM tag compared to sample folder name ( 1.2.3 )

  • [x] On-error/success e-mail ( 1.3 )

  • [x] HPC execution (slurm profile for the moment) ( 1.3 )

  • [x] Full singularity image with preinstalled conda envs ( 1.5.1 )

  • [x] Single BAM folder with side config file ( 1.6.1 )

  • [x] (EMBL) GeneCore mode of execution: allow selection and execution directly by specifying genecore run folder (2022-11-02-H372MAFX5 for instance) ( 1.8.2 )

  • [x] Version synchronisation between ashleys-qc-pipeline and mosaicatcher-pipeline ( 1.8.3 )

  • [x] Report captions update ( 1.8.5 )

  • [x] Clustering plot (heatmap) & SV calls plot update ( 1.8.6 )

  • [ ] Plotting options (enable/disable segmentation back colors)

Bioinformatic-related features

  • [x] Self-handling of low-coverage cells ( 1.6.1 )

  • [x] Upstream ashleys-qc-pipeline and FASTQ handle ( 1.6.1 )

  • [x] Change of reference genome (currently only GRCh38) ( 1.7.0 )

  • [x] Ploidy detection at the segment and the chromosome level: used to bypass StrandPhaseR if more than half of a chromosome is haploid ( 1.7.0 )

  • [x] inpub_bam_legacy mode (bam/selected folders) ( 1.8.4 )

  • [x] Blacklist regions files for T2T & hg19 ( 1.8.5 )

  • [x] ArbiGent integration: Strand-Seq based genotyper to study SV containly at least 500bp of uniquely mappable sequence ( 1.9.0 )

  • [x] scNOVA integration: Strand-Seq Single-Cell Nucleosome Occupancy and genetic Variation Analysis ( 1.9.2 )

  • [ ] Pooled samples

Small issues to fix

  • [ ] Move pysam / SM tag comparison script to snakemake rule

  • [x] replace input_bam_location by data_location (harmonization with ashleys-qc-pipeline )

  • [x] List of commands available through list_commands parameter ( 1.8.6

🛑 Troubleshooting & Current limitations

  • Do not change the structure of your input folder after running the pipeline, first execution will build a config dataframe file ( OUTPUT_DIRECTORY/config/config.tsv ) that contains the list of cells and the associated paths

  • Do not change the list of chromosomes after a first execution (i.e: first execution on chr17 , second execution on all chromosomes)

💂‍♂️ Authors (alphabetical order)

  • Ashraf Hufash

  • Cosenza Marco

  • Ebert Peter

  • Ghareghani Maryam

  • Grimes Karen

  • Gros Christina

  • Höps Wolfram

  • Jeong Hyobin

  • Kinanen Venla

  • Korbel Jan

  • Marschall Tobias

  • Meiers Sasha

  • Porubsky David

  • Rausch Tobias

  • Sanders Ashley

  • Van Vliet Alex

  • Weber Thomas (maintainer and current developer)

📕 References

Strand-seq publication: Falconer, E., Hills, M., Naumann, U. et al. DNA template strand sequencing of single-cells maps genomic rearrangements at high resolution. Nat Methods 9, 1107–1112 (2012). https://doi.org/10.1038/nmeth.2206

scTRIP/MosaiCatcher original publication: Sanders, A.D., Meiers, S., Ghareghani, M. et al. Single-cell analysis of structural variations and complex rearrangements with tri-channel processing. Nat Biotechnol 38, 343–354 (2020). https://doi.org/10.1038/s41587-019-0366-x

ArbiGent publication: Porubsky, David, Wolfram Höps, Hufsah Ashraf, PingHsun Hsieh, Bernardo Rodriguez-Martin, Feyza Yilmaz, Jana Ebler, et al. 2022. “Recurrent Inversion Polymorphisms in Humans Associate with Genetic Instability and Genomic Disorders.” Cell 185 (11): 1986-2005.e26. https://doi.org/10.1016/j.cell.2022.04.017.

scNOVA publication: Jeong, Hyobin, Karen Grimes, Kerstin K. Rauwolf, Peter-Martin Bruch, Tobias Rausch, Patrick Hasenfeld, Eva Benito, et al. 2022. “Functional Analysis of Structural Variants in Single Cells Using Strand-Seq.” Nature Biotechnology, November, 1–13. https://doi.org/10.1038/s41587-022-01551-4.

Code Snippets

17
18
script:
    "../scripts/arbigent_utils/create_hdf.py"
31
32
shell:
    "grep -E -- '{params.chromosomes}' {input} > {output}"
56
57
script:
    "../scripts/arbigent_utils/watson_crick.py"
68
69
shell:
    "sed 's/.sort.mdup//g' {input} > {output}"
80
81
82
83
shell:
    """
    awk '!seen[$1,$2,$3]++' {input.counts_file} > {output.msc}
    """
98
99
script:
    "../scripts/arbigent_utils/mosaiclassifier_scripts/mosaiClassifier.snakemake.R"
24
25
26
27
28
29
30
shell:
    """
    Rscript workflow/scripts/arbigent/regenotype.R \
                    -f {input.probabilities_table} \
                    -c {input.msc} \
                    -o {output.sv_calls_bulk_dir}/ > {log} 2>&1
    """
48
49
50
51
shell:
    """
    awk 'FNR==1 && NR!=1 {{ while (/^chrom/) getline; }} 1 {{print}}' {input.sv_calls_bulk} > {output}
    """
67
68
69
70
shell:
    """
    cp {input.all_txt} {output.all_txt_rephased}
    """
 94
 95
 96
 97
 98
 99
100
shell:
    """
    Rscript workflow/scripts/arbigent/table_to_vcfs.R \
                            -a {input.alltxt} \
                            -m {input.msc} \
                            -o {output.res_csv_dir}/ > {log} 2>&1
    """
118
119
120
121
122
123
124
shell:
    """
    Rscript workflow/scripts/arbigent/add_filter.R \
                        -i {input.res_csv} \
                        -n {params.names_gm_to_na} \
                        -o {output.verdicted_table} > {log} 2>&1
    """
140
141
142
143
144
145
shell:
    """
    Rscript workflow/scripts/arbigent/qc_res_verdicted.R \
                        -f {input.verdicted_table} \
                        -o {output.verdict_plot_dir}/ > {log} 2>&1
    """
19
20
script:
    "../scripts/utils/generate_exclude_file.py"
54
55
56
57
58
59
60
61
62
63
64
65
shell:
    """
    mosaicatcher count \
        --verbose \
        --do-not-blacklist-hmm \
        -o {output.counts} \
        -i {output.info} \
        -x {input.excl} \
        -w {params.window} \
        {input.bam} \
    > {log} 2>&1
    """
79
80
script:
    "../scripts/utils/populated_counts_for_qc_plot.py"
93
94
script:
    "../scripts/utils/handle_input_old_behavior.py"
105
106
shell:
    "echo 'cell\tprobability\tprediction' > {output}"
SnakeMake From line 105 of rules/count.smk
122
123
shell:
    "cp {input} {output}"
SnakeMake From line 122 of rules/count.smk
137
138
script:
    "../scripts/utils/symlink_selected_bam.py"
SnakeMake From line 137 of rules/count.smk
151
152
153
154
shell:
    """
    rm {input.bam} {input.bai}
    """
SnakeMake From line 151 of rules/count.smk
179
180
script:
    "../scripts/utils/filter_bad_cells.py"
SnakeMake From line 179 of rules/count.smk
201
202
203
204
shell:
    """
    workflow/scripts/normalization/merge-blacklist.py --merge_distance 500000 {input.norm} --whitelist {input.whitelist} --min_whitelist_interval_size {params.window} > {output.merged} 2>> {log}
    """
SnakeMake From line 201 of rules/count.smk
219
220
221
222
shell:
    """
    workflow/scripts/normalization/merge-blacklist.py --merge_distance 500000 {input.norm} > {output.merged} 2> {log}
    """
SnakeMake From line 219 of rules/count.smk
247
248
249
250
shell:
    """
    Rscript workflow/scripts/normalization/normalize.R {input.counts} {input.norm} {output} {params.normalisation_type} 2>&1 > {log}
    """
SnakeMake From line 247 of rules/count.smk
263
264
shell:
    "cp {input} {output}"
SnakeMake From line 263 of rules/count.smk
278
279
script:
    "../scripts/utils/sort_counts.py"
SnakeMake From line 278 of rules/count.smk
292
293
294
295
shell:
    """
    zcat {input.counts} | awk -v name={wildcards.cell} '(NR==1) || $5 == name' | gzip > {output}
    """
SnakeMake From line 292 of rules/count.smk
17
18
run:
    shell("unzip {input} -d .")
33
34
35
36
37
38
39
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/ref_genomes/hg19.fa.gz
    gunzip workflow/data/ref_genomes/hg19.fa.gz
    """
54
55
56
57
58
59
60
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/ref_genomes/hg38.fa.gz
    gunzip workflow/data/ref_genomes/hg38.fa.gz
    """
75
76
77
78
79
80
81
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/ref_genomes/T2T.fa.gz
    gunzip workflow/data/ref_genomes/T2T.fa.gz
    """
 96
 97
 98
 99
100
101
102
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/ref_genomes/mm10.fa.gz
    gunzip workflow/data/ref_genomes/mm10.fa.gz
    """
117
118
119
120
121
122
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/ref_genomes/BSgenome.T2T.CHM13.V2_1.0.0.tar.gz
    """
135
136
137
138
139
140
shell:
    """
    directory="workflow/data/arbigent/"
    mkdir -p "$directory"
    mv {input} {output}
    """
176
177
178
179
180
181
182
shell:
    """
    directory="workflow/data/ref_genomes/"
    mkdir -p "$directory"
    mv {input} workflow/data/scNOVA_data_models.zip
    unzip workflow/data/scNOVA_data_models.zip -d workflow/data/
    """
21
22
script:
    "../scripts/GC/library_size_normalisation.R"
SnakeMake From line 21 of rules/gc.smk
53
54
script:
    "../scripts/GC/GC_correction.R"
SnakeMake From line 53 of rules/gc.smk
76
77
script:
    "../scripts/GC/variance_stabilizing_transformation.R"
SnakeMake From line 76 of rules/gc.smk
91
92
script:
    "../scripts/utils/populated_counts_for_qc_plot.py"
SnakeMake From line 91 of rules/gc.smk
103
104
script:
    "../scripts/utils/reformat_ms_norm.py"
SnakeMake From line 103 of rules/gc.smk
123
124
125
126
shell:
    """
    LC_CTYPE=C Rscript workflow/scripts/plotting/qc.R {input.counts} {input.info} {output} > {log} 2>&1
    """
SnakeMake From line 123 of rules/gc.smk
35
36
shell:
    "whatshap haplotag --skip-missing-contigs -o {output} -r {input.fasta} {input.vcf} {input.bam} > {log} 2>{log}  "
52
53
54
55
56
shell:
    """
    # Issue #1022 (https://bitbucket.org/snakemake/snakemake/issues/1022)
    awk -v s={params.window} -f workflow/scripts/haplotagging_scripts/create_haplotag_segment_bed.awk {input.segments} > {output.bed}
    """
73
74
script:
    "../scripts/haplotagging_scripts/haplotagTable.snakemake.R"
94
95
shell:
    "(head -n1 {input.tsvs[0]} && tail -q -n +2 {input.tsvs}) > {output.tsv}"
17
18
script:
    "../scripts/mosaiclassifier_scripts/mosaiClassifier.snakemake.R"
33
34
script:
    "../scripts/mosaiclassifier_scripts/haplotagProbs.snakemake.R"
57
58
script:
    "../scripts/mosaiclassifier_scripts/mosaiClassifier_call.snakemake.R"
72
73
script:
    "../scripts/mosaiclassifier_scripts/mosaiClassifier_call_biallelic.snakemake.R"
87
88
89
90
91
92
93
94
shell:
    """
    PYTHONPATH="" # Issue #1031 (https://bitbucket.org/snakemake/snakemake/issues/1031)
    workflow/scripts/mosaiclassifier_scripts/call-complex-regions.py \
    --merge_distance 5000000 \
    --ignore_haplotypes \
    --min_cell_count 2 {input.calls} > {output.complex_regions} 2>{log}
    """
20
21
22
23
24
25
26
27
28
29
30
31
shell:
    """
    python workflow/scripts/ploidy/ploidy_estimator.py --debug \
        --merge-bins-to {params.merge_window} \
        --shift-window-by {params.shift_step} \
        --max-ploidy {params.max_ploidy} \
        --boundary-alpha {params.boundary_alpha} \
        --jobs {threads} \
        --input {input.counts} \
        --output {output} \
        --log {log}
    """
43
44
script:
    "../scripts/ploidy/summarise_ploidy.py"
65
66
script:
    "../scripts/ploidy/ploidy_bcftools.py"
30
31
32
33
shell:
    """
    LC_CTYPE=C Rscript workflow/scripts/plotting/qc.R {input.counts} {input.info} {output} > {log} 2>&1
    """
53
54
script:
    "../scripts/plotting/dividing_pdf.py"
87
88
shell:
    "touch {output}"
123
124
script:
    "../scripts/plotting/sv_consistency_barplot.snakemake.R"
SnakeMake From line 123 of rules/plots.smk
148
149
script:
    "../scripts/plotting/plot-clustering.snakemake.R"
SnakeMake From line 148 of rules/plots.smk
173
174
script:
    "../scripts/plotting/plot_clustering_dev_clean.R"
SnakeMake From line 173 of rules/plots.smk
199
200
script:
    "../scripts/plotting/plot_clustering_scale_clean.py"
SnakeMake From line 199 of rules/plots.smk
230
231
232
233
234
235
236
237
238
239
240
241
242
shell:
    """
    Rscript workflow/scripts/plotting/plot-sv-calls.R \
        segments={input.segments} \
        singlecellsegments={input.scsegments} \
        strand={input.strand} \
        complex={input.complex_calls} \
        groups={input.grouptrack} \
        calls={input.calls} \
        {input.counts} \
        {wildcards.chrom} \
        {output} > {log} 2>&1
    """
SnakeMake From line 230 of rules/plots.smk
272
273
274
275
276
277
278
279
280
281
282
283
284
shell:
    """
    Rscript workflow/scripts/plotting/plot-sv-calls_dev.R \
        segments={input.segments} \
        singlecellsegments={input.scsegments} \
        strand={input.strand} \
        complex={input.complex_calls} \
        groups={input.grouptrack} \
        calls={input.calls} \
        {input.counts} \
        {wildcards.chrom} \
        {output} > {log} 2>&1
    """
SnakeMake From line 272 of rules/plots.smk
302
303
script:
    "../scripts/plotting/ploidy_plot.py"
SnakeMake From line 302 of rules/plots.smk
319
320
shell:
    "python workflow/scripts/plotting/ucsc_vizu.py {input.counts} {input.stringent_calls} {input.lenient_calls} {output} > {log}"
SnakeMake From line 319 of rules/plots.smk
16
17
18
19
20
21
shell:
    """
    export LC_CTYPE=en_US.UTF-8 
    export LC_ALL=en_US.UTF-8 
    workflow/scripts/postprocessing/filter_MosaiCatcher_calls.pl {input.calls} {params.segdups} > {output.calls}
    """
35
36
37
38
39
40
shell:
    """
    export LC_CTYPE=en_US.UTF-8 
    export LC_ALL=en_US.UTF-8 
    workflow/scripts/postprocessing/group_nearby_calls_of_same_AF_and_generate_output_table.pl {input.calls}  > {output.calls}
    """
54
55
56
57
58
shell:
    """
    PYTHONPATH="" # Issue #1031 (https://bitbucket.org/snakemake/snakemake/issues/1031)
    workflow/scripts/postprocessing/create-sv-group-track.py {input.calls}  > {output.grouptrack}
    """
73
74
75
76
77
shell:
    """
    PYTHONPATH="" # Issue #1031 (https://bitbucket.org/snakemake/snakemake/issues/1031)
    workflow/scripts/postprocessing/apply_filter.py {input.inputcalls} {input.mergedcalls} > {output.calls}
    """
19
20
shell:
    "samtools merge -@ {threads} {output} {input.bam} 2>&1 > {log}"
36
37
shell:
    "samtools sort -@ {threads} -o {output} {input} 2>&1 > {log}"
51
52
shell:
    "samtools index {input} > {log} 2>&1"
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
shell:
    """
    (freebayes \
        -f {input.fasta} \
        -r {wildcards.chrom} \
        -@ {input.sites} \
        --only-use-input-alleles {input.bam} \
        --genotype-qualities \
    | bcftools view \
        --exclude-uncalled \
        --types snps \
        --genotype het \
        --include "QUAL>=10" \
    > {output.vcf}) 2> {log}
    """
108
109
110
111
112
shell:
    """
    bcftools mpileup -r {wildcards.chrom} -f {input.fasta} {input.bam} \
    | bcftools call -mv --ploidy-file {input.ploidy} | bcftools view --genotype het --types snps > {output} 2> {log}
    """
10
11
script:
    "../scripts/scNOVA_scripts/filter_sv_calls.py"
25
26
shell:
    "touch {output}"
50
51
52
53
shell:
    """
    Rscript {params.generate_CN_for_CNN} {input.subclone} {input.sv_calls_all} {input.Deeptool_result_final} {input.CNN_features_annot} {output.sv_calls_all_print} > {log}
    """
77
78
79
80
shell:
    """
    Rscript {params.generate_CN_for_chromVAR} {input.TSS_matrix} {input.TES_matrix} {input.Genebody_matrix} {input.DHS_matrix_resize} {input.subclone} {input.sv_calls_all} {output.sv_calls_all_print}  > {log}
    """
 97
 98
 99
100
101
shell:
    """
    samtools view -H {input} > {output.bam_header} 
    samtools view -F 2304 {input.bam} | awk -f workflow/scripts/scNOVA_scripts/awk_1st.awk | cat {output.bam_header} - | samtools view -Sb - > {output.bam_pre}    
    """
118
119
120
121
shell:
    """
    samtools sort -@ {threads} -O BAM -o {output} {input}
    """
137
138
139
140
shell:
    """
    samtools index {input}
    """
157
158
159
160
shell:
    """
    bammarkduplicates markthreads=2 I={input.bam} O={output.bam_uniq} M={output.bam_metrix} index=1 rmdup=1
    """
SnakeMake From line 157 of rules/scNOVA.smk
176
177
178
179
shell:
    """
    samtools index {input}
    """
198
199
200
201
shell:
    """
    bedtools multicov -bams {input.bam}  -bed workflow/data/scNOVA/utils/bin_Genebody_all.bed > {output.tab}
    """
223
224
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 223 of rules/scNOVA.smk
238
239
240
241
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 238 of rules/scNOVA.smk
262
263
264
265
shell:
    """
    Rscript {params.count_sort_annotate_geneid} {input.count_table} {input.GB_matrix} {output}  
    """
SnakeMake From line 262 of rules/scNOVA.smk
279
280
script:
    "../scripts/scNOVA_scripts/filter_input_subclonality.py"
SnakeMake From line 279 of rules/scNOVA.smk
311
312
313
314
shell:
    """
    perl workflow/scripts/scNOVA_scripts/merge_bam_clones.pl {input.input_subclonality} {output.subclonality_colnames} {output.line}
    """
SnakeMake From line 311 of rules/scNOVA.smk
333
334
335
336
shell:
    """
    bedtools multicov -bams {input.bam}  -bed workflow/data/scNOVA/utils/bin_Genes_for_CNN_sort.txt.corrected > {output.tab}
    """
358
359
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 358 of rules/scNOVA.smk
378
379
380
381
shell:
    """
    bedtools multicov -bams {input.bam}  -bed workflow/data/scNOVA/utils/bin_Genes_for_CNN_sort.txt.corrected > {output.tab}
    """
403
404
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 403 of rules/scNOVA.smk
425
426
427
428
shell:
    """
    bedtools multicov -bams {input.bam} -bed workflow/data/scNOVA/utils/bin_chr_length.bed > {output.tab}
    """
450
451
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 450 of rules/scNOVA.smk
470
471
472
473
shell:
    """
    bedtools multicov -bams {input.bam}  -bed workflow/data/scNOVA/utils/bin_chr_length.bed > {output.tab}
    """
495
496
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 495 of rules/scNOVA.smk
510
511
512
513
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 510 of rules/scNOVA.smk
533
534
535
536
shell:
    """
    Rscript {params.count_sort_label} {input.count_reads_sort} {input.Ref_bed} {output.count_reads_sort_label}
    """
SnakeMake From line 533 of rules/scNOVA.smk
550
551
552
553
shell:
    """
    sort -k4,4n -t$'\t' {input} > {output}
    """
SnakeMake From line 550 of rules/scNOVA.smk
581
582
583
584
shell:
    """
    Rscript {params.count_norm} {input.count_reads_chr_length} {input.count_reads_sort_label} {input.CNN_features_annot} {input.table_CpG} {input.table_GC} {input.table_size} {input.TSS_matrix} {input.FPKM} {input.CN_result_data1} {output.plot} {output.table_mononuc_norm_data1}
    """
SnakeMake From line 581 of rules/scNOVA.smk
598
599
600
601
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 598 of rules/scNOVA.smk
621
622
623
624
shell:
    """
    Rscript {params.count_sort_label} {input.count_reads_sort} {input.Ref_bed} {output.count_reads_sort_label}
    """
SnakeMake From line 621 of rules/scNOVA.smk
638
639
640
641
shell:
    """
    sort -k4,4n -t$'\t' {input} > {output}
    """
SnakeMake From line 638 of rules/scNOVA.smk
669
670
671
672
shell:
    """
    Rscript {params.feature_sc_var} {input.subclone_list} {input.count_reads_sc_sort} {input.Ref_bed_annot} {input.TSS_matrix} {input.CNN_features_annot} {input.FPKM} {input.CN_result_data1} {output.plot} {output.table_mononuc_var_data1} > {log} 2>&1
   """
SnakeMake From line 669 of rules/scNOVA.smk
702
703
704
705
shell:
    """
    Rscript {params.combine_features} {input.TSS_matrix} {input.table_GC_imput} {input.table_CpG_imput} {input.table_RT} {input.table_mononuc_norm_data1} {input.CN_result_data1} {input.table_mononuc_var_data1} {input.FPKM} {output.features} {output.exp} {output.TSS_annot}
    """
SnakeMake From line 702 of rules/scNOVA.smk
722
723
script:
    "../scripts/scNOVA_scripts/Deeplearning_Nucleosome_predict_train_RPE.py"
SnakeMake From line 722 of rules/scNOVA.smk
744
745
script:
    "../scripts/scNOVA_scripts/gather_infer_expr_genes_split.py"
SnakeMake From line 744 of rules/scNOVA.smk
790
791
792
793
shell:
    """
    Rscript {params.annot_expressed} {input.TSS_annot} {input.train80} {input.train40} {input.train20} {input.train5} {output.train80_annot} {output.train40_annot} {output.train20_annot} {output.train5_annot}
    """
SnakeMake From line 790 of rules/scNOVA.smk
823
824
825
826
shell:
    """
    Rscript {params.infer_diff_gene_expression} {input.Genebody_NO} {input.clonality} {input.TSS_matrix} {input.GB_matrix} {input.CNN_result1} {input.CNN_result2} {input.input_matrix} {output.pdf} {output.final_result}
    """
SnakeMake From line 823 of rules/scNOVA.smk
857
858
859
860
shell:
    """
    Rscript {params.infer_diff_gene_expression_alt} {input.Genebody_NO} {input.clonality} {input.TSS_matrix} {input.GB_matrix} {input.CNN_result1} {input.CNN_result2} {input.input_matrix} {output.result_table} {output.result_plot} {input.final_result} 
    """
SnakeMake From line 857 of rules/scNOVA.smk
881
882
883
884
shell:
    """
    bedtools multicov -bams {input.bam}  -bed workflow/data/scNOVA/utils/regions_all_hg38_v2_resize_2kb_sort.bed > {output.tab}
    """
906
907
script:
    "../scripts/scNOVA_scripts/dev_aggr.py"
SnakeMake From line 906 of rules/scNOVA.smk
921
922
923
924
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 921 of rules/scNOVA.smk
946
947
948
949
shell:
    """
    Rscript {params.count_sort_annotate_chrid_CREs} {input} {output} 
    """
SnakeMake From line 946 of rules/scNOVA.smk
963
964
965
966
shell:
    """
    sort -k1,1n -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 963 of rules/scNOVA.smk
986
987
988
989
990
991
992
993
shell:
    """
    samtools view -H {input} > {output.bam_header}
    samtools view -f 99 {input} | cat {output.bam_header} - | samtools view -Sb - > {output.bam_C1}
    samtools view -f 147 {input} | cat {output.bam_header} - | samtools view -Sb - > {output.bam_C2}
    samtools view -f 83 {input} | cat {output.bam_header} - | samtools view -Sb - > {output.bam_W1}
    samtools view -f 163 {input} | cat {output.bam_header} - | samtools view -Sb - > {output.bam_W2}
    """
1013
1014
1015
1016
1017
shell:
    """
    samtools merge {output.bam_C} {input.bam_C1} {input.bam_C2}
    samtools merge {output.bam_W} {input.bam_W1} {input.bam_W2}
    """
1035
1036
1037
1038
1039
shell:
    """
    samtools index {input.bam_C}
    samtools index {input.bam_W}
    """
1071
1072
1073
1074
shell:
    """
    perl workflow/scripts/scNOVA_scripts/perl_test_all_snake.pl {input.strandphaser_output} {output.nucleosome_sampleA} {output.nucleosome_sampleB} {output.strandphaser_output_copy}
    """
SnakeMake From line 1071 of rules/scNOVA.smk
1093
1094
1095
1096
shell:
    """
    bedtools multicov -bams {input.bam1} {input.bam2} -bed workflow/data/scNOVA/utils/regions_all_hg38_v2_resize_2kb_sort.bed > {output.tab}
    """
1110
1111
1112
1113
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 1110 of rules/scNOVA.smk
1134
1135
1136
1137
shell:
    """
    bedtools multicov -bams {input.bam1} {input.bam2}  -bed workflow/data/scNOVA/utils/bin_Genebody_all.bed > {output.tab}
    """
1151
1152
1153
1154
shell:
    """
    sort -k1,1 -k2,2n -k3,3n -t$'\t' {input} > {output}
    """
SnakeMake From line 1151 of rules/scNOVA.smk
17
18
19
20
21
22
23
24
25
shell:
    """
    mosaicatcher segment \
    --remove-none \
    --forbid-small-segments {params.min_num_segs} \
    -M 50000000 \
    -o {output} \
    {input.counts} > {log} 2>&1
    """
41
42
43
44
45
shell:
    """
    # Issue #1022 (https://bitbucket.org/snakemake/snakemake/issues/1022)
    awk -v name={wildcards.sample} -v window={params.window} -f {params.script} {input} > {output}
    """
61
62
63
64
65
66
67
68
69
shell:
    """
    mosaicatcher segment \
    --remove-none \
    --forbid-small-segments {params.min_num_segs} \
    -M 50000000 \
    -o {output} \
    {input} > {log} 2>&1
    """
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
shell:
    """
    PYTHONPATH="" # Issue #1031 (https://bitbucket.org/snakemake/snakemake/issues/1031)
    python workflow/scripts/segmentation_scripts/detect_strand_states.py \
        --sce_min_distance {params.sce_min_distance} \
        --sce_add_cutoff {params.additional_sce_cutoff} \
        --min_diff_jointseg {params.min_diff_jointseg} \
        --min_diff_singleseg {params.min_diff_singleseg} \
        --output_jointseg {output.jointseg} \
        --output_singleseg {output.singleseg} \
        --output_strand_states {output.strand_states} \
        --samplename {wildcards.sample} \
        --cellnames {params.cellnames} \
        {input.info} \
        {input.counts} \
        {input.jointseg} \
        {input.singleseg} > {log} 2>&1
    """
30
31
script:
    "../scripts/utils/install_R_package.R"
50
51
script:
    "../scripts/utils/run_summary.py"
15
16
script:
    "../scripts/stats/summary_stats.py"
41
42
shell:
    "(head -n1 {input.tsv[0]} && (tail -n1 -q {input.tsv} | sort -k1) ) > {output}"
61
62
script:
    "../scripts/stats/transpose_table.py"
11
12
script:
    "../scripts/strandphaser_scripts/helper.convert_strandphaser_input.R"
24
25
script:
    "../scripts/utils/detect_single_paired_end.py"
38
39
script:
    "../scripts/strandphaser_scripts/prepare_strandphaser.py"
65
66
67
68
69
70
71
72
73
74
shell:
    """
    Rscript workflow/scripts/strandphaser_scripts/StrandPhaseR_pipeline.R \
            {params.input_bam} \
            {params.output} \
            {input.configfile} \
            {input.wcregions} \
            {input.snppositions} \
            $(pwd)/utils/R-packages/
    """
89
90
91
92
shell:
    """
    (bcftools concat -a {input.vcfs} | bcftools view -o {output.vcfgz} -O z --genotype het --types snps - ) > {log} 2>&1
    """
106
107
script:
    "../scripts/strandphaser_scripts/combine_strandphaser_output.py"
123
124
script:
    "../scripts/strandphaser_scripts/helper.convert_strandphaser_output.R"
138
139
shell:
    'grep -v -P "[WC]{{3,}}" {input} > {output}'
10
11
12
13
14
15
16
17
18
19
20
21
shell:
    """
    sample_name="{wildcards.sample}"
    sm_tag=$(samtools view -H {input} | grep '^@RG' | sed "s/.*SM:\([^\\t]*\).*/\\1/g")

    if [[ $sample_name == $sm_tag ]]; then 
        echo "{input}: $sm_tag $sample_name OK" > {output}
        echo "{input}: $sm_tag $sample_name OK" > {log}
    else
        echo "{input}: $sm_tag $sample_name MISMATCH" > {log}
    fi
    """
35
36
shell:
    "samtools index {input} > {log} 2>&1"
50
51
shell:
    "samtools index {input} > {log} 2>&1"
65
66
shell:
    "bgzip {input.vcf} > {log} 2>&1"
80
81
shell:
    "tabix -p vcf {input.vcf} > {log} 2>&1"
95
96
shell:
    "bgzip {input.vcf} > {log} 2>&1"
110
111
shell:
    "tabix -p vcf {input.vcf} > {log} 2>&1"
125
126
shell:
    "tabix -p vcf {input.vcf} > {log} 2>&1"
140
141
shell:
    "samtools faidx {input}"
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
library(stringr)
library(dplyr)
library(pheatmap)
library(matrixStats)
library(reshape2)
library(optparse)
source("workflow/scripts/arbigent/postprocess_helpers.R")


# INPUT INSTRUCTIONS
option_list <- list(
  make_option(c("-i", "--table"),
    type = "character", default = NULL,
    help = "res.csv", metavar = "character"
  ),
  make_option(c("-n", "--normal_names"),
    type = "character", default = T,
    help = "Should samplenames be converted from GM to NA?", metavar = "character"
  ),
  make_option(c("-o", "--outfile"),
    type = "character", default = NULL,
    help = "Outfile: verdicted table", metavar = "character"
  )
)

# Parse input
opt_parser <- OptionParser(option_list = option_list)
opt <- parse_args(opt_parser)
callmatrix_link <- opt$table
normal_names <- opt$normal_names
outfile <- opt$outfile

# callmatrix_link = '~/s/g/korbel2/StrandSeq/Test_WH/pipeline_7may/pipeline/regenotyper_allsamples_bulk/arbigent_results/res.csv'
# normal_names = T

cm <- read.table(callmatrix_link, header = 1, sep = "\t", stringsAsFactors = F)


### GO ###

# some inventory. Which samples do we have here? And therefore how many 'other' cols?
print(colnames(cm))
print(tail(colnames(cm), 1))

# samples <- colnames(cm)[grep("^[HMNG].*", colnames(cm))]
samples <- tail(colnames(cm), 1)
print(samples)
# stop()
n_samples <- length(samples)
n_other_cols <- dim(cm)[2] - n_samples

# Rename samples if wanted
if (as.numeric(normal_names)) {
  colnames(cm)[(n_other_cols + 1):dim(cm)[2]] <-
    str_replace(colnames(cm)[(n_other_cols + 1):dim(cm)[2]], "GM", "NA")
  # samples <- colnames(cm)[grep("^[HMNG].*", colnames(cm))]
  samples <- tail(colnames(cm), 1)
}

print(samples)
# Factor char stuff
cm[] <- lapply(cm, as.character)

# stratify entries with 0 valid bins.
cm[cm$valid_bins == 0, samples] <- "noreads"


# Count hom, het, ref, noreads and complex
cm <- count_homhetrefetc(cm, n_samples)

# Calc mapability

cm$valid_bins <- as.numeric(cm$valid_bins)

# Mendel
cm <- add_mendelfails(cm)

# Filter
cm <- apply_filter_new(cm, samples)

# Clean
cm[, c("mendel1", "mendel2", "mendel3")] <- NULL
cm[] <- lapply(cm, as.character)

# Rename inv_dup genotypes
cm <- make_invdups_human_readable(cm, samples)
print(cm)
# stop()
# Sort columns
cols <- c(colnames(cm)[1:n_other_cols], "verdict", "nref", "nhet", "nhom", "ninvdup", "ncomplex", samples)
cm_return <- cm[, cols]

# Save
write.table(cm_return, file = outfile, col.names = T, row.names = F, sep = "\t", quote = F)
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
library(stringr)
library(ggplot2)
library(optparse)
library(reshape2)
library(tibble)
library(matrixStats)
library(grid)
library(dplyr)

make_barplot_invwise <- function(rv_f, samples_f) {
  number <- 1
  ### Define events again ###
  ref_events <- c("0/0", "0|0")
  simple_events <- c("0|1", "0/1", "1|0", "1/0", "1|1", "1/1")
  simple_events_lowconf <- c(
    "0/0_lowconf", "0|0_lowconf", "0|1_lowconf",
    "0/1_lowconf", "1|0_lowconf", "1|0_lowconf", "1/0_lowconf", "1|1_lowconf", "1/1_lowconf"
  )
  verdicts <- c("simple", "simple_lowconf", "ref", "noreads", "complex")

  ### Replace GTs with 'event' - simple, ref, complex, noreads... ###
  res_verdict <- rv_f %>%
    mutate_all(funs(ifelse(. %in% simple_events, "simple", .))) %>%
    mutate_all(funs(ifelse(. %in% simple_events_lowconf, "simple_lowconf", .))) %>%
    mutate_all(funs(ifelse(. %in% ref_events, "ref", .))) %>%
    mutate_all(funs(ifelse(. %in% c("simple", "simple_lowconf", "ref", "noreads"), ., "complex")))

  ### We accidentally replaced chr, start, end too, so we want to recover them.
  res_verdict$chrom <- rv_f$chrom
  res_verdict$start <- rv_f$start
  res_verdict$end <- rv_f$end
  res_verdict$verdict <- rv_f$verdict

  # Now count which verdict is how frequent per inversion. We need this only for one
  # line (marked with !!!##!!!)
  for (verdict in verdicts) {
    res_verdict[[verdict]] <- rowCounts(as.matrix(res_verdict), value = verdict)
  }

  # Enumerate inversions
  res_verdict$invno <- row.names(res_verdict)

  res_verdict_verdictsorted <- res_verdict[order(res_verdict$verdict, res_verdict$simple), "invno"]
  # This is that line: !!!##!!!
  # Get the sorted verdicts (this is a vector of length n_inversions)
  verd <- res_verdict[res_verdict_verdictsorted, "verdict"]
  # Get the position at which we jump to the next verdict
  xpos <- c(0, cumsum(rle(verd)$lengths))
  # Get all verdictnames
  xnames <- rle(verd)$values

  # We go towards plotting. #
  baseheight <- length(samples_f)
  n_classes <- length(xnames)


  res_verdict_molten <- melt(res_verdict[, c("invno", "verdict", samples), drop = F], id.vars = c("invno", "verdict"))
  print("got so far")
  ### PLOT ###
  p2 <- ggplot() +
    geom_bar(data = res_verdict_molten, aes(x = as.character(invno), fill = value)) +
    scale_fill_manual(values = c("blue", "white", "darkgrey", "darkgreen", "green")) +
    xlim(res_verdict_verdictsorted)

  p3 <- p2 + geom_segment(
    aes(
      y = baseheight,
      yend = baseheight,
      x = xpos[1:length(xpos) - 1],
      xend = xpos[2:length(xpos)],
      color = xnames[1:length(xnames)]
    ),
    arrow = arrow(ends = "both"),
  ) +
    geom_text(
      aes(
        x = ((xpos[2:length(xpos)] - xpos[1:length(xpos) - 1]) / 2) + xpos[1:length(xpos) - 1],
        y = seq(from = baseheight + 2, to = baseheight + 2 + (n_classes * 1.3), length.out = n_classes),
        label = xnames[1:length(xnames)],
        color = xnames[1:length(xnames)]
      ),
      check_overlap = F,
      size = 3,
      fontface = "bold",
      angle = 0
    )

  return(p3)
}


make_barplot_groupwise <- function(idf_f) {
  xorder <- as.character((idf_f[order(idf_f$simpleGT, decreasing = T), ]$variable))
  idf_f <- within(idf_f, variable <- factor(variable, levels = xorder))

  idf_2 <- idf_f %>%
    group_by(simpleGT) %>%
    mutate(sum_events = sum(value)) %>%
    slice(1)
  p2 <- ggplot(idf_2) +
    geom_bar(aes(x = simpleGT, y = sum_events, fill = simpleGT), stat = "identity") +
    theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
    scale_x_discrete(rev(xorder))

  return(p2)
}

make_barplot_all <- function(idf_f) {
  xorder <- as.character((idf_f[order(idf_f$simpleGT, decreasing = T), ]$variable))

  idf_f <- within(idf_f, variable <- factor(variable, levels = xorder))

  p1 <- ggplot(idf_f) +
    geom_bar(aes(x = variable, y = value, fill = simpleGT), stat = "identity") +
    theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
    scale_x_discrete(rev(xorder))

  return(p1)
}


make_inventory <- function(rv_f, samples_f) {
  # invent = inventory. We flatten the matrix and count elements
  invent <- table(unlist(rv_f[, samples_f, drop = F]))
  idf <- melt(as.data.frame.matrix(rbind((invent))))

  # What to color how?
  idf$simpleGT <- "complex"
  ref_events <- c("0/0", "0|0")
  simple_events <- c("0|1", "0/1", "1|0", "1/0", "1|1", "1/1")
  simple_events_lowconf <- c(
    "0/0_lowconf", "0|0_lowconf", "0|1_lowconf",
    "0/1_lowconf", "1|0_lowconf", "1|0_lowconf", "1/0_lowconf", "1|1_lowconf", "1/1_lowconf"
  )

  idf[idf$variable %in% ref_events, "simpleGT"] <- "a) ref"
  idf[idf$variable %in% simple_events, "simpleGT"] <- "b) inv_simple"
  idf[idf$variable %in% simple_events_lowconf, "simpleGT"] <- "c) inv_simple_lowconf"
  idf[idf$variable == "noreads", "simpleGT"] <- "z) noreads"

  return(idf)
}




# INPUT INSTRUCTIONS
option_list <- list(
  make_option(c("-f", "--file"),
    type = "character", default = NULL,
    help = "res_verdicted", metavar = "character"
  ),
  make_option(c("-o", "--outdir"),
    type = "character", default = "./outputcorr/",
    help = "Outputdir for phased all.txt and other qcs", metavar = "character"
  )
)


opt_parser <- OptionParser(option_list = option_list)
opt <- parse_args(opt_parser)
res_verdicted_link <- opt$file
outdir <- opt$outdir

# res_verdicted_link = "~/s/g/korbel/hoeps/projects/huminvs/mosai_results/results_freeze4manual/regenotyper_allsamples_bulk/arbigent_results/res_verdicted.vcf"
# res_verdicted_link = "/home/hoeps/Desktop/hufsah_freeze4manual/Arbigent_gts.vcf"
# outdir = '~/Desktop/hufsah_freeze4manual/'

############# RUN CODE #################

# Load res_verdicted
rv <- read.table(res_verdicted_link, header = 1, sep = "\t", stringsAsFactors = F)

# Get the samplenames
# samples <- colnames(rv)[grep("^[HMNG].*", colnames(rv))]
samples <- tail(colnames(rv), 1)

print(rv)
print(samples)
# First, inventory.
idf <- make_inventory(rv, samples)

print(idf)
# Give numbers
gtclass_merge <- (idf %>% group_by(simpleGT) %>% mutate(sum_events = sum(value)) %>% slice(1))[, c("simpleGT", "sum_events")]
verdicts_invwise <- as.data.frame.matrix(rbind(table(rv$verdict)))

# Make plots
p1 <- make_barplot_all(idf)
p2 <- make_barplot_groupwise(idf)
p3 <- make_barplot_invwise(rv, samples)

### SAVE ###

# a) tables
write.table(gtclass_merge, file = paste0(outdir, "/simple-complex-numbers.txt"), sep = "\t", row.names = F, col.names = T, quote = F)
write.table(t(verdicts_invwise), file = paste0(outdir, "/verdicts-numbers.txt"), sep = "\t", row.names = T, col.names = F, quote = F)

# b) plots
ggsave(file = paste0(outdir, "/all_gts_overview.pdf"), plot = p1, width = 30, height = 10, units = "cm", device = "pdf")
ggsave(file = paste0(outdir, "/all_verdicts_overview.pdf"), plot = p2, width = 15, height = 10, units = "cm", device = "pdf")
ggsave(file = paste0(outdir, "/lineplot_gts.pdf"), plot = p3, width = 30, height = 15, units = "cm", device = "pdf")
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
print("Initialising Regenotyper...")

# supposed to suppress warnings, but not working
oldw <- getOption("warn")
options(warn = -1)


# !/usr/bin/env Rscript
suppressMessages(library("optparse"))
suppressMessages(library("tidyr"))
suppressMessages(library("stringr"))


# INPUT INSTRUCTIONS
option_list <- list(
  make_option(c("-f", "--file"),
    type = "character", default = NULL,
    help = "probabilities.R, produced by mosaicatcher", metavar = "character"
  ),
  make_option(c("-b", "--bed"),
    type = "character", default = NULL,
    help = "a bed file specifiying labels/groups for the segments.", metavar = "character"
  ),
  make_option(c("-o", "--outdir"),
    type = "character", default = "./outputcorr/",
    help = "output dir name [default= %default]", metavar = "character"
  ),
  make_option(c("-c", "--cn_map"),
    type = "character", default = NULL,
    help = "average copy numbers and mapability for all segments in given bed file", metavar = "character"
  ),
  make_option(c("-p", "--path_to_self"),
    type = "character", default = NULL,
    help = "Path to regenotyper, if run from another directory.", metavar = "character"
  ),
  make_option(c("-m", "--mode"),
    type = "character", default = "",
    help = "Can be bulk or single-cell.", metavar = "character"
  )
  #  make_option(c("-s", "--sample_sex"), type='character', default=NULL,
  #              help="Needed for proper normalization of read counts in gonosomes", metavar="character")
)
opt_parser <- OptionParser(option_list = option_list)
opt <- parse_args(opt_parser)
path_to_regenotyper <- opt$path_to_self
if (is.null(path_to_regenotyper) == FALSE) {
  print(paste0("Going to path ", path_to_regenotyper))
  setwd(path_to_regenotyper)
} else {
  print(paste0("Staying in path", getwd()))
}

suppressMessages(source("workflow/scripts/arbigent/probability_helpers.R"))
suppressMessages(source("workflow/scripts/arbigent/regenotype_helpers.R"))
suppressMessages(source("workflow/scripts/arbigent/bulk_helpers.R"))








# Ok lets go! Calm the minds of impatient humans first of all.
print("Processing and summarizing information, making plots")
print("This can take a few minutes.")

p_link <- opt$file
labels_link <- NULL # opt$bed
outdir_raw <- opt$outdir
debug_file <- opt$cn_map
# sample_sex = opt$sample_sex
m <- opt$mode

if (m == "single-cell") {
  ### switch modules on/off
  make_bell_bulk <- F
  make_table_bulk <- F
  make_bee_bulk <- F

  make_bell_sc <- F
  make_table_sc <- F
  make_bee_sc <- F

  run_singlecell_mode <- T
} else {
  make_bell_bulk <- T
  make_table_bulk <- T
  make_bee_bulk <- F

  make_bell_sc <- F
  make_table_sc <- F
  make_bee_sc <- F

  run_singlecell_mode <- F
}
suppressMessages(dir.create(outdir_raw))

# is input file specified?
if (is.null(opt$file)) {
  print_help(opt_parser)
  stop("Please specify path to probabilities.R", call. = FALSE)
}

# Tell the user if we go for bed mode or single mode. At the occasion, also load it.
if (is.null(labels_link)) {
  print("No bed file specified. Examining everything together.")
  labels <- NULL
} else {
  print("Bed file provided. Will use it to stratify results w.r.t. groups")
  labels <- read.table(labels_link, stringsAsFactors = FALSE)
  colnames(labels) <- c("chrom", "start", "end", "group")

  # Bed files can be weird. For safety, we remove duplicate lines
  labels <- unique(labels)
}

# sample name could be interesting
# print(p_link)
sname <- tail(strsplit(sub("/arbigent_mosaiclassifier/sv_probabilities/probabilities.Rdata", "", p_link), "/")[[1]], 1)
# sname <- str_match(p_link, "([HG|NA|GM]+[0-9]{3,5})")[, 2]
print(sname)
# stop()

# load p to p_grouped
print("Loading probabilities table")
probs_raw <- load_and_prep_pdf(p_link)
print("Done loading")

print(unique(probs_raw$chrom))
# load file containing info on valid bins
CN <- read.table(debug_file, header = 1, stringsAsFactors = F)

# Cut CN down to important cols, then join it with probs so we know which inversion has how
# many valid bins
if (!is.null(CN)) {
  CNmerge <- as.data.frame(lapply(CN[, c("chrom", "start", "end", "valid_bins")], as.character))

  CNmerge <- as.tbl(CNmerge)
  CNmerge <- CNmerge %>% mutate(
    chrom = as.character(chrom),
    start = as.numeric(as.character(start)),
    end = as.numeric(as.character(end))
  ) # ,
  p2 <- full_join(probs_raw, CNmerge, by = c("chrom", "start", "end"))
}

# Remove invalid bins from probs_raw
probs_raw <- p2[!is.na(p2$sample), ]
# probs_raw = probs_raw[probs_raw$chrom == 'chr22',]

print("ASDSADSADSASDASA")
print(unique(probs_raw$chrom))
#############################################################
#############################################################
# ATTENTION LADIES AND GENTLEMEN, HERE IS THE NORMALIZATION #
# PLEASE PAY CLOSE ATTENTION TO THIS                        #
#############################################################

# Using the valid_bin info, we can reconstruct the length
# normalization factor here. It is between 0 and 1.
# Basically it is the mapability ratio (0 nothing, 1 perfect)
len_normalization <- as.numeric(as.character(probs_raw$valid_bins)) /
  ((probs_raw$end - probs_raw$start) / 100.)
len_normalization[len_normalization == 0] <- 1

# Depending on how manual segment counts were normalized before,
# and depending on which normalization you want to have, you
# have to choose different options here.
# If they have been length normalized, I am using both options
# A and B to arrive at the 'downscaling' solution ('2'). If A and
# B are disabled, no further normalization is done. Since W and C
# counts have been up-scaled by watson_crick_counts.py, this is
# also ok ('solution 1')
# By all means check the qc plots that arrive at the end of the
# snakemake!

# option A: if we multiply by len_normalization, we remove previous
# normalization, and return back to non-length-corrected counts.
probs_raw$W <- probs_raw$W * len_normalization
probs_raw$C <- probs_raw$C * len_normalization

# option B: we can downscale expectations
probs_raw$expected <- probs_raw$expected * len_normalization


##############################################################
##############################################################
##############################################################
# Additionally: adjust expectations based on biological sex ##
##############################################################

# if (sample_sex == 'male'){

# UPDATE: NO WE DONT HAVE TO NORMALIZE.
# In male samples, we expect half the number of reads on X ...
# probs_raw[probs_raw$chrom=='chrX',]$expected = (probs_raw[probs_raw$chrom=='chrX',]$expected) / 2
# And half in y
# probs_raw[probs_raw$chrom=='chrY',]$expected = (probs_raw[probs_raw$chrom=='chrY',]$expected) / 2

# Remove WC and CW cells in chrX and Y.
#  probs_raw = probs_raw[!((probs_raw$chrom == 'chrX') & (probs_raw$class %in% c('WC','CW'))),]
#  probs_raw = probs_raw[!((probs_raw$chrom == 'chrY') & (probs_raw$class %in% c('WC','CW'))),]
# } else {
#  probs_raw = probs_raw[!(probs_raw$chrom == 'chrY'),]
# }



# Adding group information to probs_raw.
if (is.null(labels)) {
  # If bed file was not provided, everyone is group 'all'
  probs_raw$group <- "all"
} else {
  # Else, the ones with a group get that one from the bed file
  probs_raw <- full_join(probs_raw, labels)
  # ... the remaining ones are called ungrouped.
  probs_raw$group[is.na(probs_raw$group)] <- "ungrouped"
}

# remove inf things [DIRTY SOLUTION! SHOULD BE DONE BETTER! Not needed apparently? Not sure. CHECK]
# probs_raw = probs_raw[probs_raw$logllh != 'Inf',]
# probs_raw = probs_raw[probs_raw$logllh != '-Inf',]
# probs_raw = probs_raw[!(probs_raw$cell %in% unique(probs_raw[probs_raw$logllh == '-Inf',]$cell)),]

###################################
##### OKAAYYY HERE WE GOOOOO ######
#### THIS IS THE MAIN WORKHORSE ###
###################################
print(unique(probs_raw$group))
# group = unique(probs_raw$group)[1] #for quick manual mode
for (group in unique(probs_raw$group)) {
  # Talk to human
  print(paste0("Running samples with group ", group))

  # Make the outfolder (maybe not necessary?)
  outdir <- gsub("\\.:", "_:", paste0(outdir_raw, group, "/"))
  suppressMessages(dir.create(outdir))

  # cut pg down to the desired inversions
  pg <- probs_raw[probs_raw$group == group, ]
  haps_to_consider <- na.omit(unique(pg$haplotype))


  if (make_bell_bulk) {
    ### [I]a) make dumbbell plot ###
    #### [I] BULK ###


    pg_bulk_list <- (bulkify_pg(haps_to_consider, pg))
    pg_bulk <- data.frame(pg_bulk_list[1]) %>% group_by(start, end, haplotype, class, group)
    pg_bulk_probs <- data.frame(pg_bulk_list[2]) %>% group_by(start, end, haplotype, class, group)
    # write.table(pg_bulk, file=paste0('/home/hoeps/Desktop/', 'counts_bulk.txt'), quote = F, row.names = F, col.names = T)

    # at least temporarily, I'm operating both with likelihoods and probabilities. Haven't decided yet which one I like more.
    # call_llhs_bulk = (make_condensed_sumlist(haps_to_consider, pg_bulk)) %>% mutate_all(funs(replace_na(.,-1000)))
    call_llhs_bulk <- (make_condensed_sumlist(haps_to_consider, pg_bulk)) %>% mutate_all(funs(replace_na(., -1000)))

    # call_probs_bulk = (make_condensed_sumlist_probs(haps_to_consider, pg_bulk_probs)) %>% mutate_all(funs(replace_na(.,-1000)))
    # write.table(mm, file=paste0('/home/hoeps/Desktop/', 'counts_bulk2.txt'), quote = F, row.names = F, col.names = T)



    g <- make_dumbbell(call_llhs_bulk, groupname = group, run_shiny = F)
    # p = suppressMessages(ggplotly(g))

    savepath <- paste0(outdir, "bellplot_bulk.html")
    # suppressMessages(htmlwidgets::saveWidget(as_widget(p), file.path(normalizePath(dirname(savepath)),basename(savepath))))
    ggsave(filename = paste0(outdir, sname, "_", group, "_bellplot_bulk.png"), width = 30, height = 12, units = "cm", device = "png")
    ggsave(filename = paste0(outdir, sname, "_", group, "_bellplot_bulk.pdf"), width = 30, height = 12, units = "cm", device = "pdf")

    # call_probs_bulk[,4:73] = -log(1-(call_probs_bulk[,4:73]))
    # call_probs_bulk[,4:73] = 10**(call_probs_bulk[,4:73])
    # g = make_dumbbell_probs(call_probs_bulk, groupname=group, run_shiny=F)
    # p = suppressMessages(ggplotly(g))

    savepath <- paste0(outdir, "bellplot_bulk_prob.html")
    # suppressMessages(htmlwidgets::saveWidget(as_widget(p), file.path(normalizePath(dirname(savepath)),basename(savepath))))
    ggsave(g, filename = paste0(outdir, sname, "_", group, "_bellplot_bulk_prob.png"), width = 30, height = 12, units = "cm", device = "png")
    ggsave(g, filename = paste0(outdir, sname, "_", group, "_bellplot_bulk_prob.pdf"), width = 30, height = 12, units = "cm", device = "pdf")
  }

  if (make_table_bulk) {
    #### [I]b) write table ####
    tab <- make_table_finaledition(call_llhs_bulk, group, sname)
    # write.table(tab, file=paste0('/home/hoeps/Desktop/', 'counts_bulk_labels.txt'), quote = F, row.names = F, col.names = T)
    # adding copy number and mapability information to the table
    tab2 <- left_join(tab, CN[, c("chrom", "start", "end", "valid_bins")])


    if (dim(tab2[tab2$valid_bins == 0, ])[1] > 0) {
      tab2[tab2$valid_bins == 0, ]$pred_hard <- "nomappability"
      tab2[tab2$valid_bins == 0, ]$pred_soft <- "nomappability"
      tab2[tab2$valid_bins == 0, ]$pred_nobias <- "nomappability"
      tab2[tab2$valid_bins == 0, ]$second_hard <- "nomappability"
      tab2[tab2$valid_bins == 0, ]$confidence_hard_over_second <- 100
    }
    tab <- tab2[, !(names(tab2) %in% c("valid_bins"))]
    # t2 = as.data.frame(lapply(tab, as.character))

    tab[tab$confidence_hard_over_second == 0, "pred_hard"] <- "0|0"


    # if (!is.null(CN)){
    # CNmerge = as.data.frame(lapply(CN[, c("chrom","start","end","CN","mapability")], as.character))
    #  CNmerge = as.data.frame(lapply(CN[, c("chrom","start","end","valid_bins")], as.character))

    #  tab <- left_join(t2, CNmerge, by = c("chrom","start","end"))
    # CN = read.table(CN_link, stringsAsFactors = F, header=1);
    # }

    tab$group <- group
    write.table(tab, file = paste0(outdir, "sv_calls_bulk.txt"), quote = F, row.names = F, col.names = T, sep = "\t")
  }

  if (make_bee_bulk) {
    ## [I]c) make beeswarm plots ##
    print("sup")
    save_beeswarms(pg_bulk %>% group_by(start), call_llhs_bulk, outdir, testrun = F, compositemode = T)
    print("over the hill")
  }

  ### [II] SINGLE CELL ###


  if (make_bell_sc) {
    call_llhs <- (make_condensed_sumlist(haps_to_consider, pg)) # %>% mutate_all(funs(replace_na(.,-1000)))

    #### [II]a) make dumbbell plot ####


    suppressMessages(source("workflow/scripts/arbigent/regenotype_helpers.R")) # for quick manual mode

    # create the ggplot plot
    g <- make_dumbbell(call_llhs, groupname = group, run_shiny = F)
    g
    # convert it to plotly
    # p = suppressMessages(ggplotly(g))

    # save. htmlwidgets does not work with relative paths, so we need a little workaround
    # (info and code taken from https://stackoverflow.com/questions/41399795/savewidget-from-htmlwidget-in-r-cannot-save-html-file-in-another-folder)
    savepath <- paste0(outdir, "bellplot.html")
    suppressMessages(htmlwidgets::saveWidget(as_widget(p), file.path(normalizePath(dirname(savepath)), basename(savepath))))
    ggsave(filename = paste0(outdir, sname, "_", group, "_bellplot.png"), width = 30, height = 12, units = "cm", device = "png")
    ggsave(filename = paste0(outdir, sname, "_", group, "_bellplot.pdf"), width = 30, height = 12, units = "cm", device = "pdf")
  }

  if (make_table_sc) {
    #### [II]b) write table ####
    # tab = make_table(call_llhs, group, sname)
    tab <- make_table_finaledition(call_llhs, group, sname)

    t2 <- as.data.frame(lapply(tab, as.character))
    CNmerge <- as.data.frame(lapply(CN[, c("chrom", "start", "end", "CN", "mapability")], as.character))
    tab2 <- left_join(t2, CNmerge, by = c("chrom", "start", "end"))


    # adding copy number and mapability information to the table
    write.table(tab2, file = paste0(outdir, "sv_calls.txt"), quote = F, row.names = F, col.names = T)
  }

  if (make_bee_sc) {
    #### [II]c) save beewarm plots ###
    save_beeswarms(pg, call_llhs, outdir, testrun = F)
    print(paste0("Group ", group, " done."))
  }


  if (run_singlecell_mode) {
    # rectify logllh Nan. It will be overwritten anyway. Just want to avoid NA errors
    if (dim(pg[is.na(pg$logllh), ])[1] > 0) {
      pg[is.na(pg$logllh), ]$logllh <- -1
    }

    # re-calculate logllh based on mapping parameters
    pg2 <- calc_new_logllhs_singlecell(pg)

    # make an output table

    tab <- make_table_sc_separated(pg2)
    cols_to_return <- c("cell", "chrom", "start", "end", "class", "expected", "W", "C", "top_pred", "second_pred", "llr_1st_to_2nd", "llr_1st_to_ref")
    tab <- make_table_sc_separated(pg2)
    write.table(tab[, cols_to_return], file = paste0(outdir_raw, "sv_calls.txt"), quote = F, row.names = F, col.names = T, sep = "\t")
    write.table(tab, file = paste0(outdir_raw, "sv_calls_detailed.txt"), quote = F, row.names = F, col.names = T, sep = "\t")
  }
}

print("### ALL DONE. Happy discoveries. ###")

# Does not seem to work.
options(warn = oldw)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
library(ggplot2)
library(reshape)
library(dplyr)
library(tibble)
library(optparse)

source("workflow/scripts/arbigent/clean_genotype_defs.R")
source("workflow/scripts/arbigent/vcf_defs.R")

# INPUT INSTRUCTIONS
option_list <- list(
  make_option(c("-a", "--alltxt"),
    type = "character", default = NULL,
    help = "ArbiGent output (traditionally 'all.txt') to be turned into vcf", metavar = "character"
  ),
  make_option(c("-m", "--msc"),
    type = "character", default = NULL,
    help = "count debug file from mosaicatcher main run", metavar = "character"
  ),
  make_option(c("-o", "--outdir"),
    type = "character", default = "./outputcorr/",
    help = "Outputdir", metavar = "character"
  ),
  make_option(c("-c", "--copynumberann"),
    type = "logical", default = F,
    help = "Use the copynumberannotation?", metavar = "character"
  )
)
# Parse input
opt_parser <- OptionParser(option_list = option_list)
opt <- parse_args(opt_parser)
alltxt_file <- opt$alltxt
msc_file <- opt$msc
outdir <- opt$outdir
use_cntrack <- opt$c
# alltxt_file = "/home/hoeps/PhD/projects/huminvs/mosaicatcher/analysis/results/U32_freezemerge/sv_probabilities/all.txt"
# msc_file = "/home/hoeps/PhD/projects/huminvs/mosaicatcher/analysis/results/U32_freezemerge/msc.debug"
# outdir = "/home/hoeps/Desktop/arbitrash"


### PARAMETERS ###

# Cutoff for 'lowconf'
save <- T
cutoff <- 3
# Second criterion: plus 5
# Complex vs simple: we want complex LLHs to be double the ones
# of simple, and at the same time at least higher by magnitude 5.
bias_add_factor <- 5


### FUNCTIONS ###
load_tab <- function(alltxt_file_f) {
  # Tiny function to load the input file
  tab <- read.table(alltxt_file_f, header = T, stringsAsFactors = F)
  tab <- tab %>% mutate(ID = paste0(chrom, "-", start + 1, "-INV-", (end - start) + 1))

  return(tab)
}

simplify_countmatrix_idup <- function(countm) {
  # Strongly simplify things

  countm[countm == "noreads"] <- "./."
  countm[countm == "1101"] <- "./1"
  countm[countm == "0100"] <- "1|."
  countm[countm == "2101"] <- "./1"
  countm[countm == "1110"] <- "./0"
  countm[countm == "0010"] <- ".|0"
  countm[countm == "0103"] <- "1|1"
  countm[countm == "1201"] <- "./1"
  countm[countm == "0001"] <- "./1"
  countm[countm == "1030"] <- "0|0"
  countm[countm == "0301"] <- "1|1"
  countm[countm == "3010"] <- "0|0"
  countm[countm == "0120"] <- "1|0"
  countm[countm == "3001"] <- "./1"
  countm[countm == "2020"] <- "0|0"
  countm[countm == "0202"] <- "1|1"
  countm[countm == "1000"] <- "0|."

  countm[countm == "1111"] <- "idup_1|1"
  countm[countm == "2200"] <- "idup_1|1"
  countm[countm == "0022"] <- "idup_1|1"

  countm[countm == "1011"] <- "idup_0|1"
  countm[countm == "1110"] <- "idup_1|0"

  simple_calls <- c(
    "0|0", "0|1", "1|0", "1|1",
    "0/0", "0/1", "1/0", "1/1",
    "./1", "1/.", "0/.", "./0",
    ".|1", "1|.", "0|.", ".|0"
  )

  # for idups, we want to make presence-absence variation analysis
  for (row in seq(1:dim(countm)[1])) {
    for (col in seq(1:dim(countm)[2])) {
      if ((countm[row, col] %in% simple_calls)) {
        countm[row, col] <- "0|0"
      }
    }
  }
  print("hi")
  good_calls <- c(
    "0|0", "0|1", "1|0", "1|1",
    "0/0", "0/1", "1/0", "1/1",
    "./1", "1/.", "0/.", "./0",
    ".|1", "1|.", "0|.", ".|0",
    "idup_0|1", "idup_1|0", "idup_1|1"
  )

  # Ugly loop because i dont know how to do it better
  for (row in seq(1:dim(countm)[1])) {
    for (col in seq(1:dim(countm)[2])) {
      if (!(countm[row, col] %in% good_calls)) {
        countm[row, col] <- "./."
      }
    }
  }
  head(countm)
  return(countm)
}

simplify_countmatrix <- function(countm) {
  # Strongly simplify things

  countm[countm == "noreads"] <- "./."
  countm[countm == "1101"] <- "./1"
  countm[countm == "0100"] <- "1|."
  countm[countm == "2101"] <- "./1"
  countm[countm == "1110"] <- "./0"
  countm[countm == "0010"] <- ".|0"
  countm[countm == "0103"] <- "1|1"
  countm[countm == "1201"] <- "./1"
  countm[countm == "0001"] <- "./1"
  countm[countm == "1030"] <- "0|0"
  countm[countm == "0301"] <- "1|1"
  countm[countm == "3010"] <- "0|0"
  countm[countm == "0120"] <- "1|0"
  countm[countm == "3001"] <- "./1"
  countm[countm == "2020"] <- "0|0"
  countm[countm == "0202"] <- "1|1"
  countm[countm == "1000"] <- "0|."

  simple_calls <- c(
    "0|0", "0|1", "1|0", "1|1",
    "0/0", "0/1", "1/0", "1/1",
    "./1", "1/.", "0/.", "./0",
    ".|1", "1|.", "0|.", ".|0"
  )
  # Ugly loop because i dont know how to do it better
  for (row in seq(1:dim(countm)[1])) {
    for (col in seq(1:dim(countm)[2])) {
      if (!(countm[row, col] %in% simple_calls)) {
        countm[row, col] <- "./."
      }
    }
  }
  head(countm)
  return(countm)
}

load_and_prep_CN <- function(msc_file_f) {
  CN <- read.table(msc_file, header = 1, stringsAsFactors = F)
  CNmerge <- as.data.frame(lapply(CN[, c("chrom", "start", "end", "valid_bins")], as.character))

  CNmerge <- tibble::as_tibble(CNmerge)
  CNmerge <- CNmerge %>% mutate(
    chrom = as.character(chrom),
    start = as.numeric(as.character(start)),
    end = as.numeric(as.character(end)),
    valid_bins = as.numeric(as.character(valid_bins))
  )
  return(CNmerge)
}


################
### RUN CODE ###
################

# Load input file
tab <- load_tab(alltxt_file)
# Load CN file. It will be used to include 'valid bins' information, which is good to have in the output files.
CNmerge <- load_and_prep_CN(msc_file)

# TODO: WHAT IS THIS LINE DOING?
bias_factor <- tab$confidence_hard_over_second # This is the first criterion: double


######### GET 'REPORTED' GENOTYPES ##############
# tabp = [tab]le_[p]rocessed. This has added a 'GT' column that is the GT result
# of choice. And this GT of choice depends on: bias(add)factor, cutoff and wether
# or not we want 'lowconf' label included.

if (use_cntrack) {
  tab$pred_hard <- paste(tab$pred_hard, tab$illumina_CN, sep = ":")
  tab$pred_nobias <- paste(tab$pred_nobias, tab$illumina_CN, sep = ":")
}
# tabp = Complex calls allowed, lowconf label added
tabp <- add_gts_revisited_lowconf(tab, bias_factor, bias_add_factor, cutoff)
print(tabp)
# tabp2 =  Complex calls allowed, lowconf label nope, LLHs printed
tabp2 <- add_long_gts_revisited(tab, bias_factor, bias_add_factor, cutoff)
print(tabp2)

# tabp3 = Complex calls allowed, lowconf label nope
tabp3 <- add_gts_revisited(tab, bias_factor, bias_add_factor, cutoff)


# Merge valid bin information into tabp's
tabp <- full_join(tabp, CNmerge, by = c("chrom", "start", "end"))
tabp2 <- full_join(tabp2, CNmerge, by = c("chrom", "start", "end"))
tabp3 <- full_join(tabp3, CNmerge, by = c("chrom", "start", "end"))
print(tabp3)
print(CNmerge)

####################################################################################

# Cast table tabp3 into vcf-like matrix
callmatrix <- cast(unique(tabp3), chrom + start + end + ID + len + valid_bins ~ sample, value = "GT")

####################################################################################

# Name shortening, a bit of reformatting
cms <- callmatrix
print(colnames(cms))
samplenames <- colnames(cms)[7:length(colnames(cms))]
print(samplenames)

cms_gts <- cms[, samplenames, drop = F]
print(cms_gts)

# print(head(lapply(cms_gts, as.character)))
cms_gts[] <- lapply(cms_gts, as.character)
# Complex calls to simple ones
print(cms_gts)

countm <- simplify_countmatrix(cms_gts)
countm_idup <- simplify_countmatrix_idup(cms_gts)
# Bind description and GTs back together
cms_full <- cbind(cms[, 1:6], cms_gts)

# Bind simplified countmatrix to description
cms_simple <- cbind(cms[, 1:6], countm)
cms_simple_idup <- cbind(cms[, 1:6], countm_idup)
####################################################################################

# Tabp: this is the 'normal' one. With lowconf label
callmatrix <- cast(unique(tabp), chrom + start + end + ID + len + valid_bins ~ sample, value = "GT")
callmatrix <- callmatrix[, colSums(is.na(callmatrix)) < dim(callmatrix)[2]]

# Tabp2: this is the one with LLH info included
callmatrix_detail <- cast(tabp2, chrom + start + end + ID + len + valid_bins ~ sample, value = "GTL")
callmatrix_detail <- callmatrix_detail[, colSums(is.na(callmatrix_detail)) < dim(callmatrix_detail)[2]]
# callmatrix_detail = callmatrix_detail[ , colSums(is.na(callmatrix_detail)) == 0]

# Sidequest: find hom invs. A bit messy but we keep it for now.
callmatrix_hom_lab <- callmatrix
callmatrix_hom_lab$nhom <- rowSums(callmatrix_hom_lab == "1|1")
cm_hom <- callmatrix_hom_lab[callmatrix_hom_lab$nhom > (dim(callmatrix_hom_lab)[2] - 5) * 0.8, ]
hom_ids <- cm_hom$ID
cm_detail_hom <- callmatrix_detail[callmatrix_detail$ID %in% hom_ids, ]

################## MAKE VCFS
# The detailed one from tabp2, including LLH info.
vcf <- vcfify_callmatrix_detail(callmatrix_detail)
# Like above, but filtered for misos
vcf_miso <- vcfify_callmatrix_detail(cm_detail_hom)
# And this one is based on simplified tabp3.
vcf_limix <- vcfify_callmatrix_simple_for_limix(cms_simple)
# And here we save the one that contains idups
vcf_limix_plus_idups <- vcfify_callmatrix_simple_for_limix(cms_simple_idup)
################## Make at least one plot TODO
print(head(cms_simple_idup))
# ggplot(tabp3) + geom_point(aes(x=log1p(confidence_hard_over_second), y=log1p(confidence_nobias_over_hard), col=GT))


################# SAVE ALL THESE DIFFERENT THINGS.
if (save == T) {
  # Prep directory
  dir.create(outdir)

  # Paths, paths, paths.
  callmatrix_file <- "res.csv"
  callmatrix_file_detail <- "res_detail.csv"
  vcffile_all <- "res_all.vcf"
  vcffile_miso <- "res_miso.vcf"
  vcffile_limix <- "res_verysimple.vcf"
  vcffile_limix_plus_idups <- "res_verysimple_idups.vcf"

  # Save simple callmatrix
  write.table(callmatrix, file = file.path(outdir, callmatrix_file), quote = F, col.names = T, row.names = F, sep = "\t")
  # Same detailed callmatrix
  write.table(callmatrix_detail, file = file.path(outdir, callmatrix_file_detail), quote = F, col.names = T, row.names = F, sep = "\t")
  # Save vcf_all
  outvcffile <- file.path(outdir, vcffile_all)
  writeLines(vcf[[1]], file(outvcffile))
  write.table(vcf[[2]], file = outvcffile, quote = F, col.names = F, row.names = F, sep = "\t", append = T)
  system(paste0("cat ", outvcffile, ' | awk \'$1 ~ /^#/ {print $0;next} {print $0 | "sort -k1,1 -k2,2n"}\' > ', outvcffile, "_sorted"))

  # Save vcf_miso
  outvcffile_miso <- file.path(outdir, vcffile_miso)
  writeLines(vcf_miso[[1]], file(outvcffile_miso))
  write.table(vcf_miso[[2]], file = outvcffile_miso, quote = F, col.names = F, row.names = F, sep = "\t", append = T)
  system(paste0("cat ", outvcffile_miso, ' | awk \'$1 ~ /^#/ {print $0;next} {print $0 | "sort -k1,1 -k2,2n"}\' > ', outvcffile_miso, "_sorted"))

  # Save vcf_limix
  outvcffile_limix <- file.path(outdir, vcffile_limix)
  writeLines(vcf_limix[[1]], file(outvcffile_limix))
  write.table(vcf_limix[[2]], file = outvcffile_limix, quote = F, col.names = F, row.names = F, sep = "\t", append = T)
  system(paste0("cat ", outvcffile_limix, ' | awk \'$1 ~ /^#/ {print $0;next} {print $0 | "sort -k1,1 -k2,2n"}\' > ', outvcffile_limix, "_sorted"))

  # Save vcf_limix idups
  outvcffile_limix_idups <- file.path(outdir, vcffile_limix_plus_idups)
  writeLines(vcf_limix_plus_idups[[1]], file(outvcffile_limix_idups))
  write.table(vcf_limix_plus_idups[[2]], file = outvcffile_limix_idups, quote = F, col.names = F, row.names = F, sep = "\t", append = T)
  system(paste0("cat ", outvcffile_limix_idups, ' | awk \'$1 ~ /^#/ {print $0;next} {print $0 | "sort -k1,1 -k2,2n"}\' > ', outvcffile_limix_idups, "_sorted"))
}
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os

# import pdb
# import argparse
import sys
from pathlib import Path

# import glob
import multiprocessing as mp

# import collections as col
# from collections import defaultdict

import pandas as pd
import xopen
import numpy as np
import pysam as pysam


# def parse_command_line():
#     parser = argparse.ArgumentParser(description=__doc__)
#     parser.add_argument(
#         "--mapping-counts",
#         "-mc",
#         help="Raw mapping counts as fixed bin, regular spaced BED-like file. Will be automatically converted to HDF file for future use.",
#         dest="map_counts",
#         type=str,
#         default="",
#     )
#     parser.add_argument(
#         "--chromosome",
#         "-c",
#         default="genome",
#         type=str,
#         choices=["genome"],  # safeguard against accidental parallelization by chromosome
#         help='Restrict counting to this chromosome. Default "genome" will process everything.',
#     )

#     args = parser.parse_args()
#     return args


def convert_mapping_counts(raw_counts, process_chrom):
    """
    Convert textual BED-like raw mapping counts into
    binary representation stored as HDF for faster access.
    """
    num_splits = 1
    if any([raw_counts.endswith(x) for x in [".gz", ".zip", ".bz2", ".xz"]]):
        num_splits = 2
    basename = raw_counts.rsplit(".", num_splits)[0]
    hdf_file = basename + ".h5"
    if os.path.isfile(hdf_file):
        return hdf_file
    else:
        with pd.HDFStore(hdf_file, "w") as hdf:
            pass

    process_genome = process_chrom == "genome"

    with xopen.xopen(raw_counts, mode="rt") as bedfile:
        columns = bedfile.readline().strip().split(" ")
        if not int(columns[1]) == 0:
            raise ValueError("Mapping counts track does not start at beginning of chromosome: {}".format("\t".join(columns)))
        bin_size = int(columns[2]) - int(columns[1])
        assert bin_size > 99, "Bin size {} detected, this code is only optimized for bin sizes >= 100".format(bin_size)
        bedfile.seek(0)

        last_chrom = columns[0]
        correct_counts = []
        incorrect_counts = []
        chroms_seen = set()

        for line in bedfile:
            chrom, start, _, correct_reads, incorrect_reads = line.split(" ")
            if not process_genome:
                if process_chrom != chrom:
                    continue

            if chrom != last_chrom:
                with pd.HDFStore(hdf_file, "a") as hdf:
                    hdf.put(os.path.join(last_chrom, "correct"), pd.Series(correct_counts, dtype=np.int8), format="fixed")
                    hdf.put(os.path.join(last_chrom, "incorrect"), pd.Series(incorrect_counts, dtype=np.int8), format="fixed")

                correct_counts = []
                incorrect_counts = []

                if chrom in chroms_seen:
                    raise ValueError("Mapping counts track file is not sorted - encountered twice: {}".format(chrom))

                if last_chrom == process_chrom:
                    # can stop iteration, processed the one single chromosome that was requested
                    break

                last_chrom = chrom
                chroms_seen.add(chrom)

            correct_reads = int(correct_reads)
            incorrect_reads = int(incorrect_reads)
            assert correct_reads < 127, "Count of correct reads too large (must be < 127): {}".format(correct_reads)
            correct_counts.append(correct_reads)
            incorrect_counts.append(incorrect_reads)

    # dump last
    if correct_counts:
        with pd.HDFStore(hdf_file, "a") as hdf:
            hdf.put(os.path.join(last_chrom, "correct"), pd.Series(correct_counts, dtype=np.int8), format="fixed")
            hdf.put(os.path.join(last_chrom, "incorrect"), pd.Series(incorrect_counts, dtype=np.int8), format="fixed")

    # return hdf_file


def main():
    # args = parse_command_line()

    map_counts_file = convert_mapping_counts(snakemake.input.mapping_track, "genome")
    # map_counts_file = convert_mapping_counts(
    #     args.map_counts,
    #     args.chromosome
    #     )

    return 0


if __name__ == "__main__":
    main()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
log <- file(snakemake@log[[1]], open='wt')
sink(file=log, type='message')
sink(file=log, type='output')

library(data.table)
library(assertthat)
source("workflow/scripts/arbigent_utils/mosaiclassifier_scripts/mosaiClassifier/mosaiClassifier.R")

# The following function converts the bed format to the segs format, which is later used in mosaicatcher
convert_bed_to_segs_format <- function(bed.table, bin.size){
	colnames(bed.table) <- c("chrom", "start", "end")
	segs <- bed.table
	segs[, `:=`(s=ceiling(start/bin.size)-1, e=ceiling(end/bin.size)-1), by=chrom]
	segs[, `:=`(start=NULL, end=NULL)]

	segs <- reshape(segs, direction = "long", varying = c("s", "e"), v.names = "bps", timevar = NULL)
	segs[, id:=NULL]
	setkey(segs, chrom, bps)

	# remove repetitive rows
	segs <- unique(segs)

	segs[, k:=.N, by=chrom]
	setcolorder(segs, c("k", "chrom", "bps"))

	return(segs)
}


# Currently read files from the Snakemake pipeline
counts = fread(paste("zcat",snakemake@input[["counts"]]))
info   = fread(snakemake@input[["info"]])
strand = fread(snakemake@input[["states"]])
segs   = fread(snakemake@input[["bp"]])

chroms <- snakemake@config[["chromosomes"]]

counts <- counts[counts$chrom %in% chroms, ]
strand <- strand[strand$chrom %in% chroms, ]
segs <- segs[segs$chrom %in% chroms, ]

#As binomial model is discrete, we need integer counts
segs$C=round(segs$C)
segs$W=round(segs$W)


# DEPERECATED: this version of normalization is no longer used
# is there a normalization file given?
if ("norm" %in% names(snakemake@input) && length(snakemake@input[["norm"]])>0) {
  message("[MosaiClassifier] Read normalization from ", snakemake@input[["norm"]])
  normalization = fread(snakemake@input[["norm"]])
  message("[Warning] Normalization file specified, but this option is no longer available")
} else {
  normalization = NULL
}

# haplotypeMode?
if ("CW" %in% strand$class) {
  haplotypeMode = T
} else {
  haplotypeMode = F
}

print("mosaiClassifierPrepare...")
d = mosaiClassifierPrepare(counts, info, strand, segs, manual.segs=as.logical(snakemake@config[["arbigent"]]))
print("mosaiClassifierCalcProbs...")
e = mosaiClassifierCalcProbs(d, maximumCN = 4, haplotypeMode = haplotypeMode, manual.segs=as.logical(snakemake@config[["arbigent"]]))

saveRDS(e, file = snakemake@output[[1]])
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
import os

# import pdb
import argparse
import sys
from pathlib import Path
import glob
import multiprocessing as mp

# import collections as col
from collections import defaultdict

import pandas as pd
import xopen
import numpy as np
import pysam as pysam


def parse_command_line():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("-d", "--debug", action="store_true", default=False)
    parser.add_argument("-j", "--jobs", help="Number of CPU cores to use", default=1, type=int, dest="jobs")
    parser.add_argument("-s", "--sample", help="The sample name", required=True)
    parser.add_argument("-i", "--input_bam", help="The input bam file", required=True)
    parser.add_argument("-b", "--input_bed", type=str, help="The bed file with segments", required=True)
    parser.add_argument(
        "-n",
        "--norm_count_output",
        type=str,
        help="The output file with normalised watson and crick counts for downstream procesing",
        required=True,
    )
    parser.add_argument(
        "-p", "--norm_plot_output", type=str, help="The output file with normalised watson and crick counts for plots", required=True
    )
    parser.add_argument(
        "--mapping-counts",
        "-mc",
        help="Raw mapping counts as fixed bin, regular spaced BED-like file. Will be automatically converted to HDF file for future use.",
        dest="map_counts",
        type=str,
        default="",
    )
    parser.add_argument(
        "-l", "--lengthcorr_bool", action="store_true", help="Whether or not to perform length normalization", default=False
    )
    # since the binary HDF file will be created in the background,
    # and be used implicitly is present, parallelization can be
    # dangerous if several processes (e.g., as executed in a Snakemake
    # run) try to write to the same output HDF file during conversion
    parser.add_argument(
        "--chromosome",
        "-c",
        default="genome",
        type=str,
        choices=["genome"],  # safeguard against accidental parallelization by chromosome
        help='Restrict counting to this chromosome. Default "genome" will process everything.',
    )
    parser.add_argument(
        "--bin-size",
        "-bs",
        type=int,
        default=100,
        dest="bin_size",
        help='Bin size of "mapping counts track". Will be inferred in case of raw mapping counts',
    )
    parser.add_argument(
        "--min-mappability",
        "-mm",
        type=int,
        default=75,
        dest="min_mapp",
        help="Minimum count of correctly mapped reads for a bin to be considered. Default: 75",
    )

    args = parser.parse_args()
    return args


def determine_boundaries(coordinate, bin_size, which_end):

    dm_div, dm_mod = divmod(coordinate, bin_size)
    if which_end == "low":
        if dm_mod == 0:
            # sitting right on the boundary
            coord = dm_div * bin_size
            coord_bin = dm_div
        else:
            coord = dm_div * bin_size + bin_size
            coord_bin = coord // bin_size
    else:
        if dm_mod == 0:
            coord = dm_div * bin_size
            coord_bin = dm_div
        else:
            coord = dm_div * bin_size
            coord_bin = coord // bin_size
    return coord, coord_bin


def filter_reads(align_read):

    if any([align_read.is_read2, align_read.is_qcfail, align_read.is_secondary, align_read.is_duplicate, align_read.mapq < 10]):
        return None, 0
    else:
        start_pos = align_read.reference_start
        if align_read.is_reverse:
            return "watson", start_pos
        else:
            return "crick", start_pos


def aggregate_segment_read_counts(process_args):

    chrom, sample, segment_file, bam_folder, mappability_track, bin_size, min_correct_reads, lengthcorr_bool = process_args
    # print(segment_file)
    segments = pd.read_csv(segment_file, sep="\t", header=None, names=["chrom", "start", "end"])
    # print(segments)
    segments = segments.loc[segments["chrom"] == chrom, :].copy()
    # print(segments)
    if segments.empty:
        return chrom, None
    segments_low_bound = (segments["start"].apply(determine_boundaries, args=(bin_size, "low"))).tolist()
    segments_high_bound = (segments["end"].apply(determine_boundaries, args=(bin_size, "high"))).tolist()
    segments = pd.concat(
        [
            segments,
            pd.DataFrame.from_records(segments_low_bound, columns=["start_boundary", "start_bin"], index=segments.index),
            pd.DataFrame.from_records(segments_high_bound, columns=["end_boundary", "end_bin"], index=segments.index),
        ],
        axis=1,
        ignore_index=False,
    )
    segments["start_bin"] = segments["start_bin"].astype(np.int64)
    segments["end_bin"] = segments["end_bin"].astype(np.int64)

    with pd.HDFStore(mappability_track, "r") as hdf:
        correct_counts = hdf[os.path.join(chrom, "correct")].values
        incorrect_counts = hdf[os.path.join(chrom, "incorrect")].values

    path = Path(bam_folder)
    glob_path = path.glob("*.bam")

    segment_index = []
    segment_counts = []
    # Iterate over each bam file
    for bam_file in glob_path:
        assert os.path.isfile(str(bam_file) + ".bai"), "No BAM index file detected for {}".format(bam_file)
        cell = os.path.basename(bam_file).rsplit(".", 1)[0]

        with pysam.AlignmentFile(bam_file, mode="rb") as bam:
            for idx, row in segments.iterrows():
                counts = pd.DataFrame(
                    np.zeros((4, row["end_bin"] - row["start_bin"]), dtype=np.float64),
                    index=["watson", "crick", "correct", "incorrect"],
                    columns=list(range(row["start_bin"], row["end_bin"])),
                )
                counts.loc["correct", :] = correct_counts[row["start_bin"] : row["end_bin"]]
                counts.loc["incorrect", :] = incorrect_counts[row["start_bin"] : row["end_bin"]]

                reads = [filter_reads(r) for r in bam.fetch(row["chrom"], row["start_boundary"], row["end_boundary"])]
                for orientation, start_pos in reads:
                    if orientation is None:
                        continue
                    try:
                        counts.loc[orientation, start_pos // bin_size] += 1
                    except KeyError:
                        # happens if start pos is outside (lower than) start boundary but overlaps segment
                        continue

                # select only bins where the number of correct reads (simulation data) is above threshold
                # select_correct_threshold = np.array(counts.loc["correct", :] >= min_correct_reads, dtype=np.bool)
                select_correct_threshold = np.array(counts.loc["correct", :] >= min_correct_reads, dtype=bool)

                # select only bins where the number of incorrect reads (simulation data) is lower than 10% relative to correct reads
                # select_low_incorrect = ~np.array(counts.loc["incorrect", :] >= (0.1 * counts.loc["correct", :]), dtype=np.bool)
                select_low_incorrect = ~np.array(counts.loc["incorrect", :] >= (0.1 * counts.loc["correct", :]), dtype=bool)

                # combine selection: only bins for which both of the above is true
                select_bins = select_correct_threshold & select_low_incorrect
                # select_has_watson = np.array(counts.loc["watson", :] > 0, dtype=np.bool)
                select_has_watson = np.array(counts.loc["watson", :] > 0, dtype=bool)
                # select_has_crick = np.array(counts.loc["crick", :] > 0, dtype=np.bool)
                select_has_crick = np.array(counts.loc["crick", :] > 0, dtype=bool)

                valid_bins = select_bins.sum()

                watson_count_valid = counts.loc["watson", select_bins].sum()
                crick_count_valid = counts.loc["crick", select_bins].sum()

                # normalize Watson counts, reset everything else to 0
                counts.loc["watson", select_bins & select_has_watson] *= 100 / counts.loc["correct", select_bins & select_has_watson]
                counts.loc["watson", ~(select_bins & select_has_watson)] = 0

                # normalize Crick counts, reset everything else to 0
                counts.loc["crick", select_bins & select_has_crick] *= 100 / counts.loc["correct", select_bins & select_has_crick]
                counts.loc["crick", ~(select_bins & select_has_crick)] = 0

                total_watson_norm = counts.loc["watson", :].sum()
                total_crick_norm = counts.loc["crick", :].sum()

                # Length-correct counts if needed.
                if True:
                    # if lengthcorr_bool:
                    # compute length normalization factor
                    length_norm = 0
                    if valid_bins > 0:
                        length_norm = (row["end"] - row["start"]) / (valid_bins * bin_size)
                else:
                    length_norm = 1
                total_watson_norm *= length_norm
                total_crick_norm *= length_norm

                segment_counts.append(
                    (
                        row["chrom"],
                        row["start"],
                        row["end"],
                        sample,
                        cell,
                        total_crick_norm,
                        total_watson_norm,
                        valid_bins,
                        length_norm,
                        crick_count_valid,
                        watson_count_valid,
                        row["start_boundary"],
                        row["end_boundary"],
                    )
                )

    df = pd.DataFrame(
        segment_counts,
        columns=[
            "chrom",
            "start",
            "end",
            "sample",
            "cell",
            "C",
            "W",
            "valid_bins",
            "length_norm_factor",
            "Crick_count_valid",
            "Watson_count_valid",
            "start_boundary",
            "end_boundary",
        ],
    )

    return chrom, df


def counts(sample, input_bam, input_bed, norm_count_output, mapping_counts, norm_plot_output):
    dictionary = defaultdict(lambda: defaultdict(tuple))
    mapping_counts_file = open(mapping_counts, "r")
    # Store the whole mapability track
    for lines in mapping_counts_file:
        line = lines.strip().split("\t")
        # print(line)
        chrom = line[0]
        interval_start = int(line[1])
        interval_end = int(line[2])
        reads_originated = int(line[3])
        reads_mapped = int(line[4])
        dictionary[chrom][(interval_start, interval_end)] = (reads_originated, reads_mapped)

    print("Mapping_counts over")

    watson_count = 0
    crick_count = 0
    norm_counts_file = open(norm_count_output, "w")
    # Write header to output file
    norm_counts_file.write("chrom" + "\t" + "start" + "\t" + "end" + "\t" + "sample" + "\t" + "cell" + "\t" + "C" + "\t" + "W" + "\n")

    norm_plots_file = open(norm_plot_output, "w")
    norm_plots_file.write("chrom" + "\t" + "start" + "\t" + "end" + "\t" + "sample" + "\t" + "cell" + "\t" + "C" + "\t" + "W" + "\n")

    # Get all bam file paths
    path = Path(input_bam)
    glob_path = path.glob("*.bam")

    print(glob_path)
    print("Iterate over cells")

    # Iterate over each cell / bam file
    for file in glob_path:
        print(file)
        # Load the according file
        file_name = str(file).strip().split("/")[-1]
        cell = file_name.strip().split(".bam")[0]
        print(cell)
        bam_file = pysam.AlignmentFile(file, "rb")
        # Also get the bed_file
        norm_dictionary = defaultdict(lambda: defaultdict(tuple))
        with open(input_bed, "r") as bed_file:
            next(bed_file)  # skipping the header in bed file
            # Iterate over each manual segment aka each bed file line
            for line in bed_file:
                if line.startswith("#"):
                    continue
                segment_bins = []  # list for binwise counts for each segment
                line_r = line.strip().split("\t")
                chromosome = line_r[0]
                sub_dictionary = dictionary[chromosome]
                seg_start = int(line_r[1])
                seg_end = int(line_r[2])
                seg_start_bin = seg_start
                interval = int(seg_end - seg_start)
                bin_size = 100
                sc_TRIP_bin = 100000
                origin_count = 0
                mapped_count = 0
                norm_crick_counts = 0.0
                norm_watson_counts = 0.0
                whole_crick = 0.0
                whole_watson = 0.0
                start_check = 0
                bins_used = 0
                # do bin wise normalization, i.e multiply bin norm_factor individually to each bin count instead of doing it for the whole segment collectively.
                for m in sub_dictionary:
                    mapped_count = sub_dictionary[m][1]
                    # if segment start is towards the right of this bin_end
                    if seg_start_bin >= m[1]:
                        continue
                    else:
                        # if segment ends before the current bin_ends then we are done with this interval
                        if seg_end < m[1]:
                            break
                        # if the segment starts somewhere inside this bin, skip it
                        elif seg_start_bin > m[0]:
                            continue
                        # otherwise start counting
                        elif seg_start_bin <= m[0] and seg_end >= m[1] and mapped_count > 90:
                            seg_start_bin = m[0]
                            seg_end_bin = m[1]

                            normalizing_factor_counts = 100 / mapped_count

                            # Fetch all bam entries that fall in this bin.
                            for read in bam_file.fetch(chromosome, seg_start_bin, seg_end_bin):
                                # We want to exclude reads that match any of these 5 failing criteria
                                # We use the fact that 'any' is lazy and stops as soon as it finds a true
                                # value. So e.g. c4 only has to be tested if c1, c2 and c3 all returned false.
                                c1 = "read.is_read2"
                                c2 = "read.is_qcfail"
                                c3 = "read.is_secondary"
                                c4 = "read.is_duplicate"
                                c5 = "read.mapq < 10"
                                c6 = "read.pos < seg_start_bin"
                                c7 = "read.pos >= seg_end_bin"
                                if any([eval(c1), eval(c2), eval(c3), eval(c4), eval(c5), eval(c6), eval(c7)]):
                                    pass
                                else:
                                    if read.is_reverse:
                                        watson_count += 1
                                        print("=== watson read")
                                        print("norm factor ", normalizing_factor_counts)
                                        print(read.query_name)
                                        print("seg start ", seg_start_bin)
                                        print("seg end ", seg_end_bin)
                                        print("pos ", read.pos)
                                    elif not read.is_reverse:
                                        print("=== watson read")
                                        print(read.query_name)
                                        print("seg start ", seg_start_bin)
                                        print("seg end ", seg_end_bin)
                                        print("pos ", read.pos)
                                        crick_count += 1
                                    else:
                                        pass
                            if crick_count > 0:
                                print("norm factor ", normalizing_factor_counts)
                            # normalising both watson and crick counts to make the heights comparable
                            norm_crick_counts_bin = float(crick_count * normalizing_factor_counts)
                            norm_watson_counts_bin = float(watson_count * normalizing_factor_counts)
                            segment_bins.append((norm_crick_counts_bin, norm_watson_counts_bin, crick_count, watson_count))

                            # remember how many valid bins we have used
                            bins_used += 1
                            # Reset count for next bin of this segment
                            watson_count = 0
                            crick_count = 0

                            # move to the next bin
                            seg_start_bin = seg_end_bin + 1

                # now add up watson_crick counts (normalizesd) per bin
                for seg_count in segment_bins:
                    if seg_count[0]:
                        print("crick segment ", seg_count[0])
                    if seg_count[1]:
                        print("watson segment ", seg_count[1])
                    norm_crick_counts += float(seg_count[0])
                    norm_watson_counts += float(seg_count[1])

                # Normalize for combined bin length
                if bins_used > 0:
                    len_norm = interval / (bins_used * 100.0)
                else:
                    len_norm = 0
                print("valid bins ", bins_used)
                print("L-norm ", len_norm)

                print("crick norm ", norm_crick_counts)
                print("watson norm ", norm_watson_counts)

                norm_crick_counts *= len_norm
                norm_watson_counts *= len_norm

                print("crick l-norm ", norm_crick_counts)
                print("watson l-norm ", norm_watson_counts)

                norm_counts_file.write(
                    str(chromosome)
                    + "\t"
                    + str(seg_start)
                    + "\t"
                    + str(seg_end)
                    + "\t"
                    + str(sample)
                    + "\t"
                    + str(cell)
                    + "\t"
                    + str(norm_crick_counts)
                    + "\t"
                    + str(norm_watson_counts)
                    + "\n"
                )
                norm_plots_file.write(
                    str(chromosome)
                    + "\t"
                    + str(seg_start)
                    + "\t"
                    + str(seg_end)
                    + "\t"
                    + str(sample)
                    + "\t"
                    + str(cell)
                    + "\t"
                    + str(norm_crick_plots)
                    + "\t"
                    + str(norm_watson_plots)
                    + "\n"
                )
                # norm_plots_file.write(str(chromosome)+  "\t"+ str(seg_start)+ "\t" + str(seg_end) + "\t" + str(sample) + "\t" + str(cell) +"\t" +str(norm_crick_counts)+ "\t" + str(norm_watson_counts)+ "\n")


def convert_mapping_counts(raw_counts, process_chrom):
    """
    Convert textual BED-like raw mapping counts into
    binary representation stored as HDF for faster access.
    """
    print("DEBUG")
    num_splits = 1
    if any([raw_counts.endswith(x) for x in [".gz", ".zip", ".bz2", ".xz"]]):
        num_splits = 2
    basename = raw_counts.rsplit(".", num_splits)[0]
    print(basename)
    hdf_file = basename + ".h5"
    if os.path.isfile(hdf_file):
        return hdf_file
    else:
        with pd.HDFStore(hdf_file, "w") as hdf:
            pass

    process_genome = process_chrom == "genome"
    print(raw_counts)
    with xopen.xopen(raw_counts, mode="rt") as bedfile:
        columns = bedfile.readline().strip().split()
        print(columns)
        if not int(columns[1]) == 0:
            raise ValueError("Mapping counts track does not start at beginning of chromosome: {}".format("\t".join(columns)))
        bin_size = int(columns[2]) - int(columns[1])
        assert bin_size > 99, "Bin size {} detected, this code is only optimized for bin sizes >= 100".format(bin_size)
        bedfile.seek(0)

        last_chrom = columns[0]
        correct_counts = []
        incorrect_counts = []
        chroms_seen = set()

        for line in bedfile:
            chrom, start, _, correct_reads, incorrect_reads = line.split()
            if not process_genome:
                if process_chrom != chrom:
                    continue

            if chrom != last_chrom:
                with pd.HDFStore(hdf_file, "a") as hdf:
                    hdf.put(os.path.join(last_chrom, "correct"), pd.Series(correct_counts, dtype=np.int8), format="fixed")
                    hdf.put(os.path.join(last_chrom, "incorrect"), pd.Series(incorrect_counts, dtype=np.int8), format="fixed")

                correct_counts = []
                incorrect_counts = []

                if chrom in chroms_seen:
                    raise ValueError("Mapping counts track file is not sorted - encountered twice: {}".format(chrom))

                if last_chrom == process_chrom:
                    # can stop iteration, processed the one single chromosome that was requested
                    break

                last_chrom = chrom
                chroms_seen.add(chrom)

            correct_reads = int(correct_reads)
            incorrect_reads = int(incorrect_reads)
            assert correct_reads < 127, "Count of correct reads too large (must be < 127): {}".format(correct_reads)
            correct_counts.append(correct_reads)
            incorrect_counts.append(incorrect_reads)

    # dump last
    if correct_counts:
        with pd.HDFStore(hdf_file, "a") as hdf:
            hdf.put(os.path.join(last_chrom, "correct"), pd.Series(correct_counts, dtype=np.int8), format="fixed")
            hdf.put(os.path.join(last_chrom, "incorrect"), pd.Series(incorrect_counts, dtype=np.int8), format="fixed")

    return hdf_file


def main():

    # args = parse_command_line()
    debug = False
    chromosome = snakemake.params.genome_chromosome_param
    # print(chromosome)
    # chromosome = "genome"
    bin_size = 100
    min_mapp = 75
    lengthcorr_bool = False
    # jobs = 1
    jobs = snakemake.threads

    # sample = "RPE1-WT"
    # input_bam = "/scratch/tweber/DATA/MC_DATA/PAPER_ARBIGENT/RPE1-WT/selected"
    # input_bed = "workflow/data/arbigent/scTRIP_segmentation.bed"
    # norm_count_output = "TEST_arbigent_manual_segments.txt.raw"
    # norm_plot_output = "TEST_arbigent_blub.txt"
    # debug_output = "TEST_arbigent.txt.debug"
    # map_counts = "workflow/data/arbigent/mapping_counts_allchrs_hg38.txt"

    sample = snakemake.wildcards.sample
    input_bam = snakemake.params.bam_folder
    input_bed = snakemake.input.bed
    norm_count_output = snakemake.output.processing_counts
    debug_output = snakemake.output.debug
    map_counts = snakemake.input.mapping
    norm_plot_output = snakemake.output.norm_plot_output

    if debug:
        print("=== Hufsah original ===")
        counts(
            sample,
            input_bam,
            input_bed,
            norm_count_output,
            map_counts,
            norm_plot_output,
        )
        # counts(
        #     args.sample,
        #     args.input_bam,
        #     args.input_bed,
        #     args.norm_count_output,
        #     args.map_counts,
        #     args.norm_plot_output
        # )
        return 1

    map_counts_file = convert_mapping_counts(map_counts, chromosome)
    # print(map_counts_file)

    chroms_to_process = []
    if chromosome != "genome":
        # chroms_to_process = [chromosome]
        chroms_to_process = chromosome.split(",")
    else:
        with pd.HDFStore(map_counts_file, "r") as hdf:
            chroms_to_process = set([os.path.dirname(c).strip("/") for c in hdf.keys()])
            print(chroms_to_process)

    param_list = [(c, sample, input_bed, input_bam, map_counts_file, bin_size, min_mapp, lengthcorr_bool) for c in chroms_to_process]
    # print(param_list)
    merge_list = []
    with mp.Pool(min(len(chroms_to_process), jobs)) as pool:
        res_iter = pool.imap_unordered(aggregate_segment_read_counts, param_list)
        for chrom, result in res_iter:
            if result is None:
                # no segments / inversion on that chromosome
                continue
            merge_list.append(result)

    output = pd.concat(merge_list, axis=0, ignore_index=False)
    output.sort_values(["chrom", "start", "end", "cell"], inplace=True)
    reduced_output = ["chrom", "start", "end", "sample", "cell", "C", "W"]

    output[reduced_output].to_csv(norm_count_output, index=False, header=True, sep="\t")

    output[reduced_output].to_csv(norm_plot_output, index=False, header=True, sep="\t")

    output.to_csv(debug_output, index=False, header=True, sep="\t")

    return 0


if __name__ == "__main__":
    main()
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
input_path <- snakemake@input[["counts_scaled"]]
gc_path <- snakemake@params[["gc_matrix"]]
save_path <- snakemake@output[["counts_scaled_gc"]]
plot <- TRUE
min_reads <- snakemake@params[["gc_min_reads"]] # <- 5
n_subsample <- snakemake@params[["gc_n_subsample"]] # <- 1000

print(gc_path)

# open files
counts <- data.table::fread(input_path, header = T)
GC_matrix <- data.table::fread(gc_path, header = T)

# reformat GC_matrix
# find column containing GC counts and rename to 'GC%'

idx <- which(grepl("GC", colnames(GC_matrix), fixed = TRUE))
colnames(GC_matrix)[[idx]] <- "GC%"

# check GC plots
if (plot) {
  # import libraries
  library(ggplot2)
  library(ggpubr)
}


# check files
if (!all(c("cell", "chrom", "start", "end", "w", "c") %in% colnames(counts))) {
  message("count file does not contain required columns: 'cell', 'chrom', 'start', 'end', 'w', 'c'")
  message("Usage: Rscript GC_correction.R count-file.txt.gz gc-matrix.txt output.txt.gz")
  stop()
}
if (!all(c("chrom", "start", "end", "GC%") %in% colnames(GC_matrix))) {
  message("GC_matrix file does not contain required columns: 'chrom', 'start', 'end', 'GC%'")
  message("Usage: Rscript GC_correction.R count-file.txt.gz gc-matrix.txt output.txt.gz")
  stop()
}
if (!(all(unique(counts$chrom) %in% unique(GC_matrix$chrom)) &
  all(unique(counts$start) %in% unique(GC_matrix$start)) &
  all(unique(counts$end) %in% unique(GC_matrix$end)))) {
  message("bin features ('crhom', 'start', 'end') do not match between count file and GC matrix")
  message("make sure to choose files with identical bin sizes")
}


# green light message
# message(paste("\ncount file:", args[1]))
# message(paste("GC matrix file:", args[2]))
# message(paste("savepath:", args[3]))
message("preprocessing...\n")


#################
# Preprocessing #
#################

# force cell column to factor
counts$cell <- as.factor(counts$cell)

# convert strandseq count file to count matrix
counts$tot_count <- counts$c + counts$w

######################
# GC bias correction #
######################


counts <- merge(counts, GC_matrix[, c("chrom", "start", "GC%")], by = c("chrom", "start"), all.x = T)

# filter data for subsampling
c <- counts[counts$tot_count >= min_reads]
if (dim(c)[[1]] == 0) {
  stop(paste("there are no bins with more than", min_reads, "reads"))
}
c$`GC%` <- as.numeric(c$`GC%`)
c$log_count_norm <- log(c$tot_count) - log(median(c$tot_count))
not.na <- !is.na(c$`GC%`)
s <- c[not.na]

# subsample from quantiles
s$GC_bin <- cut(s$`GC%`, breaks = c(quantile(s$`GC%`, probs = seq(0, 1, by = 1 / 10))), labels = seq(1, 10, by = 1), include.lowest = TRUE)

subsample <- data.frame()
for (i in seq(10)) {
  sbin <- s[s$GC_bin == i]
  m <- min(dim(sbin)[1], n_subsample)
  sa <- sbin[sample(nrow(sbin), size = m), ]
  subsample <- rbind(subsample, sa)
}

#############################
# lowess fit and correction #
#############################
# lowess fit
z <- lowess(subsample$`GC%`, subsample$log_count_norm)

# ################
# # SAVING PLOTS #
# ################

# adjust tot count to closest predicted GC value
idxs <- sapply(as.numeric(counts$`GC%`), FUN = function(a) {
  which.min(abs(z$x - a))
})
idxs[lapply(idxs, length) == 0] <- NA
counts$pred <- z$y[unlist(idxs)]
counts$tot_count_gc <- log(counts$tot_count / median(counts$tot_count)) - counts$pred
counts$tot_count_gc <- exp(counts$tot_count_gc) * median(counts$tot_count)

if (plot) {
  sidxs <- sapply(as.numeric(subsample$`GC%`), FUN = function(a) {
    which.min(abs(z$x - a))
  })
  sidxs[lapply(sidxs, length) == 0] <- NA


  subsample$pred <- z$y[unlist(sidxs)]
  subsample$tot_count_gc <- log(subsample$tot_count / median(subsample$tot_count)) - subsample$pred
  subsample$tot_count_gc <- exp(subsample$tot_count_gc) * median(subsample$tot_count)

  z$y2 <- exp(z$y) * median(subsample$tot_count)
  ymin <- min(cbind(subsample$tot_count, subsample$tot_count_gc))
  ymax <- max(cbind(subsample$tot_count, subsample$tot_count_gc))

  p1 <- ggplot(subsample, aes(`GC%`, tot_count)) +
    geom_point(size = 1, alpha = .2) +
    ggtitle("raw") +
    ylim(ymin, ymax) +
    xlab("GC_content") +
    ylab("read count") +
    geom_line(data = as.data.frame(z), aes(x, y2), color = "red")

  p2 <- ggplot(subsample, aes(`GC%`, tot_count_gc)) +
    geom_point(size = 1, alpha = .2) +
    ggtitle("gc corrected") +
    ylim(ymin, ymax) +
    xlab("GC content") +
    ylab("read count")

  corr_plot <- ggarrange(p1, p2)


  # save plots
  ggsave(snakemake@output[["plot"]], corr_plot, width = 12, height = 6)
}

# adjust w, c and fill NAs
counts$w <- (counts$w * counts$tot_count / counts$tot_count_gc)
counts$c <- (counts$c * counts$tot_count / counts$tot_count_gc)
counts$w[is.na(counts$w)] <- 0
counts$c[is.na(counts$c)] <- 0
counts$tot_count[is.na(counts$tot_count)] <- 0

output <- counts[, c("chrom", "start", "end", "sample", "cell", "w", "c", "tot_count", "class")]

data.table::fwrite(output, save_path)
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
counts <- data.table::fread(snakemake@input[["counts"]], header = T)
save_path <- snakemake@output[["counts_scaled"]]
# save_path <- args[3]
# min_reads <- snakemake@params[["gc_min_reads"]]

info_raw <- data.table::fread(snakemake@input[["info_raw"]], skip = 13, header = T, sep = "\t")
# info_raw <- data.table::fread(args[2])

min_reads <- min(info_raw[info_raw$pass1 == 1, ]$good) - 1

#################
# Preprocessing #
#################

# force cell column to factor
counts$cell <- as.factor(counts$cell)

# add tot counts
counts$tot_count <- counts$c + counts$w


# filter cells with too low counts
counts_bycell <- as.data.frame(aggregate(counts$tot_count, by = list(Category = counts$cell), FUN = sum))
sel_cells <- counts_bycell[counts_bycell$x >= as.integer(min_reads), "Category"]
if (length(sel_cells) == 0) {
    stop(paste("there are no cells with more than", min_reads, "total reads"))
}
counts <- counts[counts$cell %in% sel_cells, ]

# convert to bin matrix
count_matrix_raw <- reshape2::dcast(counts, chrom + start + end ~ cell, value.var = "tot_count")
count_matrix <- as.data.frame(count_matrix_raw)

#################
# Normalization #
#################

message(paste("library size normalization for", snakemake@input[["counts"]]))
# take the log of counts
count_matrix[4:ncol(count_matrix)] <- log(count_matrix[4:ncol(count_matrix)])

# calculate mean of the log counts per bin
count_matrix$mean_log_count <- apply(count_matrix[4:ncol(count_matrix)], MARGIN = 1, FUN = mean, na.rm = FALSE)

# filter out infinity
count_matrix <- count_matrix[!is.infinite(count_matrix$mean_log_count), ]
if (dim(count_matrix)[[1]] == 0) {
    stop("there are no common non-zero bins available across all cells.")
}

# log of counts over mean per bin
count_matrix[4:ncol(count_matrix)] <- count_matrix[4:ncol(count_matrix)] - count_matrix$mean_log_count

# median per cell of the per log of counts/mean per bin is the scaling factor
scaling_factors <- apply(count_matrix[4:ncol(count_matrix)], MARGIN = 2, FUN = median)

# raise e to scaling factor
scaling_factors <- exp(scaling_factors)

# rescale libraries
scaled_matrix <- data.table::data.table(count_matrix_raw)

for (i in colnames(scaled_matrix)[4:ncol(scaled_matrix)]) {
    scaled_matrix[[i]] <- scaled_matrix[[i]] / scaling_factors[i]
}

##################
# Postprocessing #
##################

# wide to long
norm_tot_counts <- reshape2::melt(scaled_matrix,
    id.vars = c("chrom", "start", "end"),
    measure.vars = colnames(scaled_matrix)[4:ncol(scaled_matrix)],
    value.name = "norm_tot_count", variable.name = "cell"
)

# merge with counts
cols <- colnames(counts)
counts <- merge(counts, norm_tot_counts, by = c("chrom", "start", "end", "cell"), all.x = TRUE)

# adjust W and C counts to normalized counts
counts$ratio <- counts$norm_tot_count / counts$tot_count
counts$ratio[is.na(counts$ratio)] <- 0
counts$w <- counts$w * counts$ratio
counts$c <- counts$c * counts$ratio
counts$tot_count <- counts$norm_tot_count
# fill na
counts$w[is.na(counts$w)] <- 0
counts$c[is.na(counts$c)] <- 0
counts$tot_count[is.na(counts$tot_count)] <- 0


# saving
message("saving...\n")
data.table::fwrite(counts[, ..cols], save_path)
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
input_path <- snakemake@input[["counts_scaled_gc"]]
save_path <- snakemake@output[["counts_scaled_gc_vst"]]
chosen_transform <- "anscombe"
plot <- TRUE
rescale <- TRUE

# PRE-PROCESSING DATA
# open count file
counts_raw <- data.table::fread(input_path)

# force cell column to factor
counts_raw$cell <- as.factor(counts_raw$cell)

# fuse bin coordinates
counts_raw$bin <- paste(counts_raw$chrom, counts_raw$start, counts_raw$end, sep = "_")

# add tot counts
counts_raw$tot_count <- counts_raw$c + counts_raw$w

counts <- counts_raw

# convert to bin count matrix
to_matrix <- function(counts) {
  # fuse bin coordinates
  counts$bin <- paste(counts$chrom, counts$start, counts$end, sep = "_")
  mat_tot <- reshape2::dcast(counts, bin ~ cell, value.var = "tot_count")
  rownames(mat_tot) <- mat_tot$bin
  mat_tot <- mat_tot[, 2:ncol(mat_tot)]

  return(mat_tot)
}

# VARIANCE STABILIZING TRANSFORMATION

# TRANSFORMS
# for negative binomial distributions
# Anscombe, 1948
# Laubschner, 1961
anscombe_transform <- function(x, phi) {
  a <- x + (3 / 8)
  b <- (1 / phi) - (3 / 4)
  c <- sqrt(a / b)
  y <- asinh(c)
  return(y)
}
laubscher_transform <- function(x, phi) {
  a <- sqrt(phi)
  b <- asinh(sqrt(x / phi))
  c <- sqrt(phi - 1)
  d <- anscombe_transform(x, phi)
  y <- a * b + c * d
  return(y)
}
transform_data <- function(counts, transform, phi) {
  cols <- colnames(counts)
  counts$tot_count_corr <- transform(counts$tot_count, phi)
  counts$f <- counts$tot_count_corr / counts$tot_count
  counts$w <- counts$w * counts$f
  counts$c <- counts$c * counts$f
  counts$tot_count <- counts$tot_count_corr
  counts <- counts[, ..cols]

  return(counts)
}
transform_list <- list("anscombe" = anscombe_transform, "laubscher" = laubscher_transform)
transform <- transform_list[[chosen_transform]]


disp_score <- function(counts, transform, phi, design = NULL) {
  counts$tot_count <- transform(counts$tot_count, phi)
  mat <- to_matrix(counts)
  # if multiple samples are present design matrix can be used
  if (is.null(design)) {
    design <- matrix(1, ncol = 1, nrow = ncol(mat))
  }
  res <- as.matrix(mat) %*% MASS::Null(design)
  rsd <- sqrt(rowMeans(res * res))
  score <- sd(rsd) / mean(rsd)
  return(score)
}

message(paste("Transforming data with", chosen_transform, "VST"))

# estimate dispersion by residual variance

opt <- optimize(disp_score, counts = counts, transform = transform, interval = c(0.00001, 1))
phi <- opt$minimum
message(paste("Estimated dispersion - phi: ", phi))

# correction
corr_counts <- transform_data(counts, transform, phi)
corr_counts <- data.table::data.table(corr_counts[, c("chrom", "start", "end", "sample", "cell", "w", "c", "class", "tot_count")])

rescale_data <- function(counts_original, counts_transformed) {
  rescaled <- counts_transformed
  rescaled_med <- aggregate(rescaled$tot_count, list(rescaled$cell), FUN = median)
  original_med <- aggregate(counts_original$tot_count, list(counts_original$cell), FUN = median)
  m <- merge(x = original_med, y = rescaled_med, by = "Group.1", suffixes = c("_raw", "_norm"))
  m[["f"]] <- m[["x_raw"]] / m[["x_norm"]]

  rescaled <- merge(rescaled, m[c("Group.1", "f")], by.x = "cell", by.y = "Group.1")

  rescaled$tot_count <- rescaled$tot_count * rescaled$f
  rescaled$w <- rescaled$w * rescaled$f
  rescaled$c <- rescaled$c * rescaled$f
  return(rescaled)

}

if (rescale == TRUE) {
  corr_counts <- rescale_data(counts_raw, corr_counts)
} 


message("saving...")
data.table::fwrite(corr_counts, save_path)

if (plot) {
  library(ggplot2)
  library(ggpubr)

  merge_bins <- function(df, bin_size = 3e6) {
    df <- df[with(df, order(cell, chrom, start))]

    df$bin_group <- df$start %/% bin_size


    w <- aggregate(df$w, by = list(df$cell, df$chrom, df$bin_group), FUN = sum)
    c <- aggregate(df$c, by = list(df$cell, df$chrom, df$bin_group), FUN = sum)
    s <- aggregate(df$start, by = list(df$cell, df$chrom, df$bin_group), FUN = function(x) x[[1]])
    e <- aggregate(df$end, by = list(df$cell, df$chrom, df$bin_group), FUN = function(x) x[[length(x)]])
    cl <- aggregate(df$class, by = list(df$cell, df$chrom, df$bin_group), FUN = function(x) names(sort(table(x), decreasing = TRUE))[[1]])

    m <- merge(w, c, by = c("Group.1", "Group.2", "Group.3"))
    m <- merge(m, s, by = c("Group.1", "Group.2", "Group.3"))
    m <- merge(m, e, by = c("Group.1", "Group.2", "Group.3"))
    m <- merge(m, cl, by = c("Group.1", "Group.2", "Group.3"))
    colnames(m) <- c("cell", "chrom", "bin_group", "w", "c", "start", "end", "class")
    m$sample <- unique(df$sample)[[1]]
    m <- data.table::as.data.table(m)
    m <- m[with(m, order(cell, chrom, start))]

    m$tot_count <- m$w + m$c
    return(m)
  }

  wf_plot <- function(m) {
    m$wf <- m$w / m$tot_count

    p1 <- ggplot(m, aes(x = tot_count, y = wf)) +
      geom_point(size = 1, alpha = .1, shape = 16) +
      xlab("tot count") +
      ylab("watson fraction") +
      ylim(0, 1)
    return(p1)
  }

  p1 <- ggplot(counts_raw, aes(x = tot_count)) +
    geom_histogram(bins = 256) +
    ggtitle("raw") +
    xlab("read count") +
    ylab("bin count")

  p2 <- ggplot(corr_counts, aes(x = tot_count)) +
    geom_histogram(bins = 256) +
    ggtitle(paste(chosen_transform, "VST")) +
    xlab("read count") +
    ylab("bin count")

  m <- merge_bins(counts_raw)
  p3 <- wf_plot(m) + ggtitle('raw')

  n <- merge_bins(corr_counts)
  p4 <- wf_plot(n) + ggtitle(paste(chosen_transform, "VST"))



  m <- merge_bins(counts_raw)
  p3 <- wf_plot(m) + ggtitle("raw")

  n <- merge_bins(corr_counts)
  p4 <- wf_plot(n) + ggtitle(paste(chosen_transform, "VST"))


  corr_plot <- ggarrange(p1, p2, p3, p4)

  ggsave(snakemake@output[["plot"]], corr_plot, width = 12, height = 6)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

system("LC_MEASUREMENT=C")

source("workflow/scripts/haplotagging_scripts/haplotagTable.R")

paired_end <- sub("\n", "", readChar(snakemake@input[["paired_end"]], file.info(snakemake@input[["paired_end"]])$size))

tab <- getHaplotagTable2(bedFile = snakemake@input[["bed"]], bam.file = snakemake@input[["bam"]], file.destination = snakemake@output[["tsv"]], paired_end = paired_end)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

source("workflow/scripts/mosaiclassifier_scripts/haplotagProbs.R")

haplotagCounts <- fread(snakemake@input[["haplotag_table"]])
probs <- readRDS(snakemake@input[["sv_probs_table"]])

# FIXME : tmp solution to fix error : Segments must covered all bins, which happen for small scaffolds
chroms <- snakemake@config[["chromosomes"]]

haplotagCounts <- haplotagCounts[haplotagCounts$chrom %in% chroms, ]


# FIXME: quick and dirty fix for off by one start coordinates of segments
haplotagCounts[, start := start - 1]

probs <- addHaploCountProbs(probs, haplotagCounts, alpha = 0.05)

saveRDS(probs, file = snakemake@output[[1]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
library(data.table)
source("workflow/scripts/mosaiclassifier_scripts/mosaiClassifier/makeSVcalls.R")

probs <- readRDS(snakemake@input[["probs"]])
llr <- as.numeric(snakemake@wildcards[["llr"]])
bin_size <- as.numeric(snakemake@wildcards[["window"]])

probs <- mosaiClassifierPostProcessing(probs)
probs <- forceBiallelic(probs)
tab <- makeSVCallSimple(probs, llr_thr = llr, bin.size = bin_size)

write.table(tab, file = snakemake@output[[1]], sep = "\t", quote = F, row.names = F, col.names = T)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

library(data.table)
source("workflow/scripts/mosaiclassifier_scripts/mosaiClassifier/makeSVcalls.R")

# probs <- readRDS(snakemake@input[["probs"]])
# llr <- as.numeric(snakemake@wildcards[["llr"]])
# use.pop.priors <- eval(parse(text = snakemake@wildcards[["pop_priors"]]))
# use.haplotags <- eval(parse(text = snakemake@wildcards[["use_haplotags"]]))
# regularizationFactor <- 10^(-as.numeric(snakemake@wildcards[["regfactor"]]))
# genotype.cutoff <- as.numeric(snakemake@wildcards[["gtcutoff"]])
# minFrac.used.bins <- as.numeric(snakemake@params[["minFrac_used_bins"]])
# bin.size <- as.numeric(snakemake@params[["window"]])


probs <- readRDS(snakemake@input[["probs"]])
llr <- as.numeric(snakemake@params[["llr"]])
use.pop.priors <- eval(parse(text = snakemake@params[["pop_priors"]]))
use.haplotags <- eval(parse(text = snakemake@params[["use_haplotags"]]))
regularizationFactor <- 10^(-as.numeric(snakemake@params[["regfactor"]]))
genotype.cutoff <- as.numeric(snakemake@params[["gtcutoff"]])
minFrac.used.bins <- as.numeric(snakemake@params[["minFrac_used_bins"]])
bin.size <- as.numeric(snakemake@params[["window"]])

# print(probs)
# print(llr)
# print(use.pop.priors)
# print(regularizationFactor)
# print(genotype.cutoff)
# print(minFrac.used.bins)
# print(bin.size)
# stop()

probs <- mosaiClassifierPostProcessing(probs, regularizationFactor = regularizationFactor)

# print(probs)

tab <- makeSVCallSimple(probs, llr_thr = llr, use.pop.priors = use.pop.priors, use.haplotags = use.haplotags, genotype.cutoff = genotype.cutoff, bin.size, minFrac.used.bins = minFrac.used.bins)

# print(tab)

write.table(tab, file = snakemake@output[[1]], sep = "\t", quote = F, row.names = F, col.names = T)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
sink(snakemake@log[[1]])
library(data.table)
library(assertthat)
source("workflow/scripts/mosaiclassifier_scripts/mosaiClassifier/mosaiClassifier.R")
system("LC_CTYPE=C")

# Currently read files from the Snakemake pipeline
counts <- fread(paste("zcat", snakemake@input[["counts"]]))
info <- fread(snakemake@input[["info"]])
strand <- fread(snakemake@input[["states"]])
segs <- fread(snakemake@input[["bp"]])


# FIXME : tmp solution to fix error : Segments must covered all bins, which happen for small scaffolds
chroms <- snakemake@config[["chromosomes"]]

counts <- counts[counts$chrom %in% chroms, ]
strand <- strand[strand$chrom %in% chroms, ]
segs <- segs[segs$chrom %in% chroms, ]

# haplotypeMode?
if ("CW" %in% strand$class) {
  haplotypeMode <- T
} else {
  haplotypeMode <- F
}

d <- mosaiClassifierPrepare(counts, info, strand, segs)
e <- mosaiClassifierCalcProbs(d, maximumCN = 4, haplotypeMode = haplotypeMode)

saveRDS(e, file = snakemake@output[[1]])
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
import sys
from argparse import ArgumentParser
import pandas as pd


def main():
    parser = ArgumentParser(prog="merge-blacklist.py", description=__doc__)
    parser.add_argument(
        "--merge_distance",
        default=500000,
        type=int,
        help="If the distance between two blacklisted intervals is below this threshold, they are merged.",
    )
    parser.add_argument(
        "--whitelist", default=None, help="TSV file with intervals to be removed from the blacklist (columns: chrom, start, end)."
    )
    parser.add_argument("--min_whitelist_interval_size", default=400000, type=int, help="Ignore whitelisted intervals below this size.")

    parser.add_argument("normalization", metavar="NORM", help="File (tsv) with normalization and blacklist data")

    args = parser.parse_args()

    print("Reading", args.normalization, file=sys.stderr)
    norm_table = pd.read_csv(args.normalization, sep="\t")

    assert set(norm_table.columns) == set(["chrom", "start", "end", "scalar", "class"])

    whitelist = None
    if args.whitelist is not None:
        whitelist = pd.read_csv(args.whitelist, sep="\t")
        assert set(whitelist.columns) == set(["chrom", "start", "end"])
        print("Read", len(whitelist), "whitelisted intervals from", args.whitelist, file=sys.stderr)
        whitelist = whitelist[whitelist.end - whitelist.start >= args.min_whitelist_interval_size]
        print("  -->", len(whitelist), "remained after removing intervals below", args.min_whitelist_interval_size, "bp", file=sys.stderr)

    additional_blacklist = 0
    prev_blacklist_index = None
    prev_blacklist_chrom = None
    prev_blacklist_end = None
    for i in range(len(norm_table)):
        row = norm_table.iloc[i]
        # print('Processing row', i, ' -->', tuple(row), file=sys.stderr)
        # is row blacklisted?
        if row["class"] == "None":
            # print(' --> is black', file=sys.stderr)
            if (prev_blacklist_chrom == row["chrom"]) and (row["start"] - prev_blacklist_end <= args.merge_distance):
                # print(' --> black listing', prev_blacklist_index+1, 'to', i, file=sys.stderr)
                for j in range(prev_blacklist_index + 1, i):
                    norm_table.loc[[j], "class"] = "None"
                    row_j = norm_table.iloc[j]
                    additional_blacklist += row_j.end - row_j.start
            prev_blacklist_index = i
            prev_blacklist_chrom = row["chrom"]
            prev_blacklist_end = row["end"]

    print("Additionally blacklisted", additional_blacklist, "bp of sequence", file=sys.stderr)

    additional_whitelist = 0
    if whitelist is not None:
        for i in range(len(norm_table)):
            row = norm_table.iloc[i]
            if row["class"] == "None":
                if len(whitelist[(whitelist.chrom == row.chrom) & (row.start < whitelist.end) & (whitelist.start < row.end)]) > 0:
                    norm_table.loc[[i], "class"] = "good"
                    additional_whitelist += row.end - row.start

    print("White listing: Removed", additional_whitelist, "bp of sequence for blacklist", file=sys.stderr)

    norm_table.to_csv(sys.stdout, index=False, sep="\t")

    ## Identify "complex" intervals
    # segments = calls.groupby(by=['chrom','start','end']).sv_call_name.agg({'is_complex':partial(is_complex, ignore_haplotypes=args.ignore_haplotypes, min_cell_count=args.min_cell_count)}).reset_index().sort_values(['chrom','start','end'])

    ## merge complex segments if closer than args.merge_distance
    # complex_segments = pd.DataFrame(columns=['chrom','start','end'])
    # cur_chrom, cur_start, cur_end = None, None, None
    # for chrom, start, end in segments[segments.is_complex][['chrom','start','end']].values:
    # if cur_chrom is None:
    # cur_chrom, cur_start, cur_end = chrom, start, end
    # elif (cur_chrom == chrom) and (start - cur_end < args.merge_distance):
    # cur_end = end
    # else:
    # complex_segments = complex_segments.append({'chrom': cur_chrom, 'start': cur_start,'end': cur_end}, ignore_index=True)
    # cur_chrom, cur_start, cur_end = chrom, start, end
    # if cur_chrom is not None:
    # complex_segments = complex_segments.append({'chrom': cur_chrom, 'start': cur_start,'end': cur_end}, ignore_index=True)

    # print(complex_segments, file=sys.stderr)
    # total_complex = sum(complex_segments.end - complex_segments.start)

    # print('Total amount of complex sequence: {}Mbp'.format(total_complex/1000000), file=sys.stderr)
    # complex_segments[['chrom','start','end']].to_csv(sys.stdout, index=False, sep='\t')
    ##print(complex_segments, file=sys.stderr)


if __name__ == "__main__":
    main()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
suppressMessages(library(dplyr))
suppressMessages(library(data.table))
suppressMessages(library(assertthat))


args <- commandArgs(trailingOnly = T)
if (length(args) != 4) {
  print("Usage: Rscript scale.R <count table> <norm factors> <out>")
  print("")
  print("       Normalize Strand-seq read counts. Divide the counts of all bins")
  print("       by a scaling factor (norm$scalar) and further black-list bins")
  print("       if requested in the normalizatio file (norm$class).")
  options(show.error.messages = F)
  stop()
}

# Read counts
message(" * Reading counts from ", args[1])
counts <- fread(paste("zcat", args[1]))
assert_that(
  is.data.table(counts),
  "chrom" %in% colnames(counts),
  "start" %in% colnames(counts),
  "end" %in% colnames(counts),
  "class" %in% colnames(counts),
  "sample" %in% colnames(counts),
  "cell" %in% colnames(counts),
  "w" %in% colnames(counts),
  "c" %in% colnames(counts)
) %>% invisible()
setkey(counts, chrom, start, end)

# Check that all cells have the same bins
bins <- unique(counts[, .(chrom, start, end)])
counts[,
  assert_that(all(.SD == bins), msg = "Not the same bins in all cells"),
  by = .(sample, cell),
  .SDcols = c("chrom", "start", "end")
] %>% invisible()


# remove bad cells
bad_cells <- counts[class == "None", .N, by = .(sample, cell)][N == nrow(bins)]
if (nrow(bad_cells) > 0) {
  message(" * Removing ", nrow(bad_cells), " cells because thery were black-listed.")
  counts <- counts[!bad_cells, on = c("sample", "cell")]
}

# Check that the "None" bins are all the same across cells
none_bins <- unique(counts[!bad_cells, on = c("sample", "cell")][class == "None", .(chrom, start, end)])
if (nrow(none_bins) > 0) {
  counts[!bad_cells, on = c("sample", "cell")][class == "None",
    assert_that(all(.SD == none_bins, msg = "None bins are not the same in all cells (excl. bad cells)")),
    by = .(sample, cell),
    .SDcols = c("chrom", "start", "end")
  ] %>% invisible()
}


# Read normalization factors
message(" * Reading norm file from ", args[2])
norm <- fread(args[2])
assert_that(
  is.data.table(norm),
  "chrom" %in% colnames(norm),
  "start" %in% colnames(norm),
  "end" %in% colnames(norm),
  "scalar" %in% colnames(norm)
) %>% invisible()
if ("class" %in% colnames(norm)) {
  norm <- norm[, .(chrom, start, end, scalar, norm_class = class)]
} else {
  norm <- norm[, .(chrom, start, end, scalar, norm_class = "good")]
}
setkey(norm, chrom, start, end)

# Check normalization type from snakemake
if (args[4] == "False") {
  norm$scalar <- 1
}

# Set particular values of the norm_class to "None":
norm[scalar < 0.01, norm_class := "None"]


# annotate counts with scaling factor
counts <- merge(counts,
  norm,
  by = c("chrom", "start", "end"),
  all.x = T
)

if (any(is.na(counts$scalar))) {
  message(
    " * Assign scalars: Could not match ",
    unique(counts[, .(chrom, start, end, scalar)])[is.na(scalar), .N],
    " bins (out of ",
    unique(counts[, .(chrom, start, end)])[, .N],
    ") -> set those to 1"
  )
}

# Fill gaps in the norm file
counts[is.na(scalar), `:=`(scalar = 1, norm_class = "good")]

# Black-listing bins
test <- counts[!bad_cells, on = c("sample", "cell")][cell == unique(cell)[1]]
test <- test[, .(
  count_None = sum(class == "None"),
  norm_None = sum(norm_class == "None"),
  final_None = sum(class == "None" | norm_class == "None")
)]
message(" * ", test$count_None, " bins were already black-listed; ", test$norm_None, " are blacklisted via the normalization, leading to a total of ", test$final_None)
counts[norm_class == "None", class := "None"]



# Apply normalization factor
counts[, `:=`(c = as.numeric(c), w = as.numeric(w))]

counts[class != "None", `:=`(
  c = c * scalar,
  w = w * scalar
)]

# # Check normalization type from snakemake
# if (args[4] == "True") {
#   counts[class != "None", `:=`(
#     c = c * scalar,
#     w = w * scalar
#   )]
# } else {
#   scalar <- 1

# }



counts[class == "None", `:=`(c = 0.0, w = 0.0)]

message(
  " * Applying normalization: min = ",
  round(min(counts[class != "None", scalar]), 3),
  ", max = ",
  round(max(counts[class != "None", scalar]), 3),
  ", median = ",
  median(unique(counts[, .(chrom, start, end, class, scalar)][class != "None", scalar]))
)


# Remove column
counts[, norm_class := NULL]
counts[, scalar := NULL]


# Write down table
message(" * Write data to ", args[3])
gz1 <- gzfile(args[3], "w")
write.table(counts, gz1, sep = "\t", quote = F, col.names = T, row.names = F)
close(gz1)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import pandas as pd
import numpy as np

ploidy_detailed = pd.read_csv(snakemake.input[0], sep="\t")
ploidy_detailed = ploidy_detailed.loc[ploidy_detailed["#chrom"] != "genome"]

m_f = "M"

x_mean = ploidy_detailed.loc[ploidy_detailed["#chrom"] == "chrX", "ploidy_estimate"].mean()
if x_mean > 0:
    m_f = "F" if x_mean >= 2 else "M"

ploidy_detailed["sex"] = m_f
ploidy_detailed["start"] = ploidy_detailed["start"] + 1
ploidy_detailed.loc[ploidy_detailed["ploidy_estimate"] > 2, "ploidy_estimate"] = 2
ploidy_detailed = ploidy_detailed[["#chrom", "start", "end", "sex", "ploidy_estimate"]]
ploidy_detailed.to_csv(snakemake.output[0], sep="\t", index=False, header=False)
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
__author__ = "Tobias Marschall"
__maintainer__ = "Peter Ebert"
__credits__ = ["Tobias Marschall", "Peter Ebert", "Tania Christiansen"]
__status__ = "Prototype"

import os as os
import sys as sys
import collections as col
import argparse as argp
import traceback as trb
import logging as log
import warnings as warn
import multiprocessing as mp

import intervaltree as ivt
import scipy.stats as stats
import numpy as np
import pandas as pd


logger = log.getLogger(__name__)


def parse_command_line():
    """
    :return:
    """
    parser = argp.ArgumentParser(prog="ploidy-estimator.py", description=__doc__)
    parser.add_argument("--debug", "-d", action="store_true", default=False, help="Print debugging messages to stderr.")
    parser.add_argument(
        "--input",
        "-i",
        type=str,
        dest="input",
        required=True,
        help="Gzipped, tab-separated table with Watson/Crick read " "counts in fixed bins. A header line is required.",
    )
    parser.add_argument("--output", "-o", type=str, dest="output", required=True, help="Full path to output text file.")
    parser.add_argument("--log", type=str, dest="log", required=True, help="Full path to log text file.")

    parser.add_argument(
        "--dump-table",
        "-tab",
        type=str,
        dest="table",
        default="",
        help="Specify a path to a file to dump the table of "
        "Watson fractions. Note that this happens before "
        "any potential filtering of blacklist regions. "
        "Note that the header line is prefixed with an "
        '"#", and the fields are tab-separated; in other words, '
        "the output format is BED-like."
        "Default: <none>",
    )
    parser.add_argument(
        "--blacklist-regions",
        "-b",
        type=str,
        dest="blacklist",
        default="",
        help="Specify file with regions to be blacklisted. Only " "chrom - start - end will be read from the file. " "Default: <none>",
    )
    parser.add_argument("--max-ploidy", default=4, type=int, dest="max_ploidy", help="Maximum ploidy that is considered. " " Default: 4")
    parser.add_argument(
        "--boundary-alpha",
        "-a",
        type=float,
        default=0.05,
        dest="alpha",
        help="Adjust means of Gaussians at the boundaries (0, 1) by alpha "
        "to account for a some noise in the data (imperfect ratios). "
        "Default: 0.05",
    )
    parser.add_argument(
        "--merge-bins-to",
        "-m",
        type=int,
        default=1000000,
        dest="window",
        help="Merge the input bins to windows of this size. " "Default: 1000000",
    )
    parser.add_argument(
        "--shift-window-by", "-s", type=int, default=500000, dest="step", help="Shift merge window by this step size. " "Default: 500000"
    )
    parser.add_argument(
        "--uniform-background",
        "-ubg",
        action="store_true",
        default=False,
        dest="background",
        help="Add an additional component to the mixture "
        "model (uniform distribution) as a "
        "noise/background component. "
        "Default: False",
    )
    parser.add_argument(
        "--sort-input",
        "-si",
        action="store_true",
        default=False,
        dest="sort",
        help="Set this option if the input data is NOT sorted by: " "cell > chrom > start > end " "Default: False",
    )
    parser.add_argument("--jobs", "-j", default=1, type=int, dest="jobs", help="Specify number of CPU cores to use. " "Default: 1")
    args = parser.parse_args()
    return args


class Mixture:
    def __init__(self, means, weights, background):
        assert means.size == weights.size
        self.means = means
        self.weights = weights
        self.stddevs = np.repeat([0.5], means.size)
        if background:
            self.wbg = 0.5 * weights.min()
            s = self.weights.sum() + self.wbg
            self.wbg /= s
            self.weights /= s
            assert np.isclose(self.weights.sum() + self.wbg, 1, atol=1e-6)
        else:
            self.wbg = -1

    def fit_stddevs_meanprop(self, fractions):
        """Fit standard deviations so that they are proportional to the means"""
        n = 0
        v = 1.0
        # make shape of input array compatible
        # for vectorized operations
        matrix = np.tile(fractions, self.means.size).reshape(self.means.size, fractions.size).transpose()
        posteriors = np.zeros_like(matrix)
        while True:
            for idx, (mean, stddev, weight) in enumerate(zip(self.means, self.stddevs, self.weights)):
                posteriors[:, idx] = weight * stats.norm.pdf(fractions, mean, stddev)
            posteriors /= posteriors.sum(axis=1, keepdims=True)
            new_v = posteriors * np.abs(matrix - self.means) / self.weights
            new_v = new_v.sum() / fractions.size
            assert not np.isnan(new_v), "new_v is NaN in iteration: {}".format(n)
            self.stddevs = self.weights * new_v
            if np.isclose(v, new_v, atol=1e-10):
                break
            v = new_v
            n += 1
            logger.debug(n)
            if n > 1000:
                raise RuntimeError("Fitting process does not converge - last v estimate: {}".format(new_v))
        return

    def log_likelihood(self, fractions):
        probs = np.zeros((fractions.size, self.means.size), dtype=np.float64)
        for idx, (m, s, w) in enumerate(zip(self.means, self.stddevs, self.weights)):
            probs[:, idx] = w * stats.norm.pdf(fractions, m, s)
        loglik = np.log(probs.sum(axis=1)).sum()
        return loglik, self.stddevs


###################################################
# Following: functions writing output
###################################################


def dump_fraction_table(dataset, out_path):
    """
    This function dumps the table containing all
    Watson fractions (all genomic bins, all cells).
    This is an intermediate result and may just be
    dumped to use the data in other tools

    :param dataset:
    :param out_path:
    :return:
    """
    out_path = os.path.abspath(out_path)
    logger.debug("Dumping table of Watson fractions at path: {}".format(out_path))
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    dataset.sort_values(["chrom", "start", "end"], inplace=True)
    with open(out_path, "w") as dump:
        _ = dump.write("#")  # write a BED-like file
        dataset.to_csv(dump, sep="\t", header=True, index=False)
    logger.debug("Dump complete")
    return


def write_ploidy_estimation_table(output_table, output_path):
    """
    :param output_table:
    :param output_path:
    :return:
    """
    out_path = os.path.abspath(output_path)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    logger.debug("Writing ploidy estimates to path: {}".format(out_path))
    with open(out_path, "w") as out:
        _ = out.write("#")  # write a BED-like file
        output_table.to_csv(out, sep="\t", header=True, index=False)
    logger.debug("Output table saved to disk")
    return


###################################################
# Following: function parsing input data
# plus utility functions (validation, merging etc.)
###################################################


def compute_watson_fractions(input_file, window_size, shift_size, not_sorted, jobs):
    """
    This function parses the input table, calls several sub-ordinate utility
    functions, and generates the data table containing Watson fraction of reads
    for all genomic bins and all cells. This table is then the input for the
    mixture model step.

    :param input_file:
    :param window_size:
    :param shift_size:
    :param not_sorted:
    :param jobs:
    :return:
    """
    logger.debug("Reading input data from file: {}".format(input_file))
    df = pd.read_csv(input_file, sep="\t", header=0, index_col=False, usecols=["chrom", "start", "end", "cell", "c", "w"])
    cells_in_data = set(df["cell"].unique())
    chrom_in_data = set(df["chrom"].unique())
    if not_sorted:
        logger.debug("Sorting input data...")
        df.sort_values(["cell", "chrom", "start", "end"], ascending=True, inplace=True)
        df.reset_index(drop=True, inplace=True)
    logger.debug("Read input dataset with {} rows".format(df.shape[0]))
    logger.debug("Number of individual cells in data: {}".format(len(cells_in_data)))
    logger.debug("Number of chromosomes in input data: {}".format(len(chrom_in_data)))

    select_range, shift_step = check_bin_window_compatibility(df, window_size, shift_size)
    logger.debug("Sanity checks on input completed")

    subsets_to_process = [
        (select_range, shift_step, window_size, subset_id, subset_counts) for subset_id, subset_counts in df.groupby(["cell", "chrom"])
    ]
    n_subsets = len(subsets_to_process)
    logger.debug("Generated {} subsets of input data to process".format(n_subsets))

    fraction_dataset = col.defaultdict(list)
    sanity_checks = col.defaultdict(set)
    with mp.Pool(min(n_subsets, jobs)) as pool:
        logger.debug("Worker pool initialized - processing input subsets...")
        resiter = pool.imap_unordered(process_count_subset, subsets_to_process)
        for cell, chrom, fractions in resiter:
            fraction_dataset[chrom].append(fractions)
            sanity_checks[cell].add(chrom)
    logger.debug("Input processing complete")

    assert len(sanity_checks.keys()) == len(cells_in_data), "Missing data for cell(s): {}".format(
        sorted(cells_in_data - set(sanity_checks.keys()))
    )
    for cell, chroms in sanity_checks.items():
        print(cell, chroms, len(chroms), chrom_in_data, len(chrom_in_data))
        assert len(chroms) == len(chrom_in_data), "Missing chromosome(s) for cell {}: {}".format(cell, chrom_in_data - chroms)
    logger.debug("Begin data merging...")
    fraction_dataset = merge_chromosome_subsets(fraction_dataset)
    return fraction_dataset


def check_bin_window_compatibility(count_data, window_size, shift_size):
    """
    Check that input data follows assumptions:
    - non-overlapping bins
    - fixed bin size
    - merge window size is (integer) multiple of input bin width
    - shift window size is (integer) multiple of input bin width

    Note that this is only checked for the top two entries,
    otherwise garbage in - garbage out applies.

    :param count_data:
    :param window_size:
    :param shift_size:
    :return:
    """
    bin1 = count_data.at[0, "end"] - count_data.at[0, "start"]
    bin2 = count_data.at[1, "end"] - count_data.at[1, "start"]
    if not bin1 == bin2:
        raise ValueError("Unequal bin sizes in dataset detected: {} vs {}".format(bin1, bin2))
    if not (bin1 > 0 and bin2 > 0):
        raise ValueError("Bin size is not greater zero: {} or {}".format(bin1, bin2))
    if not count_data.at[0, "end"] <= count_data.at[1, "start"]:
        raise ValueError("Input bins are overlapping: " "(0) {} > {} (1)".format(count_data.at[0, "end"], count_data.at[1, "start"]))

    if not window_size % bin1 == 0:
        raise ValueError(
            "User-specified merge window size is not a " "multiple of input bin size: {} mod {} != 0".format(window_size, bin1)
        )
    # select_range: when iterating the input
    # bins, this value indicates how many
    # bins to merge (the slice to select)
    select_range = window_size // bin1

    if not shift_size % bin1 == 0:
        raise ValueError("User-specified window shift step is not a " "multiple of input bin size: {} mod {} != 0".format(shift_size, bin1))
    # shift_range: when iterating the input
    # bins, this value determines the step
    # size to make as in:
    # range(start, end, shift_range)
    shift_step = shift_size // bin1
    return select_range, shift_step


def process_count_subset(params):
    """
    This function is supposed to be the map/apply function
    for the child/worker processes. It receives a subset
    of the input data, i.e., one combination of cell and
    chromosome (e.g., chr1 for cellA), and computes the
    Watson fraction of reads. It skips all windows smaller
    than the user-specified merge window size. Typically,
    this would skip the last window of a chromosome.
    By construction, the output of this function must not
    contain NaN or otherwise invalid values.

    :param params: parameters passed as tuple for simplicity
        together with multiprocessing.pool.imap_unordered
    :return:
    """
    select_range, shift_step, window_size, subset_id, count_data = params
    cell, chrom = subset_id
    w_fractions = []
    window_labels = []

    with warn.catch_warnings():
        warn.simplefilter("error")
        for begin in range(0, count_data.shape[0], shift_step):
            # NB: iloc is important as grouping may throw off
            # DF.Index and - moreover - label-based lookup is
            # inclusive in Pandas
            merge_range = count_data.iloc[begin : begin + select_range, :]
            start = merge_range["start"].min()
            end = merge_range["end"].max()
            if (end - start) != window_size:
                # skip last incomplete window
                break
            w_fraction = -1
            try:
                w_fraction = merge_range["w"].sum() / (merge_range["w"].sum() + merge_range["c"].sum())
            except RuntimeWarning:
                if w_fraction == -1 or np.isnan(w_fraction):
                    w_fraction = 0.0
                else:
                    raise
            assert not np.isnan(w_fraction), "NaN w_fraction in {}: {} / {}".format(subset_id, begin, begin + select_range)
            assert 0 <= w_fraction <= 1, "Out-of-bounds w_fraction in {}: {} / {} / {}".format(
                subset_id, w_fraction, begin, begin + select_range
            )
            w_fractions.append(w_fraction)
            window_labels.append("{}_{}_{}".format(chrom, start, end))

    # This function returns a
    # subset (i.e., one chromosome)
    # of one column of the final
    # data table containing
    # Watson fraction of reads
    # [the data table is
    # "genomic bins" X "cells"]
    sub_column = pd.Series(w_fractions, index=window_labels, dtype=np.float64)
    sub_column.name = cell
    return cell, chrom, sub_column


def merge_chromosome_subsets(fractions_by_chrom):
    """
    Since the input data is first split into
    (cell, chromosome) partitions to be processed
    in parallel, this function performs the merging
    in two stages: first, merge all Watson fractions
    per chromosome (aggregate over cells), then merge
    all chromosomes into the final data table.
    The final data table is augmented with genomic
    coordinates as "chrom" "start" "end" (extracted
    from the index of the individual partitions)

    :param fractions_by_chrom:
    :return:
    """
    chrom_subsets = []
    for chrom, chrom_data in fractions_by_chrom.items():
        tmp = pd.concat(chrom_data, axis=1, ignore_index=False)
        # this sorts by cell names
        tmp.sort_index(axis=1, inplace=True)
        chrom_subsets.append(tmp)
    logger.debug("Merging data by chromosome complete")

    chrom_subsets = pd.concat(chrom_subsets, axis=0, ignore_index=False)
    logger.debug("Merged chromosome subsets into final table of size: {} x {}".format(*chrom_subsets.shape))

    logger.debug("Adding genomic coordinates to final dataset")
    coordinates = chrom_subsets.index.str.extract("([a-zA-Z0-9]+)_([0-9]+)_([0-9]+)", expand=True)
    coordinates.index = chrom_subsets.index
    coordinates.columns = ["chrom", "start", "end"]

    chrom_subsets = pd.concat([coordinates, chrom_subsets], axis=1, ignore_index=False)
    chrom_subsets["start"] = chrom_subsets["start"].astype(np.int32)
    chrom_subsets["end"] = chrom_subsets["end"].astype(np.int32)
    chrom_subsets.reset_index(drop=True, inplace=True)
    logger.debug("Data merging complete")
    return chrom_subsets


###################################################
# Following: function marking blacklist regions
###################################################


def mark_blacklist_regions(frac_data, blacklist_file):
    """
    Mark blacklisted regions by setting all data values
    to -1. Note that blacklisted regions are not treated
    any different than regular regions to simplify further
    processing of the dataset.

    :return:
    """
    filter_trees = col.defaultdict(ivt.IntervalTree)
    bl_count = 0
    logger.debug("Reading blacklist regions from file {}".format(blacklist_file))
    with open(blacklist_file, "r") as skip:
        for line in skip:
            if line.startswith("#") or not line.strip() or "chrom" in line:
                continue
            parts = line.strip().split()
            chrom, start, end = parts[:3]
            filter_trees[chrom].addi(int(start), int(end))
            bl_count += 1
    logger.debug("Read {} blacklisted intervals".format(bl_count))

    blacklist_indices = []
    removed = 0
    for row in frac_data.itertuples():
        if filter_trees[row.chrom].overlap(row.start, row.end):
            removed += 1
            blacklist_indices.append(row.Index)

    cell_columns = [c for c in frac_data.columns if c not in ["chrom", "start", "end"]]
    frac_data.loc[blacklist_indices, cell_columns] = -1
    logger.debug("Marked {} regions as blacklisted".format(removed))
    return frac_data


###################################################
# Following: functions for actual ploidy / CN
# estimation (done in parallel)
###################################################


def run_ploidy_estimation(dataset, max_ploidy, alpha, background, jobs):
    """
    This function uses a child/worker pool to compute the ploidy
    estimates in parallel. The ploidy estimates are augmented
    with a genome-wide info about the relative occurrences of the
    individual ploidy / CN states, and adds the most common ploidy
    state as the last column of that row.

    :param dataset:
    :param max_ploidy:
    :param alpha:
    :param background:
    :param jobs:
    :return:
    """
    logger.debug("Running ploidy estimation")
    logger.debug("Assume highest ploidy is: {}".format(max_ploidy))
    logger.debug("Correct boundary means by alpha of: {}".format(alpha))
    logger.debug("Add uniform background/noise component: {}".format(background))

    process_params = [(max_ploidy, alpha, background, row) for _, row in dataset.iterrows()]
    n_params = len(process_params)
    logger.debug("Created parameter list of size {} to process".format(n_params))

    ploidy_counter = col.Counter()
    output_table = []
    logger.debug("Start model fitting")
    with mp.Pool(min(n_params, jobs)) as pool:
        resiter = pool.imap_unordered(process_segment, process_params)
        for row in resiter:
            ploidy_counter[row[-1]] += 1
            output_table.append(row)
    logger.debug("Model fitting complete")

    table_header = ["chrom", "start", "end"]
    table_header.extend(["logLH-ploidy-{}".format(p) for p in range(1, max_ploidy + 1)])
    table_header.append("ploidy_estimate")

    output_table = pd.DataFrame(output_table, columns=table_header)
    output_table.sort_values(["chrom", "start"], inplace=True)
    assert not pd.isnull(output_table).any(axis=1).any(), "LogLH table contains NULL values"

    logger.debug("Adding genome information to ploidy estimation table")
    genome_size = dataset.groupby("chrom")["end"].max().sum()
    logger.debug("Total size of genome: {}".format(genome_size))
    gw_row = ["genome", 0, genome_size]
    for p in range(1, max_ploidy + 1):
        # NB: n_params = number of genomic bins
        # => total number of ploidy estimates
        # (includes blacklisted regions)
        f = ploidy_counter[p] / n_params
        gw_row.append(f)

    gw_ploidy = ploidy_counter.most_common(1)[0][0]
    logger.debug("Most common ploidy genome-wide: {}".format(gw_ploidy))
    gw_row.append(gw_ploidy)
    gw_row = pd.DataFrame([gw_row], columns=output_table.columns)

    output_table = pd.concat([output_table, gw_row], axis=0, ignore_index=False)
    output_table.reset_index(drop=True, inplace=True)

    logger.debug("Ploidy estimation table finalized")

    return output_table


def process_segment(parameters):
    """
    :param parameters:
    :return:
    """
    max_ploidy, epsilon, background, data_row = parameters
    chrom, start, end = data_row[:3]
    assert isinstance(chrom, str), "Invalid chromosome name: {}".format(chrom)
    assert isinstance(end, int), "Invalid end coordinate: {}".format(end)
    frac_values = np.array(data_row[3:], dtype=np.float64)
    if frac_values[0] < 0:
        # region is blacklisted
        return [chrom, start, end] + [-1] * (max_ploidy + 1)

    out_row = [chrom, start, end]
    loglik = []
    for ploidy in np.arange(1, max_ploidy + 1, step=1):
        binom_dist = stats.binom(ploidy, 0.5)
        means = np.arange(ploidy + 1) / ploidy
        means[0] = epsilon
        means[-1] = 1 - epsilon

        weights = np.array([binom_dist.pmf(i) for i in range(ploidy + 1)], dtype=np.float64)
        mixture = Mixture(means=means, weights=weights, background=background)
        try:
            mixture.fit_stddevs_meanprop(frac_values)
        except AssertionError as ae:
            ae.args += ("segment_id", chrom, start, end)
            raise
        likelihood, out_stdev = mixture.log_likelihood(frac_values)
        loglik.append((likelihood, ploidy))
        out_row.append(likelihood)

    # sort list from small loglik to large
    # select last element (= largest loglik)
    # select ploidy of that element
    est_ploidy = sorted(loglik)[-1][1]
    out_row.append(est_ploidy)
    return out_row


###################################################
# Done
###################################################


def main():
    """
    Set and format logging output/behavior,
    nothing else happening here...

    :return:
    """

    args = parse_command_line()

    log.basicConfig(
        level=log.DEBUG,
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
        filename=args.log,
        filemode="w",
    )

    if args.debug:
        log.basicConfig(
            **{
                "level": log.DEBUG,
                "stream": sys.stderr,
                "format": "[%(levelname)s] %(asctime)s [%(funcName)s]: %(message)s",
                "datefmt": "%Y-%m-%d %H:%M:%S",
            }
        )
    else:
        log.basicConfig(
            **{
                "level": log.WARNING,
                "stream": sys.stderr,
                "format": "[%(levelname)s] %(asctime)s [%(funcName)s]: %(message)s",
                "datefmt": "%Y-%m-%d %H:%M:%S",
            }
        )

    logger.debug("Ploidy estimator start")
    frac_dataset = compute_watson_fractions(args.input, args.window, args.step, args.sort, args.jobs)

    if args.table:
        dump_fraction_table(frac_dataset, args.table)

    if args.blacklist:
        frac_dataset = mark_blacklist_regions(frac_dataset, args.blacklist)

    ploidy_estimates = run_ploidy_estimation(frac_dataset, args.max_ploidy, args.alpha, args.background, args.jobs)

    write_ploidy_estimation_table(ploidy_estimates, args.output)

    logger.debug("Ploidy estimator finish")

    return


if __name__ == "__main__":
    try:
        main()
    except Exception as err:
        trb.print_exc(file=sys.stderr)
        rc = 1
    else:
        rc = 0
    sys.exit(rc)
1
2
3
4
5
6
import pandas as pd

df = pd.read_csv(snakemake.input.ploidy, sep="\t")
df = df.groupby("#chrom")["ploidy_estimate"].describe()
print(df)
df.to_csv(snakemake.output.summary, sep="\t")
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
from PyPDF2 import PdfFileWriter, PdfFileReader

# print(snakemake.input[0], snakemake.output[0], snakemake.wildcards.sample, snakemake.wildcards.cell, snakemake.wildcards.i)

inputpdf = PdfFileReader(snakemake.input[0], "rb")

# cell_name = tmp_dict[snakemake.wildcards.sample][int(snakemake.wildcards.i)]

output = PdfFileWriter()
output.addPage(inputpdf.getPage(int(snakemake.wildcards.i)))

# tmp_output_path = os.path.dirname(snakemake.input[0]) + "/{}.{}.pdf".format(snakemake.wildcards.cell, snakemake.wildcards.i)

# with open(tmp_output_path, "wb") as outputStream:
with open(snakemake.output[0], "wb") as outputStream:
    output.write(outputStream)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator

# SETTINGS
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 2.50


# LOADING DATAFRAME
df = pd.read_csv(snakemake.input.ploidy_detailled, sep="\t")
df = df.loc[df["#chrom"] != "genome"]

# GETTING CHR LIST & ORDER CATEGORICALLY
chroms = ["chr" + str(c) for c in list(range(1, 23))] + ["chrX", "chrY"]
df["#chrom"] = pd.Categorical(df["#chrom"], categories=chroms, ordered=True)
df = df.sort_values(by=["#chrom", "start"])

# STARTING SUBPLOTS
col_nb = len(df["#chrom"].unique())
col_nb = col_nb + 1 if col_nb == 1 else col_nb
f, ax = plt.subplots(ncols=col_nb, figsize=(3 * col_nb, 35))


# ITERATE OVER CHROM
for i, chrom in enumerate(df["#chrom"].unique().tolist()):
    # PLOTTING MAIN FIGURE
    ax[i].plot(df.loc[df["#chrom"] == chrom].ploidy_estimate, df.loc[df["#chrom"] == chrom].start, lw=4, color="black")

    # CUSTOMISATION
    ax[i].set_xlabel("{}".format(chrom), fontsize=30)
    ax[i].set_ylim(0, df.start.max())
    ax[i].set_xlim(0, 6)

    # ADDING VERTICAL RED DASHED LINE
    ax[i].axvline(2, ymax=df.loc[df["#chrom"] == chrom].start.max() / df.start.max(), ls="--", lw=2, color="red")

    # GRID CUSTOMISATION
    ax[i].tick_params(axis="x", which="major", labelsize=20)
    major_ticks = np.arange(0, 7, 1)
    ax[i].set_xticks(major_ticks)
    ax[i].grid(which="both")
    for axe in ["top", "bottom", "left", "right"]:
        ax[i].spines[axe].set_linewidth(2)
        ax[i].spines[axe].set_color("black")
    ax[i].grid(axis="both", which="major")
    if i == 0:
        ax[i].tick_params(axis="y", which="major", labelsize=20)
        ax[i].set_ylabel("Position (Mbp)", fontsize=30)
    else:
        ax[i].get_yaxis().set_visible(False)
f.suptitle("Sample: {}".format(snakemake.wildcards.sample), fontsize=50)
plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.savefig(snakemake.output[0], dpi=300)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

# IMPORTS

library(ComplexHeatmap)
library(RColorBrewer)
library(dplyr)
library(tidyr)



# pdf("TEST_R_dev.pdf", width = 20, height = 10)
pdf(snakemake@output[["pdf"]], width = 20, height = 10)

# Chromosome order
if (snakemake@config[["reference"]] != "mm10") {
    chrOrder <-
        c(paste("chr", 1:22, sep = ""), "chrX", "chrY")
} else {
    chrOrder <-
        c(paste("chr", 1:19, sep = ""), "chrX", "chrY")
}
# Load SV data

# data_file = "../stringent_filterTRUE.tsv"
data_file <- snakemake@input[["sv_calls"]]
# data1 <- read.table("../lenient_filterFALSE.tsv",
data1 <- read.table(data_file,
    sep = "\t",
    header = T,
    comment.char = ""
)
# head(data1)

# Create Dataframe for chromosomes missing SVs

chrom <- as.vector(setdiff(chrOrder, data1$chrom))
start <- rep(0, length(setdiff(chrOrder, data1$chrom)))
end <- rep(0, length(setdiff(chrOrder, data1$chrom)))
sample <- rep(data1$sample[1][1], length(setdiff(chrOrder, data1$chrom)))
cell <- rep(data1$cell[1][1], length(setdiff(chrOrder, data1$chrom)))
class <- rep("NA", length(setdiff(chrOrder, data1$chrom)))
scalar <- rep(0, length(setdiff(chrOrder, data1$chrom)))
num_bins <- rep(0, length(setdiff(chrOrder, data1$chrom)))
sv_call_name <- rep("none", length(setdiff(chrOrder, data1$chrom)))
sv_call_haplotype <- rep(0, length(setdiff(chrOrder, data1$chrom)))
sv_call_name_2nd <- rep("NA", length(setdiff(chrOrder, data1$chrom)))
sv_call_haplotype_2nd <- rep(0, length(setdiff(chrOrder, data1$chrom)))
llr_to_ref <- rep(0, length(setdiff(chrOrder, data1$chrom)))
llr_to_2nd <- rep(0, length(setdiff(chrOrder, data1$chrom)))
af <- rep(0, length(setdiff(chrOrder, data1$chrom)))

data1_missing <- data.frame(
    chrom,
    start,
    end,
    sample,
    cell,
    class,
    scalar,
    num_bins,
    sv_call_name,
    sv_call_haplotype,
    sv_call_name_2nd,
    sv_call_haplotype_2nd,
    llr_to_ref,
    llr_to_2nd,
    af
)


# Bind existing dataframe and new one

data1 <- rbind(data1, data1_missing)

data1$chrom <-
    factor(data1$chr, levels = chrOrder)


data1 <- data1[order(data1$chrom), ]
data1$pos <- paste0(data1$chrom, "_", data1$start, "_", data1$end)

# Select subset of the dataframe
lite_data <- select(data1, c("pos", "cell", "sv_call_name"))

# Get colors / chrom
set.seed(2)
n <- length(unique(data1$chrom))
qual_col_pals <- brewer.pal.info[brewer.pal.info$category == "qual", ]
col_vector <- unlist(mapply(brewer.pal, qual_col_pals$maxcolors, rownames(qual_col_pals)))
# pie(rep(1, n), col = sample(col_vector, n))

chrom <- unique(data1$chrom)
colors_chroms <- sample(col_vector, length(chrom))
# Instanciate data$color column
data1$color <- "NA"
# Iterate over chrom to attribute color
for (i in 1:length(chrom)) {
    chrom_index_list <- which(data1$chrom == chrom[i])
    data1[chrom_index_list, "color"] <- colors_chroms[i]
}
dd <- unique(select(data1, c("chrom", "color")))

# range01 <- function(x) {
#     100 * ((x - min(x)) / (max(x) - min(x)))
# }

## LLR & CLUSTERING

# Create subset for clustering

lite_data_clustering <- select(data1, c("pos", "cell", "llr_to_ref"))
lite_data_clustering[c("llr_to_ref")][sapply(lite_data_clustering[c("llr_to_ref")], is.infinite)] <- max(lite_data_clustering$llr_to_ref[is.finite(lite_data_clustering$llr_to_ref)])
lite_data_clustering[is.na(lite_data_clustering)] <- 0
# lite_data_clustering$llr_to_ref <- range01(lite_data_clustering$llr_to_ref)

# Pivot dataframe into matrix
lite_data_pivot_clustering <- lite_data_clustering %>%
    pivot_wider(
        names_from = "cell",
        values_from = "llr_to_ref"
    )
# Transpose
t_lite_data_pivot_clustering <- t(lite_data_pivot_clustering)

colnames(t_lite_data_pivot_clustering) <- t_lite_data_pivot_clustering[1, ]
t_lite_data_pivot_clustering <- t_lite_data_pivot_clustering[-1, ]
t_lite_data_pivot_clustering[is.na(t_lite_data_pivot_clustering)] <- 0

# Turn into numeric matrix
t_lite_data_pivot_clustering_num <- matrix(as.double(t_lite_data_pivot_clustering), ncol = ncol(t_lite_data_pivot_clustering))
rownames(t_lite_data_pivot_clustering_num) <- rownames(t_lite_data_pivot_clustering)
colnames(t_lite_data_pivot_clustering_num) <- colnames(t_lite_data_pivot_clustering)

# Plot options
options(repr.plot.width = 20, repr.plot.height = 12)

anno_colors <- list(Chroms = unique(data1$chrom))

col_annotation <- sapply(strsplit(lite_data_pivot_clustering$pos, "_"), `[`, 1)

col_test <- factor(sapply(strsplit(colnames(t_lite_data_pivot_clustering_num), "_"), `[`, 1), levels = unique(sapply(strsplit(colnames(t_lite_data_pivot_clustering_num), "_"), `[`, 1)))

cl_h <- Heatmap(as.matrix(t_lite_data_pivot_clustering_num),
    name = "LLR", col = RColorBrewer::brewer.pal(name = "Reds", n = 9),
    # column_title = "a discrete numeric matrix",
    rect_gp = gpar(col = "white", lwd = 1.5),
    top_annotation = ComplexHeatmap::HeatmapAnnotation(
        foo = anno_block(gp = gpar(fill = 2:24))
    ),
    column_split = col_test,
    width = unit(32, "cm"), height = unit(20, "cm"),
    row_names_gp = gpar(fontsize = 5),
    column_names_gp = gpar(fontsize = 4),
    column_title_gp = gpar(fontsize = 10),
    cluster_columns = FALSE,
    column_gap = unit(2, "mm"),
    cluster_column_slices = FALSE,
    column_title_rot = 90,
)
ht_opt$TITLE_PADDING <- unit(c(8.5, 8.5), "points")
draw(cl_h,
    # row_title = "Three heatmaps, row title", row_title_gp = gpar(col = "red"),
    column_title = paste0("Chromosome size unscaled LLR heatmap (Sample : ", snakemake@wildcards[["sample"]], ", Methods used: ", snakemake@wildcards[["method"]], ", Filter used: ", snakemake@wildcards[["filter"]], ")"), column_title_gp = gpar(fontsize = 16)
)

## CATEGORICAL

# Turn data into a matrix
lite_data_pivot <- lite_data %>%
    pivot_wider(
        names_from = "cell",
        values_from = "sv_call_name"
    )

# Transpose
t_lite_data_pivot <- t(lite_data_pivot)
colnames(t_lite_data_pivot) <- t_lite_data_pivot[1, ]
t_lite_data_pivot <- t_lite_data_pivot[-1, ]

# SV list
sv_list <-
    c(
        "none",
        "del_h1",
        "del_h2",
        "del_hom",
        "dup_h1",
        "dup_h2",
        "dup_hom",
        "inv_h1",
        "inv_h2",
        "inv_hom",
        "idup_h1",
        "idup_h2",
        "complex"
    )

# SV type colors
colors <-
    structure(
        c(
            "grey",
            "#77AADD",
            "#4477AA",
            "#114477",
            "#CC99BB",
            "#AA4488",
            "#771155",
            "#DDDD77",
            "#AAAA44",
            "#777711",
            "#DDAA77",
            "#AA7744",
            "#774411"
        ),
        names = sv_list
    )

# Fill NA with none
t_lite_data_pivot[is.na(t_lite_data_pivot)] <- "none"

anno_colors <- list(Chroms = unique(data1$chrom))
col_annotation <- as.data.frame(sapply(strsplit(lite_data_pivot$pos, "_"), `[`, 1))
colnames(col_annotation) <- "Chroms"
col_test <- factor(sapply(strsplit(colnames(t_lite_data_pivot), "_"), `[`, 1), levels = unique(sapply(strsplit(colnames(t_lite_data_pivot), "_"), `[`, 1)))


cat_h <- Heatmap(as.matrix(t_lite_data_pivot),
    name = "SV type", col = colors,
    # column_title = "a discrete numeric matrix",
    rect_gp = gpar(col = "white", lwd = 1.5),
    top_annotation = ComplexHeatmap::HeatmapAnnotation(
        foo = anno_block(gp = gpar(fill = 2:24))
    ),
    column_split = col_test,
    width = unit(32, "cm"), height = unit(20, "cm"),
    row_names_gp = gpar(fontsize = 5),
    column_names_gp = gpar(fontsize = 4),
    column_title_gp = gpar(fontsize = 10),
    column_gap = unit(2, "mm"),
    # column_order = order(as.numeric(sapply(strsplit(gsub("chr", "", colnames(t_lite_data_pivot)), "_"), `[`, 1))),
    row_order = row_order(cl_h),
    column_title_rot = 90,
    # use_raster = TRUE, raster_by_magick = TRUE, raster_quality=10
)
ht_opt$TITLE_PADDING <- unit(c(8.5, 8.5), "points")
draw(cat_h,
    # row_title = "Three heatmaps, row title", row_title_gp = gpar(col = "red"),
    column_title = paste0("Chromosome size unscaled categorical heatmap (Sample : ", snakemake@wildcards[["sample"]], ", Methods used: ", snakemake@wildcards[["method"]], ", Filter used: ", snakemake@wildcards[["filter"]], ")"), column_title_gp = gpar(fontsize = 16)
)


# Export clustered row order to output in order to use it in python script
row_order <- row_order(cl_h)
cell <- rownames(t_lite_data_pivot)[row_order]
index <- seq(1, length(cell))
cluster_order_df <- data.frame(index, row_order, cell)
# write.table(cluster_order_df, file = "test.tsv", sep = "\t", row.names = FALSE, quote = FALSE)
write.table(cluster_order_df, file = snakemake@output[["cluster_order_df"]], sep = "\t", row.names = FALSE, quote = FALSE)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.backends.backend_pdf import PdfPages

log = open(snakemake.log[0], "w")
sys.stderr = sys.stdout = log

# Categorical mapping
# d = {
#     "none": 0,
#     "del_h1": 1,
#     "del_h2": 2,
#     "del_hom": 3,
#     "dup_h1": 4,
#     "dup_h2": 5,
#     "dup_hom": 6,
#     "inv_h1": 7,
#     "inv_h2": 8,
#     "inv_hom": 9,
#     "idup_h1": 10,
#     "idup_h2": 11,
#     "complex": 12,
# }

d = {
    "none": 1,
    "del_h1": 2,
    "del_h2": 3,
    "del_hom": 4,
    "dup_h1": 5,
    "dup_h2": 6,
    "dup_hom": 7,
    "inv_h1": 8,
    "inv_h2": 9,
    "inv_hom": 10,
    "idup_h1": 11,
    "idup_h2": 12,
    "complex": 13,
}

# Colors
colors = {
    "none": "#F8F8F8",
    "del_h1": "#77AADD",
    "del_h2": "#4477AA",
    "del_hom": "#114477",
    "dup_h1": "#CC99BB",
    "dup_h2": "#AA4488",
    "dup_hom": "#771155",
    "inv_h1": "#DDDD77",
    "inv_h2": "#AAAA44",
    "inv_hom": "#777711",
    "idup_h1": "#DDAA77",
    "idup_h2": "#AA7744",
    "complex": "#774411",
}

# Read SV file
df = pd.read_csv(snakemake.input.sv_calls, sep="\t")
# df = pd.read_csv("../stringent_filterTRUE.tsv", sep="\t")
df["ID"] = df["chrom"] + "_" + df["start"].astype(str) + "_" + df["end"].astype(str)


if snakemake.config["reference"] != "mm10":
    names = ["chrom", "start", "end", "bin_id"]
else:
    names = ["chrom", "start", "end"]

# Read 200kb bins file
binbed = pd.read_csv(
    # "../bin_200kb_all.bed",
    snakemake.input.binbed,
    sep="\t",
    names=names,
)


binbed["ID"] = binbed["chrom"] + "_" + binbed["start"].astype(str) + "_" + binbed["end"].astype(str)

cats = (
    ["chr{}".format(e) for e in range(1, 23)] + ["chrX", "chrY"]
    if snakemake.config["reference"] != "mm10"
    else ["chr{}".format(e) for e in range(1, 20)] + ["chrX", "chrY"]
)

# Turn chrom into categorical
binbed["chrom"] = pd.Categorical(
    binbed["chrom"],
    categories=cats,
    ordered=True,
)

# Sort & filter out chrY #TMP / can be changed
binbed = binbed.sort_values(by=["chrom", "start", "end"]).reset_index(drop=True)
# binbed = binbed.loc[~binbed["chrom"].isin(["chrY"])]

# Instanciate final list
l = list()


def process_row(r):
    """Get all bins from binbed that overlap SV call

    Args:
        r (pandas row)
    """
    tmp_r = binbed.loc[(binbed["chrom"] == r["chrom"]) & (binbed["start"] >= r["start"]) & (binbed["end"] <= r["end"])]
    tmp_r["cell"] = r["cell"]
    tmp_r["sv_call_name"] = r["sv_call_name"]
    tmp_r["af"] = r["af"]
    tmp_r["llr_to_ref"] = r["llr_to_ref"]
    # Append result to list
    l.append(tmp_r)


# Apply & loop on each temporary cell dataframe created
def process_sv(tmp_df):
    tmp_df.apply(lambda r: process_row(r), axis=1)


# Create a nested pandas apply
df.groupby("cell").apply(lambda r: process_sv(r))

# Concat results
processed_df = pd.concat(l)
processed_df["ID"] = processed_df["chrom"].astype(str) + "_" + processed_df["start"].astype(str) + "_" + processed_df["end"].astype(str)

# Extract only empty bins (outer join) from binbed
binbed_not_used = binbed.loc[~binbed["ID"].isin(processed_df.ID.unique().tolist())]
# Concat with previously created dataframe
concat_df = pd.concat([processed_df, binbed_not_used])

# Replace llr inf values by max values
concat_df.loc[concat_df["llr_to_ref"] == np.inf, "llr_to_ref"] = concat_df.loc[concat_df["llr_to_ref"] != np.inf]["llr_to_ref"].max()

# Pivot into matrix
pivot_concat_df = concat_df.pivot(index="ID", values="llr_to_ref", columns="cell")

# Create chrom, start, end columns in a tmp df
tmp = pivot_concat_df.reset_index().ID.str.split("_", expand=True)
tmp.columns = ["chrom", "start", "end"]
tmp["start"] = tmp["start"].astype(int)
tmp["end"] = tmp["end"].astype(int)

# Concat dfs, remove first column, sort, index
pivot_concat_df = (
    pd.concat([pivot_concat_df.reset_index(), tmp], axis=1)
    .drop(pivot_concat_df.columns[0], axis=1)
    .sort_values(by=["chrom", "start", "end"])
    .reset_index(drop=True)
)


# Read clustering index file produced from previous clustering using R ComplexHeatmap
# clustering_index_df = pd.read_csv("test.tsv", sep="\t")
clustering_index_df = pd.read_csv(snakemake.input.cluster_order_df, sep="\t")


## LLR

# Pivot df subset specific to llr
pivot_concat_df = concat_df.pivot(index="ID", values="llr_to_ref", columns="cell")
tmp = pivot_concat_df.reset_index().ID.str.split("_", expand=True)
tmp.columns = ["chrom", "start", "end"]
tmp["start"] = tmp["start"].astype(int)
tmp["end"] = tmp["end"].astype(int)
pivot_concat_df = pd.concat([pivot_concat_df.reset_index(), tmp], axis=1).drop(pivot_concat_df.columns[0], axis=1)

pivot_concat_df["chrom"] = pd.Categorical(
    pivot_concat_df["chrom"],
    categories=cats,
    ordered=True,
)
pivot_concat_df = pivot_concat_df.sort_values(by=["chrom", "start", "end"]).reset_index(drop=True)

chroms = cats
# chroms = chroms[:2]
# chroms = ["chr10", "chr13", "chr22"]

# Extract widths using binbed max values to specify subplots widths scaled according chrom sizes
widths = binbed.loc[binbed["chrom"].isin(chroms)].groupby("chrom")["end"].max().dropna().tolist()


# pdf = PdfPages("multipage_pdf2.pdf")
pdf = PdfPages(snakemake.output.pdf)

# Create subplots
f, axs = plt.subplots(ncols=len(chroms), figsize=(40, 20), gridspec_kw={"width_ratios": widths})

print("LLR plot")
# Iterate over chroms
for j, (chrom, ax) in enumerate(zip(chroms, axs)):
    print(chrom)
    cbar = False

    # If not chr1 = remove y axis
    if j != 0:
        ax.get_yaxis().set_visible(False)
        ax.yaxis.set_ticks_position("none")

    # If last chrom, enable cbar plot
    if j == len(chroms) - 1:
        cbar = True

    # Subset chrom data, set_index, transpose & replace NaN by 0
    data_heatmap = (
        pivot_concat_df.loc[pivot_concat_df["chrom"] == chrom].drop(["chrom", "start", "end"], axis=1).set_index("ID").T.fillna(0)
    )

    # Reorder rows based on clustering index
    data_heatmap = data_heatmap.loc[clustering_index_df.cell.values.tolist()]

    # Plot
    sns.heatmap(
        data=data_heatmap,
        ax=ax,
        vmin=0,
        vmax=concat_df.llr_to_ref.max(),
        cmap="Reds",
        cbar=cbar,
    )
    ax.xaxis.set_ticks_position("none")
    ax.set_xlabel("{}".format(chrom), fontsize=12, rotation=90)
    ax.set_xticklabels([])

plt.suptitle(
    f"Chromosome size scaled LLR heatmap (Sample : {snakemake.wildcards.sample}, Methods used: {snakemake.wildcards.method}, Filter used: {snakemake.wildcards.filter})",
    x=0.4,
    y=1.02,
    fontsize=18,
)

pdf.savefig(f)
plt.close()

## CATEGORICAL

# Map values to categorical names
concat_df["sv_call_name_map"] = concat_df["sv_call_name"].map(d)

# Pivot df subset specific to sv_call_name
pivot_concat_df = concat_df.pivot(index="ID", values="sv_call_name_map", columns="cell")
tmp = pivot_concat_df.reset_index().ID.str.split("_", expand=True)
tmp.columns = ["chrom", "start", "end"]
tmp["start"] = tmp["start"].astype(int)
tmp["end"] = tmp["end"].astype(int)
pivot_concat_df = pd.concat([pivot_concat_df.reset_index(), tmp], axis=1).drop(pivot_concat_df.columns[0], axis=1)

pivot_concat_df["chrom"] = pd.Categorical(
    pivot_concat_df["chrom"],
    categories=cats,
    ordered=True,
)
pivot_concat_df = pivot_concat_df.sort_values(by=["chrom", "start", "end"]).reset_index(drop=True)


# chroms = ["chr{}".format(e) for e in range(1, 23)] + ["chrX", "chrY"]
chroms = cats
# chroms = ["chr10", "chr13"]
# chroms = chroms[:2]

widths = binbed.loc[binbed["chrom"].isin(chroms)].groupby("chrom")["end"].max().dropna().tolist()

f, axs = plt.subplots(ncols=len(chroms), figsize=(30, 15), dpi=50, gridspec_kw={"width_ratios": widths})

print("Categorical plot")
for j, (chrom, ax) in enumerate(zip(chroms, axs)):
    print(chrom)
    data_heatmap = (
        pivot_concat_df.loc[pivot_concat_df["chrom"] == chrom].drop(["chrom", "start", "end"], axis=1).set_index("ID").T.fillna(0)
    )
    data_heatmap = data_heatmap.loc[clustering_index_df.cell.values.tolist()]
    sns.heatmap(data=data_heatmap, ax=ax, vmin=0, cbar=False, cmap=list(colors.values()))
    ax.xaxis.set_ticks_position("none")

    if j != 0:
        ax.get_yaxis().set_visible(False)
        ax.yaxis.set_ticks_position("none")

    ax.set_xlabel("{}".format(chrom), fontsize=12, rotation=90)
    ax.set_xticklabels([])

custom_lines = [Line2D([0], [0], color=v, lw=12) for j, (k, v) in enumerate(colors.items())]

axs[-1].legend(
    custom_lines,
    list(colors.keys()),
    bbox_to_anchor=(1 + 0.15 * len(chroms), 0.65),
    fontsize=16,
)
plt.tight_layout(rect=[0, 0, 0.95, 1])

plt.suptitle(
    f"Chromosome size scaled categorical heatmap (Sample : {snakemake.wildcards.sample}, Methods used: {snakemake.wildcards.method}, Filter used: {snakemake.wildcards.filter})",
    x=0.4,
    y=1.02,
    fontsize=18,
)

pdf.savefig(f)
plt.close()

pdf.close()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

source("workflow/scripts/plotting/plot-clustering.R")
plot.clustering(
    inputfile = snakemake@input[["sv_calls"]],
    bin.bed.filename = snakemake@input[["binbed"]],
    position.outputfile = snakemake@output[["position"]],
)
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
suppressMessages(library(dplyr))
suppressMessages(library(data.table))
suppressMessages(library(ggplot2))
library(scales) %>% invisible()
library(assertthat) %>% invisible()
library(stringr) %>% invisible()
library(RColorBrewer) %>% invisible()
library(ggnewscale)

# setwd("/Users/tweber/workspace/plot-SV")




# plot.SV_calls <-
#   function(arg.segments,
#            arg.singlecellsegments,
#            arg.strand,
#            arg.complex,
#            arg.groups,
#            arg.calls,
#            arg.counts,
#            arg.chromosome,
#            arg.output,
#            arg.truth,
#            background_segment = FALSE) {
################################################################################
# Settings                                                                     #
################################################################################

# zcat_command <- "zcat"
# FIXME : tmp solution
chroms <-
  c(
    "chr1",
    "chr2",
    "chr3",
    "chr4",
    "chr5",
    "chr6",
    "chr7",
    "chr8",
    "chr9",
    "chr10",
    "chr11",
    "chr12",
    "chr13",
    "chr14",
    "chr15",
    "chr16",
    "chr17",
    "chr18",
    "chr19",
    "chr20",
    "chr21",
    "chr22",
    "chrX",
    "chrY"
  )

format_Mb <- function(x) {
  paste(comma(x / 1e6), "Mb")
}

### Colors for background
manual_colors <- c(
  # duplications
  simul_hom_dup = "firebrick4",
  dup_hom = muted("firebrick4", 70, 50),
  simul_het_dup = "firebrick2",
  dup_h1 = muted("firebrick2", 90, 30),
  dup_h2 = muted("firebrick2", 80, 20),
  # deletions
  simul_hom_del = "dodgerblue4",
  del_hom = muted("dodgerblue4", 50, 60),
  simul_het_del = "dodgerblue2",
  del_h1 = muted("dodgerblue2", 80, 50),
  del_h2 = muted("deepskyblue2", 80, 50),
  # inversions
  simul_hom_inv = "chartreuse4",
  inv_hom = muted("chartreuse4", 80, 50),
  simul_het_inv = "chartreuse2",
  inv_h1 = muted("chartreuse2", 100, 60),
  inv_h2 = muted("darkolivegreen3", 100, 60),
  # other SVs
  simul_false_del = "darkgrey",
  simul_inv_dup = "darkgoldenrod2",
  idup_h1 = muted("darkgoldenrod2", 80, 70),
  idup_h2 = muted("gold", 80, 70),
  complex = "darkorchid1",
  # background
  bg1 = "#ffffff",
  bg2 = "#aaafaa",
  # Strand states
  `State: WW` = "sandybrown",
  `State: CC` = "paleturquoise4",
  `State: WC` = "khaki",
  `State: CW` = "yellow2"
)

manual_colors_sv <- c(
  # duplications
  # simul_hom_dup = "firebrick4",
  dup_hom = muted("firebrick4", 70, 50),
  # simul_het_dup = "firebrick2",
  dup_h1 = muted("firebrick2", 90, 30),
  dup_h2 = muted("firebrick2", 80, 20),
  # deletions
  # simul_hom_del = "dodgerblue4",
  del_hom = muted("dodgerblue4", 50, 60),
  # simul_het_del = "dodgerblue2",
  del_h1 = muted("dodgerblue2", 80, 50),
  del_h2 = muted("deepskyblue2", 80, 50),
  # inversions
  # simul_hom_inv = "chartreuse4",
  inv_hom = muted("chartreuse4", 80, 50),
  # simul_het_inv = "chartreuse2",
  inv_h1 = muted("chartreuse2", 100, 60),
  inv_h2 = muted("darkolivegreen3", 100, 60),
  # other SVs
  # simul_false_del = "darkgrey",
  # simul_inv_dup = "darkgoldenrod2",
  idup_h1 = muted("darkgoldenrod2", 80, 70),
  idup_h2 = muted("gold", 80, 70)
  # complex = "darkorchid1"
)

manual_colors_cx <- c(complex = "darkorchid1")
manual_colors_bg <- c( # background
  bg1 = "#ffffff",
  bg2 = "#aaafaa"
)

manual_colors_ss <- c(
  # Strand states
  `State: WW` = "sandybrown",
  `State: CC` = "paleturquoise4",
  `State: WC` = "khaki",
  `State: CW` = "yellow2"
)

################################################################################
# Usage                                                                        #
################################################################################

print_usage_and_stop <- function(msg = NULL) {
  if (!is.null(msg)) {
    message(msg)
  }
  message("Plot Strand-seq counts of all cells for a single chromosome.                    ")
  message("                                                                                ")
  message("Usage:                                                                          ")
  message("    Rscript chrom.R [OPTIONS] <count-file> <chrom> <out.pdf>                    ")
  message("                                                                                ")
  message("OPTIONS (no spaces around `=`):                                                 ")
  message("    per-page=<int>            Number of cells to be printed per page            ")
  message("    segments=<file>           Show the segmentation in the plots                ")
  message("    singlecellsegments=<file> Show per-cell  segmentation in the plots          ")
  message("    calls=<file>              Highlight SV calls provided in a table            ")
  message("    truth=<file>              Mark the `true`` SVs provided in a table          ")
  message("    strand=<file>             Mark the strand states which calls are based on   ")
  message("    complex=<file>            Mark complex regions given in file                ")
  message("    groups=<file>             Table with SV call grouping                       ")
  message("    no-none                   Do not hightlight black-listed (i.e. None) bins   ")
  message("                                                                                ")
  message("Generates one plot per chromosome listing all cells below another, separated    ")
  message("into pages. If an SV probability file is provided (2), segments are colored     ")
  message("according to the predicted SV classes. Note that only certain classes are       ")
  message("accepted and the script will yield an error if others are provided.             ")
  message("Similarly, a segmentation file can be specified, yet it must contain exactly    ")
  message("one segmentation (MosaiCatcher reports segmentations for various total numbers  ")
  message("of breakpoints).")
  options(show.error.messages = F)
  stop()
}



################################################################################
# Command Line Arguments                                                       #
################################################################################

args <- commandArgs(trailingOnly = T)

if (length(args) < 3) {
  print_usage_and_stop("[Error] Too few arguments!")
}
assembly <- "hg38"
f_counts <- args[length(args) - 2]
CHROM <- args[length(args) - 1]
f_out <- args[length(args)]

f_segments <- NULL
f_calls <- NULL
f_truth <- NULL
f_strand <- NULL
f_complex <- NULL

# f_segments <- arg.segments
# f_scsegments <- arg.singlecellsegments
# f_strand <- arg.strand
# f_complex <- arg.complex
# f_groups <- arg.groups
# f_calls <- arg.calls
# f_counts <- arg.counts
# CHROM <- arg.chromosome
# f_out <- arg.output




cells_per_page <- 8
show_none <- T

if (length(args) > 3) {
  if (!all(
    grepl(
      "^(strand|calls|segments|per-page|truth|no-none|complex|singlecellsegments|groups)=?",
      args[1:(length(args) - 3)]
    )
  )) {
    print_usage_and_stop("[Error]: Options must be one of `calls`, `segments`, `per-page`, or `truth`")
  }
  for (op in args[1:(length(args) - 3)]) {
    if (grepl("^segments=", op)) {
      f_segments <- str_sub(op, 10)
    }
    if (grepl("^calls=", op)) {
      f_calls <- str_sub(op, 7)
    }
    if (grepl("^truth=", op)) {
      f_truth <- str_sub(op, 7)
    }
    if (grepl("^per-page=", op)) {
      pp <- as.integer(str_sub(op, 10))
      if (pp > 0 && pp < 50) {
        cells_per_page <- pp
      }
    }
    if (grepl("^strand=", op)) {
      f_strand <- str_sub(op, 8)
    }
    if (grepl("^complex=", op)) {
      f_complex <- str_sub(op, 9)
    }
    if (grepl("^groups=", op)) {
      f_groups <- str_sub(op, 8)
    }
    if (grepl("^singlecellsegments=", op)) {
      f_scsegments <- str_sub(op, 20)
    }
    if (grepl("^no-none$", op)) {
      show_none <- F
    }
  }
}


################################################################################
# Read & check input data                                                      #
################################################################################

### Check counts table
message(" * Reading count data ", f_counts, "...")
if (grepl("\\.gz$", f_counts)) {
  # counts <- fread(paste(zcat_command, f_counts))
  counts <- data.table::fread(f_counts)
} else {
  counts <- fread(f_counts)
}

# FIXME : tmp
# print(counts)
counts <- counts[counts$chrom %in% chroms, ]
# print(counts)

assert_that(
  "chrom" %in% colnames(counts),
  "start" %in% colnames(counts),
  "end" %in% colnames(counts),
  "class" %in% colnames(counts),
  "sample" %in% colnames(counts),
  "cell" %in% colnames(counts),
  "w" %in% colnames(counts),
  "c" %in% colnames(counts)
) %>% invisible()
counts[, sample_cell := paste(sample, "-", cell)]
setkey(counts, chrom, sample_cell)
bins <- unique(counts[, .(chrom, start, end)])

### Check CHROM:
assert_that(CHROM %in% unique(counts$chrom)) %>% invisible()
counts <- counts[chrom == CHROM]




### Check SV call file
if (!is.null(f_calls)) {
  message(" * Reading SV calls from ", f_calls, "...")
  svs <- fread(f_calls)
  assert_that(
    "chrom" %in% colnames(svs),
    "start" %in% colnames(svs),
    "end" %in% colnames(svs),
    "sample" %in% colnames(svs),
    "cell" %in% colnames(svs),
    (
      "SV_class" %in% colnames(svs) | "sv_call_name" %in% colnames(svs)
    )
  ) %>% invisible()
  if (!("SV_class" %in% colnames(svs))) {
    svs[, SV_class := sv_call_name]
  }
  assert_that(all(svs$SV_class %in% names(manual_colors))) %>% invisible()
  svs[, sample_cell := paste(sample, "-", cell)]

  set_diff <-
    setdiff(unique(svs$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) {
    message("[Warning] SV calls and Counts differ in cells: ", set_diff)
  }

  svs <- svs[chrom == CHROM]
}

### Check segment table
if (!is.null(f_segments)) {
  message(" * Reading segmentation file from ", f_segments, "...")
  seg <- fread(f_segments)
  # seg_max <- seg[, max(bps), by = chrom]


  # FIXME : tmp
  seg <- seg[seg$chrom %in% chroms, ]
  seg <- seg[seg$bps > 0, ] # SOLVE T2T ISSUE

  # print(seg)

  assert_that(
    "chrom" %in% colnames(seg),
    "bps" %in% colnames(seg)
  ) %>% invisible()
  if ("k" %in% colnames(seg)) {
    seg[, assert_that(length(unique(k)) == 1), by = .(chrom)] %>% invisible()
  }

  print(seg)

  print(bins)
  # print(bins[, .N, by = chrom])
  # # print(bins %>% count(chrom))
  # print(bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N)))])
  # print(bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N))[1:(.N - 1)])])



  seg <-
    merge(seg, bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N)[- .N]))], by = "chrom")

  # print(c(1, bps[1:(.N - 1)] + 1))
  # print(bps)
  print(seg)

  # stop()


  seg[, `:=`(from = c(1, bps[-length(bps)] + 1), to = bps), by = chrom]

  # print(seg)


  seg[, `:=`(
    start = bins[from + N]$start,
    end = bins[to + N]$end
  )]

  # print("TEST")


  seg[, SV_class := rep(c("bg1", "bg2"), .N)[1:.N], by = chrom]



  seg <- seg[chrom == CHROM]
}

### Check simulated variants
if (!is.null(f_truth)) {
  message(" * Reading simulated variants from ", f_truth, "...")
  simul <- fread(f_truth)
  assert_that(
    "chrom" %in% colnames(simul),
    "start" %in% colnames(simul),
    "end" %in% colnames(simul),
    "sample" %in% colnames(simul),
    "cell" %in% colnames(simul),
    "SV_type" %in% colnames(simul)
  ) %>% invisible()
  simul[, `:=`(
    SV_class = paste0("simul_", SV_type),
    SV_type = NULL,
    sample_cell = paste(sample, "-", cell)
  )]
  simul[, sample_cell := paste(sample, "-", cell)]

  set_diff <-
    setdiff(unique(simul$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) {
    message("[Warning] True SVs and Counts differ in cells: ", set_diff)
  }

  simul <- simul[chrom == CHROM]
}

### Check strand states file
if (!is.null(f_strand)) {
  message(" * Reading strand state file from ", f_strand, "...")
  strand <- fread(f_strand)
  assert_that(
    "sample" %in% colnames(strand),
    "cell" %in% colnames(strand),
    "chrom" %in% colnames(strand),
    "start" %in% colnames(strand),
    "end" %in% colnames(strand),
    "class" %in% colnames(strand)
  ) %>% invisible()
  strand[, class := paste("State:", class)]
  strand[, sample_cell := paste(sample, "-", cell)]

  set_diff <-
    setdiff(unique(strand$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) {
    message(
      "[Warning] Strand states and Counts differ in cells: ",
      set_diff
    )
  }

  strand <- strand[chrom == CHROM]
}

### Check complex regions file
if (!is.null(f_complex)) {
  message(" * Reading complex regions state file from ", f_complex, "...")
  complex <- fread(f_complex)
  assert_that(
    "chrom" %in% colnames(complex),
    "start" %in% colnames(complex),
    "end" %in% colnames(complex)
  ) %>% invisible()

  complex <- complex[chrom == CHROM]
  message(
    "   --> Found ",
    nrow(complex),
    " complex region(s) in chromosome ",
    CHROM
  )
}

### Check SV groups file
if (!is.null(f_groups)) {
  message(" * Reading SV group file from ", f_groups, "...")
  groups <- fread(f_groups)
  assert_that(
    "chrom" %in% colnames(groups),
    "start" %in% colnames(groups),
    "end" %in% colnames(groups),
    "group_id" %in% colnames(groups)
  ) %>% invisible()
  groups[, group_id := paste("SV group", group_id)]
  groups <- groups[chrom == CHROM]
  message("   --> Found ", nrow(groups), " SV groups in chromosome ", CHROM)
}

### Check single cell segmentation file
if (!is.null(f_scsegments)) {
  message(
    " * Reading per-cell segmentation regions state file from ",
    f_scsegments,
    "..."
  )
  scsegments <- fread(f_scsegments)
  assert_that(
    "sample" %in% colnames(scsegments),
    "cell" %in% colnames(scsegments),
    "chrom" %in% colnames(scsegments),
    "position" %in% colnames(scsegments)
  ) %>% invisible()
  scsegments[, sample_cell := paste(sample, "-", cell)]

  scsegments <- scsegments[chrom == CHROM]
}


################################################################################
# Actual plot                                                                  #
################################################################################


# Plot always a few cells per page!
y_lim <- 3 * counts[, median(w + c)]
# x_lim <- seg_max[seg_max$chrom == "chr1", "V1"][[1]] * 100000
n_cells <- length(unique(counts[, sample_cell]))
i <- 1


message(" * Plotting ", CHROM, " (", f_out, ")")
cairo_pdf(f_out,
  width = 14,
  height = 10,
  onefile = T
)
while (i <= n_cells) {
  message(" * Processing cells from ", i, " to ", i + cells_per_page - 1)

  # Subset to this set of cells:
  CELLS <-
    unique(counts[, .(sample_cell)])[i:(min(i + cells_per_page - 1, n_cells))]
  setkey(CELLS, sample_cell)

  # Subset counts
  local_counts <-
    counts[CELLS, on = .(sample_cell), nomatch = 0]

  # Start major plot
  plt <- ggplot(local_counts)

  # Add background colors for segments, if available:
  # if (background_segment == TRUE) {
  if (!is.null(f_segments)) {
    message("   * Adding segment colors")
    # Segments need to be multiplied by "CELLS"
    local_seg <- CELLS[, as.data.table(seg), by = sample_cell]
    if (nrow(local_seg) > 0) {
      plt <- plt +
        geom_rect(
          data = local_seg,
          alpha = 0.4,
          aes(
            xmin = start,
            xmax = end,
            ymin = -Inf,
            ymax = Inf,
            fill = SV_class
          )
        ) + labs(fill = "Sample level\nsegmentation") + scale_fill_manual(values = manual_colors_bg) + new_scale_fill()
    }
    # }
  }

  # Add colors for SV calls, if available
  if (!is.null(f_calls)) {
    message("   * Adding SV calls")
    local_svs <- svs[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_svs) > 0) {
      plt <- plt + new_scale("fill") +

        geom_rect(
          data = local_svs,
          alpha = 0.6,
          aes(
            xmin = start,
            xmax = end,
            ymin = -Inf,
            ymax = Inf,
            fill = SV_class
          )
        ) + labs(fill = "SV class") + scale_fill_manual(values = manual_colors_sv, drop = FALSE) + new_scale_fill()
    }
  }

  # Add bars for true SVs, if available
  if (!is.null(f_truth)) {
    message("   * Adding true SVs")
    local_sim <- simul[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_sim) > 0) {
      plt <- plt +
        geom_rect(
          data = local_sim,
          aes(
            xmin = start,
            xmax = end,
            ymin = y_lim,
            ymax = Inf,
            fill = SV_class
          )
        )
    }
  }

  # Add lines for single cell segmentation, if available
  # if (background_segment == TRUE) {
  if (!is.null(f_scsegments)) {
    message("   * Adding single cell segments")
    local_scsegments <-
      scsegments[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_scsegments) > 0) {
      local_scsegments$cat <- rep("dashed", nrow(local_scsegments))
      plt <- plt +
        geom_segment(
          data = local_scsegments,
          aes(
            x = position,
            xend = position,
            y = -Inf,
            yend = y_lim,
            linetype = "dashed"
          ),
          linetype = "dashed",
          color = "black"
        ) + scale_linetype_manual("Single-cell\nsegmentation", values = c("dashed" = 1))
    }
  }
  # }

  # Add bars for strand states, if available
  if (!is.null(f_strand)) {
    message("   * Adding strand states")
    local_strand <-
      strand[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_strand) > 0) {
      plt <- plt +
        geom_rect(
          data = local_strand,
          aes(
            xmin = start,
            xmax = end,
            ymin = -Inf,
            ymax = -y_lim,
            fill = class
          )
        ) + labs(fill = "Strand State") + scale_fill_manual(values = manual_colors_ss, drop = FALSE) + new_scale_fill()
    }
  }

  # Add bars for SV group, if available
  if (!is.null(f_groups)) {
    message("   * Adding SV groups")
    if (nrow(groups) > 0) {
      plt <- plt +
        geom_rect(
          data = groups,
          aes(
            xmin = start,
            xmax = end,
            ymin = .85 * y_lim,
            ymax = Inf,
            fill = group_id
          )
        )
      # # Add colors for SV classes

      # print(manual_colors)
      manual_colors <-
        c(
          manual_colors,
          setNames(
            colorRampPalette(brewer.pal(12, "Set2"))(nrow(groups)),
            groups$group_id
          )
        ) # Add colors for SV classes
      # print(manual_colors)
    }
  }

  # Add bars for complex states, if available
  if (!is.null(f_complex)) {
    message("   * Adding complex intervals")
    if (nrow(complex) > 0) {
      complex$type <- "complex"
      print(complex)
      plt <- plt +
        geom_rect(
          data = complex,
          aes(
            xmin = start,
            xmax = end,
            ymin = y_lim,
            ymax = Inf,
            fill = type
          )
        ) + labs(fill = "Complex event") + scale_fill_manual(values = manual_colors_cx)
    }
  }

  message("   * Adding actual W/C counts")
  plt <- plt +
    geom_rect(aes(
      xmin = start,
      xmax = end,
      ymin = 0,
      ymax = -w
    ), fill = "sandybrown") +
    geom_rect(aes(
      xmin = start,
      xmax = end,
      ymin = 0,
      ymax = c
    ), fill = "paleturquoise4")


  # Highlight None bins, if requested
  none_bins <- local_counts[class == "None"]
  if (show_none == T && nrow(none_bins) > 0) {
    message("   * Highlighting None bins")
    plt <- plt +
      geom_segment(
        data = none_bins,
        aes(
          x = start,
          xend = end,
          y = 0,
          yend = 0
        ),
        col = "black",
        size = 2
      )
  }


  message("   * Adding labels, etc.")
  # print(manual_colors)
  plt <- plt +
    facet_wrap(~sample_cell, ncol = 1) +
    ylab("Watson | Crick") + xlab(NULL) +
    scale_x_continuous(breaks = pretty_breaks(12), labels = format_Mb) +
    scale_y_continuous(breaks = pretty_breaks(3)) +
    coord_cartesian(ylim = c(-y_lim, y_lim)) +

    theme_minimal() +
    theme(
      panel.spacing = unit(0, "lines"),
      axis.ticks.x = element_blank(),
      # strip.background = element_rect(color = "#eeeeee", fill = "#eeeeee"),
      strip.text = element_text(size = 8),
      # legend.position = "none"
      legend.position = "right",
      legend.key = element_rect(color = "black"),
      legend.text = element_text(size = 8),
      legend.title = element_text(size = 10)
    ) +
    ggtitle(paste(
      "Sample:",
      tools::file_path_sans_ext(basename(f_counts), compression = TRUE),
      ", chrom:",
      CHROM,
      ", Assembly:",
      assembly
    ))

  message("   * outputting")
  print(plt)
  i <- i + cells_per_page
} # while
dev.off()
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
suppressMessages(library(dplyr))
suppressMessages(library(data.table))
suppressMessages(library(ggplot2))
library(scales) %>% invisible()
library(assertthat) %>% invisible()
library(stringr) %>% invisible()
library(RColorBrewer) %>% invisible()

################################################################################
# Settings                                                                     #
################################################################################

zcat_command <- "zcat"
# FIXME : tmp solution
chroms <- c("chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8", "chr9", "chr10", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX")

format_Mb <- function(x) {
  paste(comma(x / 1e6), "Mb")
}

### Colors for background
manual_colors <- c(
  # duplications
  simul_hom_dup = "firebrick4",
  dup_hom = muted("firebrick4", 70, 50),
  simul_het_dup = "firebrick2",
  dup_h1 = muted("firebrick2", 90, 30),
  dup_h2 = muted("firebrick2", 80, 20),
  # deletions
  simul_hom_del = "dodgerblue4",
  del_hom = muted("dodgerblue4", 50, 60),
  simul_het_del = "dodgerblue2",
  del_h1 = muted("dodgerblue2", 80, 50),
  del_h2 = muted("deepskyblue2", 80, 50),
  # inversions
  simul_hom_inv = "chartreuse4",
  inv_hom = muted("chartreuse4", 80, 50),
  simul_het_inv = "chartreuse2",
  inv_h1 = muted("chartreuse2", 100, 60),
  inv_h2 = muted("darkolivegreen3", 100, 60),
  # other SVs
  simul_false_del = "darkgrey",
  simul_inv_dup = "darkgoldenrod2",
  idup_h1 = muted("darkgoldenrod2", 80, 70),
  idup_h2 = muted("gold", 80, 70),
  complex = "darkorchid1",
  # background
  bg1 = "#ffffff",
  bg2 = "#aaafaa",
  # Strand states
  `State: WW` = "sandybrown",
  `State: CC` = "paleturquoise4",
  `State: WC` = "khaki",
  `State: CW` = "yellow2"
)




################################################################################
# Usage                                                                        #
################################################################################

print_usage_and_stop <- function(msg = NULL) {
  if (!is.null(msg)) {
    message(msg)
  }
  message("Plot Strand-seq counts of all cells for a single chromosome.                    ")
  message("                                                                                ")
  message("Usage:                                                                          ")
  message("    Rscript chrom.R [OPTIONS] <count-file> <chrom> <out.pdf>                    ")
  message("                                                                                ")
  message("OPTIONS (no spaces around `=`):                                                 ")
  message("    per-page=<int>            Number of cells to be printed per page            ")
  message("    segments=<file>           Show the segmentation in the plots                ")
  message("    singlecellsegments=<file> Show per-cell  segmentation in the plots          ")
  message("    calls=<file>              Highlight SV calls provided in a table            ")
  message("    truth=<file>              Mark the `true`` SVs provided in a table          ")
  message("    strand=<file>             Mark the strand states which calls are based on   ")
  message("    complex=<file>            Mark complex regions given in file                ")
  message("    groups=<file>             Table with SV call grouping                       ")
  message("    no-none                   Do not hightlight black-listed (i.e. None) bins   ")
  message("                                                                                ")
  message("Generates one plot per chromosome listing all cells below another, separated    ")
  message("into pages. If an SV probability file is provided (2), segments are colored     ")
  message("according to the predicted SV classes. Note that only certain classes are       ")
  message("accepted and the script will yield an error if others are provided.             ")
  message("Similarly, a segmentation file can be specified, yet it must contain exactly    ")
  message("one segmentation (MosaiCatcher reports segmentations for various total numbers  ")
  message("of breakpoints).")
  options(show.error.messages = F)
  stop()
}



################################################################################
# Command Line Arguments                                                       #
################################################################################

args <- commandArgs(trailingOnly = T)

if (length(args) < 3) print_usage_and_stop("[Error] Too few arguments!")

f_counts <- args[length(args) - 2]
CHROM <- args[length(args) - 1]
f_out <- args[length(args)]

f_segments <- NULL
f_calls <- NULL
f_truth <- NULL
f_strand <- NULL
f_complex <- NULL
cells_per_page <- 8
show_none <- T

if (length(args) > 3) {
  if (!all(grepl("^(strand|calls|segments|per-page|truth|no-none|complex|singlecellsegments|groups)=?", args[1:(length(args) - 3)]))) {
    print_usage_and_stop("[Error]: Options must be one of `calls`, `segments`, `per-page`, or `truth`")
  }
  for (op in args[1:(length(args) - 3)]) {
    if (grepl("^segments=", op)) f_segments <- str_sub(op, 10)
    if (grepl("^calls=", op)) f_calls <- str_sub(op, 7)
    if (grepl("^truth=", op)) f_truth <- str_sub(op, 7)
    if (grepl("^per-page=", op)) {
      pp <- as.integer(str_sub(op, 10))
      if (pp > 0 && pp < 50) {
        cells_per_page <- pp
      }
    }
    if (grepl("^strand=", op)) f_strand <- str_sub(op, 8)
    if (grepl("^complex=", op)) f_complex <- str_sub(op, 9)
    if (grepl("^groups=", op)) f_groups <- str_sub(op, 8)
    if (grepl("^singlecellsegments=", op)) f_scsegments <- str_sub(op, 20)
    if (grepl("^no-none$", op)) show_none <- F
  }
}


################################################################################
# Read & check input data                                                      #
################################################################################

### Check counts table
message(" * Reading count data ", f_counts, "...")
if (grepl("\\.gz$", f_counts)) {
  counts <- fread(paste(zcat_command, f_counts))
} else {
  counts <- fread(f_counts)
}

# FIXME : tmp
# print(counts)
counts <- counts[counts$chrom %in% chroms, ]
# print(counts)

assert_that(
  "chrom" %in% colnames(counts),
  "start" %in% colnames(counts),
  "end" %in% colnames(counts),
  "class" %in% colnames(counts),
  "sample" %in% colnames(counts),
  "cell" %in% colnames(counts),
  "w" %in% colnames(counts),
  "c" %in% colnames(counts)
) %>% invisible()
counts[, sample_cell := paste(sample, "-", cell)]
setkey(counts, chrom, sample_cell)
bins <- unique(counts[, .(chrom, start, end)])

### Check CHROM:
assert_that(CHROM %in% unique(counts$chrom)) %>% invisible()
counts <- counts[chrom == CHROM]




### Check SV call file
if (!is.null(f_calls)) {
  message(" * Reading SV calls from ", f_calls, "...")
  svs <- fread(f_calls)
  assert_that(
    "chrom" %in% colnames(svs),
    "start" %in% colnames(svs),
    "end" %in% colnames(svs),
    "sample" %in% colnames(svs),
    "cell" %in% colnames(svs),
    ("SV_class" %in% colnames(svs) | "sv_call_name" %in% colnames(svs))
  ) %>% invisible()
  if (!("SV_class" %in% colnames(svs))) {
    svs[, SV_class := sv_call_name]
  }
  assert_that(all(svs$SV_class %in% names(manual_colors))) %>% invisible()
  svs[, sample_cell := paste(sample, "-", cell)]

  set_diff <- setdiff(unique(svs$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) message("[Warning] SV calls and Counts differ in cells: ", set_diff)

  svs <- svs[chrom == CHROM]
}

### Check segment table
if (!is.null(f_segments)) {
  message(" * Reading segmentation file from ", f_segments, "...")
  seg <- fread(f_segments)


  # FIXME : tmp
  seg <- seg[seg$chrom %in% chroms, ]
  seg <- seg[seg$bps > 0, ] # SOLVE T2T ISSUE

  # print(seg)

  assert_that(
    "chrom" %in% colnames(seg),
    "bps" %in% colnames(seg)
  ) %>% invisible()
  if ("k" %in% colnames(seg)) {
    seg[, assert_that(length(unique(k)) == 1), by = .(chrom)] %>% invisible()
  }

  # print(seg)

  # print(bins)
  # print(bins[, .N, by = chrom])
  # # print(bins %>% count(chrom))
  # print(bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N)))])
  # print(bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N))[1:(.N - 1)])])



  seg <- merge(seg, bins[, .N, by = chrom][, .(chrom, N = c(0, cumsum(N))[1:(.N - 1)])], by = "chrom")

  # print(c(1, bps[1:(.N - 1)] + 1))
  # print(bps)
  # print(seg)

  # stop()


  seg[, `:=`(from = c(1, bps[1:(.N - 1)] + 1), to = bps), by = chrom]

  # print(seg)


  seg[, `:=`(
    start = bins[from + N]$start,
    end = bins[to + N]$end
  )]

  # print("TEST")


  seg[, SV_class := rep(c("bg1", "bg2"), .N)[1:.N], by = chrom]



  seg <- seg[chrom == CHROM]
}



### Check simulated variants
if (!is.null(f_truth)) {
  message(" * Reading simulated variants from ", f_truth, "...")
  simul <- fread(f_truth)
  assert_that(
    "chrom" %in% colnames(simul),
    "start" %in% colnames(simul),
    "end" %in% colnames(simul),
    "sample" %in% colnames(simul),
    "cell" %in% colnames(simul),
    "SV_type" %in% colnames(simul)
  ) %>% invisible()
  simul[, `:=`(SV_class = paste0("simul_", SV_type), SV_type = NULL, sample_cell = paste(sample, "-", cell))]
  simul[, sample_cell := paste(sample, "-", cell)]

  set_diff <- setdiff(unique(simul$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) message("[Warning] True SVs and Counts differ in cells: ", set_diff)

  simul <- simul[chrom == CHROM]
}

### Check strand states file
if (!is.null(f_strand)) {
  message(" * Reading strand state file from ", f_strand, "...")
  strand <- fread(f_strand)
  assert_that(
    "sample" %in% colnames(strand),
    "cell" %in% colnames(strand),
    "chrom" %in% colnames(strand),
    "start" %in% colnames(strand),
    "end" %in% colnames(strand),
    "class" %in% colnames(strand)
  ) %>% invisible()
  strand[, class := paste("State:", class)]
  strand[, sample_cell := paste(sample, "-", cell)]

  set_diff <- setdiff(unique(strand$sample_cell), unique(counts$sample_cell))
  if (length(set_diff) > 0) message("[Warning] Strand states and Counts differ in cells: ", set_diff)

  strand <- strand[chrom == CHROM]
}

### Check complex regions file
if (!is.null(f_complex)) {
  message(" * Reading complex regions state file from ", f_complex, "...")
  complex <- fread(f_complex)
  assert_that(
    "chrom" %in% colnames(complex),
    "start" %in% colnames(complex),
    "end" %in% colnames(complex)
  ) %>% invisible()

  complex <- complex[chrom == CHROM]
  message("   --> Found ", nrow(complex), " complex region(s) in chromosome ", CHROM)
}

### Check SV groups file
if (!is.null(f_groups)) {
  message(" * Reading SV group file from ", f_groups, "...")
  groups <- fread(f_groups)
  assert_that(
    "chrom" %in% colnames(groups),
    "start" %in% colnames(groups),
    "end" %in% colnames(groups),
    "group_id" %in% colnames(groups)
  ) %>% invisible()
  groups[, group_id := paste("SV group", group_id)]
  groups <- groups[chrom == CHROM]
  message("   --> Found ", nrow(groups), " SV groups in chromosome ", CHROM)
}

### Check single cell segmentation file
if (!is.null(f_scsegments)) {
  message(" * Reading per-cell segmentation regions state file from ", f_scsegments, "...")
  scsegments <- fread(f_scsegments)
  assert_that(
    "sample" %in% colnames(scsegments),
    "cell" %in% colnames(scsegments),
    "chrom" %in% colnames(scsegments),
    "position" %in% colnames(scsegments)
  ) %>% invisible()
  scsegments[, sample_cell := paste(sample, "-", cell)]

  scsegments <- scsegments[chrom == CHROM]
}


################################################################################
# Actual plot                                                                  #
################################################################################


# Plot always a few cells per page!
y_lim <- 3 * counts[, median(w + c)]
n_cells <- length(unique(counts[, sample_cell]))
i <- 1


message(" * Plotting ", CHROM, " (", f_out, ")")
cairo_pdf(f_out, width = 14, height = 10, onefile = T)
while (i <= n_cells) {
  message(" * Processing cells from ", i, " to ", i + cells_per_page - 1)

  # Subset to this set of cells:
  CELLS <- unique(counts[, .(sample_cell)])[i:(min(i + cells_per_page - 1, n_cells))]
  setkey(CELLS, sample_cell)

  # Subset counts
  local_counts <- counts[CELLS, on = .(sample_cell), nomatch = 0]

  # Start major plot
  plt <- ggplot(local_counts)

  # Add background colors for segments, if available:
  if (!is.null(f_segments)) {
    message("   * Adding segment colors")
    # Segments need to be multiplied by "CELLS"
    local_seg <- CELLS[, as.data.table(seg), by = sample_cell]
    if (nrow(local_seg) > 0) {
      plt <- plt +
        geom_rect(
          data = local_seg, alpha = 0.4,
          aes(xmin = start, xmax = end, ymin = -Inf, ymax = Inf, fill = SV_class)
        )
    }
  }

  # Add colors for SV calls, if available
  if (!is.null(f_calls)) {
    message("   * Adding SV calls")
    local_svs <- svs[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_svs) > 0) {
      plt <- plt +
        geom_rect(
          data = local_svs, alpha = 1,
          aes(xmin = start, xmax = end, ymin = -Inf, ymax = Inf, fill = SV_class)
        )
    }
  }

  # Add bars for true SVs, if available
  if (!is.null(f_truth)) {
    message("   * Adding true SVs")
    local_sim <- simul[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_sim) > 0) {
      plt <- plt +
        geom_rect(
          data = local_sim,
          aes(xmin = start, xmax = end, ymin = y_lim, ymax = Inf, fill = SV_class)
        )
    }
  }

  # Add lines for single cell segmentation, if available
  if (!is.null(f_scsegments)) {
    message("   * Adding single cell segments")
    local_scsegments <- scsegments[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_scsegments) > 0) {
      plt <- plt +
        geom_segment(
          data = local_scsegments,
          aes(x = position, xend = position, y = -Inf, yend = -.8 * y_lim), color = "blue"
        )
    }
  }

  # Add bars for strand states, if available
  if (!is.null(f_strand)) {
    message("   * Adding strand states")
    local_strand <- strand[CELLS, on = .(sample_cell), nomatch = 0]
    if (nrow(local_strand) > 0) {
      plt <- plt +
        geom_rect(
          data = local_strand,
          aes(xmin = start, xmax = end, ymin = -Inf, ymax = -y_lim, fill = class)
        )
    }
  }

  # Add bars for SV group, if available
  if (!is.null(f_groups)) {
    message("   * Adding SV groups")
    if (nrow(groups) > 0) {
      plt <- plt +
        geom_rect(
          data = groups,
          aes(xmin = start, xmax = end, ymin = .85 * y_lim, ymax = Inf, fill = group_id)
        )
      # Add colors for SV classes
      manual_colors <- c(manual_colors, setNames(colorRampPalette(brewer.pal(12, "Set2"))(nrow(groups)), groups$group_id))
    }
  }

  # Add bars for complex states, if available
  if (!is.null(f_complex)) {
    message("   * Adding complex intervals")
    if (nrow(complex) > 0) {
      plt <- plt +
        geom_rect(
          data = complex,
          aes(xmin = start, xmax = end, ymin = y_lim, ymax = Inf), fill = "darkorchid1"
        )
    }
  }

  message("   * Adding actual W/C counts")
  plt <- plt +
    geom_rect(aes(xmin = start, xmax = end, ymin = 0, ymax = -w), fill = "sandybrown") +
    geom_rect(aes(xmin = start, xmax = end, ymin = 0, ymax = c), fill = "paleturquoise4")


  # Highlight None bins, if requested
  none_bins <- local_counts[class == "None"]
  if (show_none == T && nrow(none_bins) > 0) {
    message("   * Highlighting None bins")
    plt <- plt +
      geom_segment(data = none_bins, aes(x = start, xend = end, y = 0, yend = 0), col = "black", size = 2)
  }


  message("   * Adding labels, etc.")
  plt <- plt +
    facet_wrap(~sample_cell, ncol = 1) +
    ylab("Watson | Crick") + xlab(NULL) +
    scale_x_continuous(breaks = pretty_breaks(12), labels = format_Mb) +
    scale_y_continuous(breaks = pretty_breaks(3)) +
    coord_cartesian(ylim = c(-y_lim, y_lim)) +
    scale_fill_manual(values = manual_colors) +
    theme_minimal() +
    theme(
      panel.spacing = unit(0, "lines"),
      axis.ticks.x = element_blank(),
      strip.background = element_rect(color = "#eeeeee", fill = "#eeeeee"),
      strip.text = element_text(size = 5),
      legend.position = "bottom"
    ) +
    ggtitle(paste("data:", basename(f_counts), "chromosome:", CHROM))

  message("   * outputting")
  print(plt)
  i <- i + cells_per_page
} # while
dev.off()
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
suppressMessages(library(data.table))
suppressMessages(library(assertthat))
suppressMessages(library(ggplot2))
suppressMessages(library(scales))
suppressMessages(library(cowplot))


add_overview_plot <- T


args <- commandArgs(trailingOnly = T)
print(args)
if (length(args) < 2 || length(args) > 4 || !grepl("\\.pdf$", args[length(args)]) || any(!file.exists(args[1:(length(args) - 1)]))) {
    warning("Usage: Rscript R/qc.R input-file [SCE-file] [cell-info-file] output-pdf")
    quit(status = 1)
}
f_in <- args[1]
info <- args[2]
pdf_out <- args[3]



# Detect info or SCE file.
sces <- NULL
is_sce_file <- function(x) {
    all(c("sample", "cell", "chrom", "start", "end", "state") %in% colnames(x))
}
is_info_file <- function(x) {
    all(c("sample", "cell", "pass1", "dupl", "mapped", "nb_p", "nb_r", "nb_a") %in% colnames(x))
}
if (length(args) > 2) {
    x <- fread(args[2])
    if (is_sce_file(x)) {
        message("* Using SCE file ", args[2])
        sces <- x
    } else if (is_info_file(x)) {
        message("* Using INFO file ", args[2])
        info <- x
    }
    if (length(args) > 3) {
        assembly <- fread(args[3])
        if (is_sce_file(x)) {
            message("* Using SCE file ", args[3])
            sces <- x
        } else if (is_info_file(x)) {
            message("* Using INFO file ", args[3])
            info <- x
        }
    }
}




# if (!is.null(sces)) {
#     sces[, chrom := sub("^chr", "", chrom)]
#     sces[, chrom := factor(chrom, levels = as.character(c(1:22, "X", "Y")), ordered = T)]
# }



format_Mb <- function(x) {
    paste(comma(x / 1e6), "Mb")
}


# if gzip
zcat_command <- "zcat"
if (substr(f_in, nchar(f_in) - 2, nchar(f_in)) == ".gz") {
    f_in <- paste(zcat_command, f_in)
}

# Read counts & filter chromosomes (this is human-specific)
d <- fread(f_in)
print(d)
mouse_bool <- any(d$chrom == "chr22")
if (mouse_bool == FALSE) {
    chrom_levels <- as.character(c(1:19, "X", "Y"))
} else {
    chrom_levels <- as.character(c(1:22, "X", "Y"))
}
print(chrom_levels)

print(info)
# stop()

# Check that correct files are given:
invisible(assert_that(
    "chrom" %in% colnames(d),
    "start" %in% colnames(d) && is.integer(d$start),
    "end" %in% colnames(d) && is.integer(d$end),
    "sample" %in% colnames(d),
    "cell" %in% colnames(d),
    "w" %in% colnames(d) && is.numeric(d$w),
    "c" %in% colnames(d) && is.numeric(d$c),
    "class" %in% colnames(d)
))

# Re-name and -order chromosomes - this is human-specific
d <- d[, chrom := sub("^chr", "", chrom)][]
d <- d[grepl("^([1-9]|[12][0-9]|X|Y)$", chrom), ]
# d <- d[, chrom := factor(chrom, levels = as.character(c(1:22, "X", "Y")), ordered = T)]
d <- d[, chrom := factor(chrom, levels = chrom_levels, ordered = T)]
# d[, c(6, 7)] <- sapply(d[, c(6, 7)], as.double)

print(d)

message("* Writing plot ", pdf_out)

if (add_overview_plot == T) {
    message("* Plotting an overview page")

    n_samples <- nrow(unique(d[, .(sample)]))
    n_cells <- nrow(unique(d[, .(sample, cell)]))
    n_bins <- nrow(unique(d[, .(chrom, start, end)]))
    mean_bin <- unique(d[, .(chrom, start, end)])[, mean(end - start)]
    print(d)
    n_excl <- nrow(d[, .N, by = .(chrom, start, end, class)][class == "None" & N == n_cells, ])

    # Bin sizes
    ov_binsizes <- ggplot(unique(d[, .(chrom, start, end)])) +
        geom_histogram(aes(end - start), bins = 50) +
        theme_minimal() +
        scale_y_log10(breaks = c(1, 10, 100, 1000, 10e3, 100e3, 1e6)) +
        scale_x_log10(labels = comma) +
        ggtitle(paste0("Bin sizes (", n_bins, " bins, mean ", round(mean_bin / 1000, 1), " kb)")) +
        xlab("Bin size (bp)")

    # analyse how many bins are "None"
    ov_excbins <- ggplot(d[, .N, by = .(chrom, start, end, class)][class == "None" & N == n_cells, ]) +
        aes(chrom) +
        geom_bar() +
        theme_minimal() +
        ggtitle(paste0("Excluded bins per chromosome (total = ", n_excl, ")"))


    # coverage
    ov_coverage <- ggplot(d[, .(total = sum(as.double(w) + as.double(c))), by = .(sample, cell)]) +
        geom_histogram(aes(total, fill = sample), bins = 50) +
        scale_x_continuous(
            breaks = pretty_breaks(5),
            labels = comma
        ) +
        xlab("Total number of reads per cell") +
        theme_minimal() +
        theme(legend.position = "bottom") +
        scale_fill_brewer(type = "qual", palette = 6)

    # Overview mean / variance
    d_mv <- d[class != "None", .(mean = mean(w + c), var = var(w + c)), by = .(sample, cell)]
    print(d_mv)

    if(nrow(d_mv) > 0) {
        d_p <- d_mv[, .(p = sum(mean * mean) / sum(mean * var)), by = sample]

        ov_meanvar <- ggplot(d_mv) +
            geom_point(aes(mean, var), alpha = 0.4) +
            facet_wrap(~sample, nrow = 1) +
            theme_minimal() +
            geom_abline(data = d_p, aes(slope = 1 / p, intercept = 0), col = "dodgerblue") +
            geom_label(data = d_p, aes(x = 0, y = Inf, label = paste("p =", round(p, 3))), hjust = 0, vjust = 1) +
            ggtitle("Mean variance relationship of reads per bin") +
            xlab("Mean") +
            ylab("Variance")

        # Arranging overview plot
        content <- ggdraw() +
            draw_plot(ov_binsizes, x = 0, y = .66, width = .5, height = .33) +
            draw_plot(ov_excbins, x = .5, y = .66, width = .5, height = .33) +
            draw_plot(ov_coverage, x = 0, y = .33, width = .5, height = .33) +
            draw_plot(ov_meanvar, x = 0, y = 0, width = min(n_samples / 3, 1), height = .33)
    } else {
        print("d_mv is empty. Skipping the overview mean/variance plot...")

        # Arranging overview plot without ov_meanvar
        content <- ggdraw() +
            draw_plot(ov_binsizes, x = 0, y = .66, width = .5, height = .33) +
            draw_plot(ov_excbins, x = .5, y = .66, width = .5, height = .33) +
            draw_plot(ov_coverage, x = 0, y = .33, width = .5, height = .33)
    }


    # Add duplicate rates if available
    if (exists("info")) {
        ov_duplicate <- ggplot(info) +
            aes(dupl / (mapped - suppl), fill = sample) +
            geom_histogram(bins = 50) +
            xlab("Duplicate rate") +
            scale_x_continuous(labels = percent) +
            theme_minimal() +
            theme(legend.position = "bottom") +
            scale_fill_brewer(type = "qual", palette = 6)
        content <- content +
            draw_plot(ov_duplicate, x = 0.55, y = .33, width = .45, height = .33)
    }

    title <- ggdraw() + draw_label(paste("Overview across", n_cells, "cells from", n_samples, "samples"), fontface = "bold")
    side <- ggdraw() + draw_label(label = paste0(args[1], "\n", date()), angle = 90, size = 10, vjust = 1)

    final <- plot_grid(title, content, ncol = 1, rel_heights = c(0.07, 1))
    xxx <- plot_grid(side, final, nrow = 1, rel_widths = c(0.05, 1))
}

cairo_pdf(pdf_out, width = 14, height = 10, onefile = T)
if (add_overview_plot == T) {
    print(xxx)
}


# Plot all cells
for (s in unique(d$sample))
{
    # for (ce in unique(d[sample == s, ]$cell)[3])
    for (ce in unique(d[sample == s, ]$cell))
    {
        message(paste("* Plotting sample", s, "cell", ce))


        e <- d[sample == s & cell == ce, ]
        e$total <- e$c + e$w
        # print(e)

        library(dplyr)
        # e_sum <- e %>%
        #     group_by(chrom) %>%
        #     summarise(total = sum(total))

        # print(e_sum, n = 40)


        # e_sum <- e_sum[e_sum$total > 0, ]$chrom

        # print(e_sum)

        # e_lite <- filter(e, chrom %in% e_sum)
        # print(e_lite, n = 40)

        # e_lite <- e
        e_lite <- filter(e, bin_id == "")
        print(e_lite)



        # Calculate some informationxx
        info_binwidth <- median(e_lite$end - e_lite$start)
        info_reads_per_bin <- median(e_lite$w + e_lite$c)
        if (!is.integer(info_reads_per_bin)) info_reads_per_bin <- round(info_reads_per_bin, 2)
        info_chrom_sizes <- e_lite[, .(xend = max(end)), by = chrom]
        info_num_bins <- nrow(e_lite)
        info_total_reads <- sum(e_lite$c + e_lite$w)
        info_y_limit <- 2 * info_reads_per_bin + 1
        info_sample_name <- s
        info_cell_name <- ce
        # if (!is.integer(info_y_limit)) info_y_limit <- round(info_y_limit, 2)
        # info_sample_name <- substr(s, 1, 25)
        # if (nchar(s) > 25) info_sample_name <- paste0(info_sample_name, "...")
        # info_cell_name <- substr(ce, 1, 25)
        # if (nchar(ce) > 25) info_cell_name <- paste0(info_cell_name, "...")

        # start main plot:
        plt <- ggplot(e) +
            aes(x = (start + end) / 2)


        # prepare consecutive rectangles for a better plotting experience
        consecutive <- cumsum(c(0, abs(diff(as.numeric(as.factor(e$class))))))
        e$consecutive <- consecutive
        f <- e[, .(start = min(start), end = max(end), class = class[1]), by = .(consecutive, chrom)][]
        print(f)

        plt <- plt +
            geom_rect(data = f, aes(xmin = start, xmax = end, ymin = -Inf, ymax = Inf, fill = class), inherit.aes = F, alpha = 0.2) +
            scale_fill_manual(values = c(WW = "sandybrown", CC = "paleturquoise4", WC = "yellow", None = NA))

        # Show SCEs
        if (!is.null(sces)) {
            sces_local <- sces[sample == s & cell == ce][, .SD[.N > 1], by = chrom]
            if (nrow(sces_local) > 0) {
                sces_local <- sces_local[, .(pos = (end[1:(.N - 1)] + start[2:(.N)]) / 2), by = chrom]
                plt <- plt + geom_point(data = sces_local, aes(x = pos, y = -info_y_limit), size = 3, shape = 18)
            }
        }


        # Watson/Crick bars
        plt <- plt +
            geom_rect(aes(xmin = start, xmax = end, ymin = -w, ymax = 0), fill = "sandybrown") +
            geom_rect(aes(xmin = start, xmax = end, ymin = 0, ymax = c), fill = "paleturquoise4") +
            # geom_bar(aes(y = -w, width=(end-start)), stat='identity', position = 'identity', fill='sandybrown') +
            # geom_bar(aes(y = c, width=(end-start)), stat='identity', position = 'identity', fill='paleturquoise4') +
            # Trim image to 2*median cov
            coord_flip(expand = F, ylim = c(-info_y_limit, info_y_limit)) +
            facet_grid(. ~ chrom, switch = "x") +
            ylab("Watson | Crick") + xlab(NULL) +
            scale_x_continuous(breaks = pretty_breaks(12), labels = format_Mb) +
            scale_y_continuous(breaks = pretty_breaks(3)) +
            theme_classic() +
            theme(
                panel.spacing = unit(0.2, "lines"),
                axis.text.x = element_blank(),
                axis.ticks.x = element_blank(),
                strip.background = element_rect(fill = NA, colour = NA)
            ) +
            guides(fill = FALSE) +
            # Dotted lines at median bin count
            geom_segment(
                data = info_chrom_sizes, aes(xend = xend, x = 0, y = -info_reads_per_bin, yend = -info_reads_per_bin),
                linetype = "dotted", col = "darkgrey", size = 0.5
            ) +
            geom_segment(
                data = info_chrom_sizes, aes(xend = xend, x = 0, y = +info_reads_per_bin, yend = +info_reads_per_bin),
                linetype = "dotted", col = "darkgrey", size = 0.5
            ) +
            geom_segment(data = info_chrom_sizes, aes(xend = xend, x = 0), y = 0, yend = 0, size = 0.5)

        # Rename classes:
        labels <- e[, .N, by = class][, label := paste0(class, " (n=", N, ")")][]

        e[, class := factor(class, levels = labels$class, labels = labels$label)]

        # Histogram in upper right corner
        e.melt <- melt.data.table(e, c("chrom", "start", "end", "class"), measure.vars = c("w", "c"), variable.name = "strand", value.name = "coverage")
        plt_hist_xlim <- 10 + 3 * info_reads_per_bin
        plt_hist <- ggplot(e.melt) +
            aes(coverage, fill = strand) +
            geom_histogram(binwidth = 1, position = position_dodge(), alpha = 0.9) +
            scale_x_continuous(limits = c(-1, plt_hist_xlim), breaks = pretty_breaks(5), labels = comma) +
            theme(text = element_text(size = 10), axis.text = element_text(size = 8)) +
            scale_fill_manual(values = c(w = "sandybrown", c = "paleturquoise4")) +
            guides(fill = FALSE, col = FALSE) +
            ylab("bin count") +
            xlab("reads per bin") +
            facet_wrap(~class, nrow = 1, scales = "free")
        if (!is.null(info)) {
            Ie <- info[sample == s & cell == ce, ]
            if (nrow(Ie) != 1 || Ie$pass1 != 1 || !all(c("WW", "WC", "CC") %in% unique(Ie$class))) {
                message("  Problem finding additional info for ", s, " - ", ce)
            } else {
                p <- Ie$nb_p
                r <- Ie$nb_r
                a <- Ie$nb_a
                x <- seq(0, plt_hist_xlim)
                scale_factors <- e.melt[, .N, by = .(class, strand, coverage)][, .(scale = max(N)), by = class]
                nb <- data.table(
                    x = rep(x, 6),
                    strand = rep(c(rep("w", length(x)), rep("c", length(x))), 3),
                    class = c(
                        rep(labels[class == "WW", ]$label, 2 * length(x)),
                        rep(labels[class == "WC", ]$label, 2 * length(x)),
                        rep(labels[class == "CC", ]$label, 2 * length(x))
                    ),
                    scale = c(
                        rep(scale_factors[class == labels[class == "WW", ]$label, ]$scale, 2 * length(x)),
                        rep(scale_factors[class == labels[class == "WC", ]$label, ]$scale, 2 * length(x)),
                        rep(scale_factors[class == labels[class == "CC", ]$label, ]$scale, 2 * length(x))
                    ),
                    y = c(
                        dnbinom(x, a * r, p),
                        dnbinom(x, (1 - a) * r, p),
                        dnbinom(x, r / 2, p),
                        dnbinom(x, r / 2, p),
                        dnbinom(x, (1 - a) * r, p),
                        dnbinom(x, a * r, p)
                    )
                )
                plt_hist <- plt_hist + geom_line(data = nb, aes(x, y * scale, col = strand))
            }
        }


        plot_hst_width <- .03 + .13 * length(unique(e$class))
        x <- 0.25
        all <- ggdraw() + draw_plot(plt) +
            draw_plot(plt_hist, x = .45, y = .76, width = plot_hst_width, height = .23) +
            draw_label(paste("Sample:", info_sample_name), x = x, y = .97, vjust = 1, hjust = 0, size = 9) +
            draw_label(paste("Cell:", info_cell_name), x = x, y = .94, vjust = 1, hjust = 0, size = 8) +
            draw_label(paste("Median binwidth:", format(round(info_binwidth / 1000, 0), big.mark = ",", scientific = F), "kb"),
                x = x, y = .91, vjust = 1, hjust = 0, size = 8
            ) +
            draw_label(paste("Number bins:", format(info_num_bins, big.mark = ",")),
                x = x, y = .89, vjust = 1, hjust = 0, size = 8
            ) +
            draw_label(paste("Total number of reads:", format(info_total_reads, big.mark = ",")),
                x = x, y = .87, vjust = 1, hjust = 0, size = 8
            ) +
            draw_label(paste("Median reads/bin (dotted):", info_reads_per_bin),
                x = x, y = .85, vjust = 1, hjust = 0, size = 8
            ) +
            draw_label(paste0("Plot limits: [-", info_y_limit, ",", info_y_limit, "]"),
                x = x, y = .83, vjust = 1, hjust = 0, size = 8
            )
        # If available, add additional info like duplicate rate and NB params!
        if (!is.null(info)) {
            Ie <- info[sample == s & cell == ce, ]
            if (nrow(Ie) == 1) {
                all <- all +
                    draw_label(paste0("Duplicate rate: ", round(Ie$dupl / Ie$mapped, 2) * 100, "%"),
                        x = x, y = .80, vjust = 1, hjust = 0, size = 8
                    )
                if (Ie$pass1 == 1) {
                    all <- all +
                        draw_label(paste0("NB parameters (p,r,a): ", round(Ie$nb_p, 2), ",", round(Ie$nb_r, 2), ",", round(Ie$nb_a, 2)),
                            x = x, y = .78, vjust = 1, hjust = 0, size = 8
                        )
                }
            }
        }

        # If available, write number of SCEs detected
        if (!is.null(sces)) {
            sces_local <- sces[sample == s & cell == ce][, .SD[.N > 1], by = chrom]
            if (nrow(sces_local) > 0) sces_local <- sces_local[, .(pos = (end[1:(.N - 1)] + start[2:(.N)]) / 2), by = chrom]
            all <- all + draw_label(paste("SCEs detected:", nrow(sces_local)),
                x = x, y = .76, vjust = 1, hjust = 0, size = 8
            )
        }

        print(all)
    }
}
1
2
3
4
5
6
7
log <- file(snakemake@log[[1]], open = "wt")
sink(file = log, type = "message")
sink(file = log, type = "output")

source("workflow/scripts/plotting/sv_consistency_barplot.R")

SVplotting(inputfile = snakemake@input[["sv_calls"]], outputfile.byPOS = snakemake@output[["barplot_bypos"]], outputfile.byVAF = snakemake@output[["barplot_byaf"]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pandas as pd
import os, sys, glob, gzip

colors = {
    "none": "248,248,248",  # #F8F8F8
    "del_h1": "119,170,221",  # #77AADD
    "del_h2": "68,119,170",  # #4477AA
    "del_hom": "17,68,119",  # #114477
    "dup_h1": "204,153,187",  # #CC99BB
    "dup_h2": "170,68,136",  # #AA4488
    "dup_hom": "119,17,85",  # #771155
    "inv_h1": "221,221,119",  # #DDDD77
    "inv_h2": "170,170,68",  # #AAAA44
    "inv_hom": "119,119,17",  # #777711
    "idup_h1": "221,170,119",  # #DDAA77
    "idup_h2": "170,119,68",  # #AA7744
    "complex": "119,68,17",  # #774411
}


def create_bed_row(row, category, color):
    chrom, start, end, cell = row["chrom"], row["start"], row["end"], row["cell"]
    score = "0"
    strand = "+"
    return f"{chrom}\t{start}\t{end}\t{category}\t{score}\t{strand}\t{start}\t{end}\t{color}\n"


# def process_file(input_file, df_sv, output):


def main(input_counts, input_sv_file_stringent, input_sv_file_lenient, output):
    # Concatenate the W, C, and SV_call_name DataFrames
    df_sv_stringent = pd.read_csv(input_sv_file_stringent, sep="\t") 
    df_sv_stringent["color"] = df_sv_stringent["sv_call_name"].map(colors)
    df_sv_stringent = df_sv_stringent.sort_values(by=["cell"])
    df_sv_lenient = pd.read_csv(input_sv_file_lenient, sep="\t") 
    df_sv_lenient["color"] = df_sv_lenient["sv_call_name"].map(colors)
    df_sv_lenient = df_sv_lenient.sort_values(by=["cell"])

    # Get the list of input files in the input folder
    # input_files = glob.glob(os.path.join(input_counts_folder, "*.txt.percell.gz"))
    df_mosaic = pd.read_csv(input_counts, sep="\t")
    cell_list = df_mosaic.cell.unique().tolist()
    print(df_mosaic)
    print(cell_list)

    # Process each input file
    for cell_name in sorted(cell_list):
        # process_file(input_file, df_sv, output_file)

        # Extract cell name
        # cell_name = os.path.basename(input_file).replace(".txt.percell.gz", "")

        # Read the input gzipped file
        # df = pd.read_csv(input_file, sep="\t")
        df = df_mosaic.loc[df_mosaic["cell"] == cell_name]

        # Create separate DataFrames for 'c' and 'w' columns
        df_c = df[["chrom", "start", "end", "c"]]
        df_c["c"] = df_c["c"] * -1
        df_w = df[["chrom", "start", "end", "w"]]

        # Filter df_sv
        df_sv_cell_stringent = df_sv_stringent.loc[df_sv_stringent["cell"] == cell_name]
        df_sv_cell_lenient = df_sv_lenient.loc[df_sv_lenient["cell"] == cell_name]

        with gzip.open(output, "at") as output_file:
            output_file.write(
                f"track type=bedGraph name={cell_name}_W maxHeightPixels=40 description=BedGraph_{cell_name}_w.sort.mdup.bam_allChr visibility=full color=244,163,97\n"
            )
            df_w.to_csv(output_file, compression="gzip", sep="\t", header=False, index=False, mode="a")

            output_file.write(
                f"track type=bedGraph name={cell_name}_C maxHeightPixels=40 description=BedGraph_{cell_name}_c.sort.mdup.bam_allChr visibility=full color=102,139,138\n"
            )
            df_c.to_csv(output_file, compression="gzip", sep="\t", header=False, index=False, mode="a")

            output_file.write(f'track name="{cell_name}_SV_stringent" description="Stringent - SV_call_name for cell {cell_name}" visibility=squish itemRgb="On"\n')
            for _, row in df_sv_cell_stringent.iterrows():
                bed_row = create_bed_row(row, row["sv_call_name"], row["color"])
                output_file.write(bed_row)
            # output_file.write(f'track name="{cell_name}_SV_lenient" description="Lenient - SV_call_name for cell {cell_name}" visibility=squish itemRgb="On"\n')
            # for _, row in df_sv_cell_lenient.iterrows():
            #     bed_row = create_bed_row(row, row["sv_call_name"], row["color"])
            #     output_file.write(bed_row)


if __name__ == "__main__":
    if len(sys.argv) != 5:
        print("Usage: python script.py <input_counts>  <input_sv_stringent_file> <input_sv_lenient_file> <output_file>")
        # print("Usage: python script.py <input_counts>  <input_sv_stringent_file>  <output_file>")
        sys.exit(1)

    input_counts = sys.argv[1]
    input_sv_stringent_file = sys.argv[2]
    input_sv_lenient_file = sys.argv[3]
    output_file = sys.argv[4]
    # main(input_counts, input_sv_stringent_file, output_file)
    main(input_counts, input_sv_stringent_file, input_sv_lenient_file, output_file)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pandas as pd
from tensorflow.keras.models import load_model

Nucleosome_data_new = pd.read_csv(
    snakemake.input.features,
    delimiter="\t",
    header=None,
)
TSS_matrix_new = pd.read_csv(
    snakemake.input.TSS_annot,
    delimiter="\t",
    header=None,
)

i = str(snakemake.wildcards.chrom)

j = str(snakemake.wildcards.i)

TSS_matrix_new_index = TSS_matrix_new.loc[TSS_matrix_new[0] == i].index.tolist()
x_test = Nucleosome_data_new.loc[TSS_matrix_new_index].values.reshape(len(TSS_matrix_new_index), 150, 5)

Nucleosome_model_fixed = load_model("workflow/data/scNOVA/models_CNN/DNN_train{}_".format(j) + str(i) + ".h5")
y_pred = Nucleosome_model_fixed.predict_proba(x_test)
df = pd.DataFrame(y_pred, columns=["prob1", "prob2"])

df.to_csv(snakemake.output.train)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
import os

l = list(snakemake.input)

l_df = list()

l_df.append(
    pd.read_csv(
        sorted(list(l))[0],
        sep="\t",
        names=[
            "'chr'",
            "'start'",
            "'end'",
            "Feature",
            "'{file}'".format(file=os.path.basename(sorted(list(l))[0])).replace(".tab", ".bam"),
        ],
    )[["'chr'", "'start'", "'end'"]]
)

l_df.extend(
    [
        pd.read_csv(
            file,
            sep="\t",
            names=[
                "'chr'",
                "'start'",
                "'end'",
                "Feature",
                "'{file}'".format(file=os.path.basename(file)).replace(".tab", ".bam"),
            ],
        )[["'{file}'".format(file=os.path.basename(file)).replace(".tab", ".bam")]]
        for file in sorted(list(l))
    ]
)
# print(len(l_df))

pd.concat(l_df, axis=1).to_csv(snakemake.output.tab, sep="\t", index=False)
1
2
3
4
5
6
import pandas as pd

df = pd.read_csv(snakemake.input[0], sep="\t")
df.loc[df["Subclonality"] == snakemake.wildcards.clone].to_csv(
    snakemake.output[0], sep="\t", index=False
)
1
2
3
4
import pandas as pd

df = pd.read_csv(snakemake.input[0], sep="\t")
df.loc[df["chrom"] != "chrY"].to_csv(snakemake.output[0], sep="\t", index=False)
1
2
3
4
5
import pandas as pd

pd.concat([pd.read_csv(e) for e in sorted(list(snakemake.input))])[["prob1", "prob2"]].reset_index().to_csv(
    snakemake.output[0], index=False
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
import sys
from argparse import ArgumentParser
from collections import namedtuple, defaultdict
import bisect
import gzip
from math import ceil
import copy
import logging as log

logger = log.getLogger(__name__)


class Segmentation:
    def __init__(self, filename):
        self.sse = dict()
        self.breaks = defaultdict(list)
        n = 0
        self.binwidth = None
        for line in open(filename):
            if line.startswith("#"):
                continue
            if n == 0:
                f = line.split()
                assert f == [
                    "sample",
                    "cells",
                    "chrom",
                    "bins",
                    "maxcp",
                    "maxseg",
                    "none_bins",
                    "none_regs",
                    "action",
                    "k",
                    "sse",
                    "bps",
                    "start",
                    "end",
                ]
                # Fields = namedtuple('Fields', f)
            else:
                f = line.split()
                sample = f[0]
                cells = f[1]
                chrom = f[2]
                bins = int(f[3])
                maxcp = int(f[4])
                maxseg = int(f[5])
                none_bins = int(f[6])
                none_regs = int(f[7])
                action = f[8]
                k = int(f[9])
                sse = float(f[10])
                bps = int(f[11])
                start = int(f[12])
                end = int(f[13])
                if (self.binwidth is None) and (k > 1) and (start == 0):
                    self.binwidth = end / (bps + 1)
                self.sse[(chrom, k)] = sse
                if len(self.breaks[(chrom, k)]) == 0:
                    self.breaks[(chrom, k)].append(0)
                self.breaks[(chrom, k)].append(end)
            n += 1
        self.chromosomes = sorted(set(chrom for chrom, k in self.sse))

    def __str__(self):
        s = "Segmentation"
        for chrom, k in sorted(self.sse.keys()):
            s += "\n  chrom={}, k={}, sse={}, breaks={}".format(chrom, k, self.sse[(chrom, k)], self.breaks[(chrom, k)])
        return s

    def select_k(self, min_diff=1, max_abs_value=500000):
        """Select number of breakpoints for each chromosome such that the difference in squared error
        drops below min_diff."""
        self.selected_k = dict()
        for chromosome in self.chromosomes:
            k = 1
            while (
                ((chromosome, k + 1) in self.sse)
                and ((self.sse[(chromosome, k)] - self.sse[(chromosome, k + 1)]) > min_diff)
                or (self.sse[(chromosome, k)] > max_abs_value)
            ):
                k += 1
            self.selected_k[chromosome] = k

    def closest_breakpoint(self, chromosome, position):
        """Return the closest breakpoint to a given position in the selected segmentation."""
        breaks = self.breaks[(chromosome, self.selected_k[chromosome])]
        i = bisect.bisect_right(breaks, position)
        if i == 0:
            return breaks[0]
        elif i == len(breaks):
            return breaks[i - 1]
        elif abs(position - breaks[i - 1]) < abs(position - breaks[i]):
            return breaks[i - 1]
        else:
            return breaks[i]

    def get_selected_segmentation(self, chromosome):
        return self.breaks[(chromosome, self.selected_k[chromosome])]

    def write_selected_to_file(self, filename):
        print("binwidth", self.binwidth, file=sys.stderr)
        f = open(filename, "w")
        print("k", "chrom", "bps", sep="\t", file=f)
        for chromosome in self.chromosomes:
            breaks = self.breaks[(chromosome, self.selected_k[chromosome])]
            start = 0
            for position in breaks[1:]:
                end = position
                bps = ((position - start) / self.binwidth) - 1
                print(len(breaks) - 1, chromosome, ceil(bps), sep="\t", file=f)
        f.close()


class CountTable:
    def __init__(self, filename):
        # maps (cell,chromosome) to a list of counts (start, end, w, c)
        self.counts = defaultdict(list)
        for i, line in enumerate(gzip.open(filename)):
            if i == 0:
                fieldnames = list(x.decode() + "_" for x in line.split())
                Fields = namedtuple("Fields", fieldnames)
            else:
                f = list(x.decode() for x in line.split())
                fields = Fields(*f)
                self.counts[(fields.cell_, fields.chrom_)].append(
                    (int(fields.start_), int(fields.end_), float(fields.w_), float(fields.c_))
                )

    def get_counts(self, cell, chromosome, breaks):
        assert len(breaks) >= 2
        w_sums = [0] * (len(breaks) - 1)
        c_sums = [0] * (len(breaks) - 1)
        # fetch first segment
        i = 0
        segment_start, segment_end = breaks[i], breaks[i + 1]
        for bin_start, bin_end, w, c in self.counts[(cell, chromosome)]:
            while (bin_start >= segment_end) and (i + 2 < len(breaks)):
                i += 1
                segment_start, segment_end = breaks[i], breaks[i + 1]
            if segment_start <= bin_start < bin_end <= segment_end:
                w_sums[i] += w
                c_sums[i] += c
        return w_sums, c_sums


def read_info_file(filename):
    """Read info file and return a dict that maps cell names to NB parameters"""
    nb_params = dict()
    NB = namedtuple("NB", ["r", "p"])
    n = 0
    for line in open(filename):
        if line.startswith("#"):
            continue
        if n == 0:
            f = line.split()
            assert f == [
                "sample",
                "cell",
                "medbin",
                "mapped",
                "suppl",
                "dupl",
                "mapq",
                "read2",
                "good",
                "pass1",
                "nb_p",
                "nb_r",
                "nb_a",
                "bam",
            ]
        else:
            f = line.split()
            # sample = f[0]
            cell = f[1]
            # medbin = f[2]
            # mapped = f[3]
            # suppl = f[4]
            # dupl = f[5]
            # mapq = f[6]
            # read2 = f[7]
            # good = f[8]
            # pass1 = f[9]
            nb_p = float(f[10])
            nb_r = float(f[11])
            # nb_a = f[12]
            # bam = f[13]
            nb_params[cell] = NB(r=nb_r, p=nb_p)
        n += 1
    return nb_params


# TODO: use proper NB distribution in the future
def get_strand_state(w, c):
    """Returns the strand state a tuple (w,c), where (2,0) means WW, (1,1) means WC, etc."""
    if (w is None) or (c is None) or (w + c == 0):
        return (0, 0)
    r = w / (w + c)
    if r < 0.2:
        return (0, 2)
    elif r > 0.8:
        return (2, 0)
    else:
        return (1, 1)


def safe_div(a, b):
    if b == 0:
        return float("nan")
    else:
        return a / b


def evaluate_sce_list(sce_list, strand_state_list, breaks):
    """Pick initial state (i.e. at the start of the chromosome) such that the total distance where the
    state is off is minimized. Additionally evaluate whether to add one more SCE to avoid long stretches
    of wrong cell states."""
    best_mismatch_distance = None
    best_ground_state = None
    best_is_valid = None
    best_sce_list = None
    for w_ground_state, c_ground_state in [(2, 0), (1, 1), (0, 2)]:
        w_state, c_state = w_ground_state, c_ground_state
        mismatch_distance = 0
        valid = True
        for i in range(len(breaks) - 1):
            start = breaks[i]
            end = breaks[i + 1]
            w_actual_state, c_actual_state = strand_state_list[i]
            for sce_pos, w_state_diff, c_state_diff in sce_list:
                if sce_pos == start:
                    w_state += w_state_diff
                    c_state += c_state_diff
            # Test whether this sequence of SCEs has led to an impossible ground state
            # (at least under the assumption that the cell is diploid).
            if (w_state < 0) or (c_state < 0):
                valid = False
            if (w_actual_state, c_actual_state) != (w_state, c_state):
                mismatch_distance += end - start
        if (best_mismatch_distance is None) or ((valid, -mismatch_distance) > (best_is_valid, -best_mismatch_distance)):
            best_is_valid = valid
            best_mismatch_distance = mismatch_distance
            best_ground_state = (w_ground_state, c_ground_state)
            best_sce_list = copy.copy(sce_list)
    return best_is_valid, best_ground_state, best_mismatch_distance


def main():

    # ARGS
    parser = ArgumentParser(prog="detect_strand_states.py", description=__doc__)
    parser.add_argument("--samplename", default="UNNAMED", help="Sample name (to be mentioned in output files)")

    parser.add_argument(
        "--cellnames", default=None, help="Comma-separated list of single cell names, in the same order as the SINGLESEG files are given."
    )
    parser.add_argument(
        "--sce_min_distance", default=200000, type=int, help="Minimum distance of an SCE to a break in the joint segmentation."
    )
    parser.add_argument(
        "--sce_add_cutoff", default=20000000, type=int, help="Minimum gain in mismatch distance needed to add an additional SCE."
    )
    parser.add_argument("--output_jointseg", default=None, help="Filename to output selected joint segmentation to.")
    parser.add_argument("--output_singleseg", default=None, help="Filename to output selected single cell segmentations to.")
    parser.add_argument("--output_strand_states", default=None, help="Filename to output strand states to.")
    parser.add_argument(
        "--min_diff_jointseg",
        default=0.5,
        type=float,
        help="Minimum difference in error term to include another breakpoint in the joint segmentation (default=0.5).",
    )
    parser.add_argument(
        "--min_diff_singleseg",
        default=1,
        type=float,
        help="Minimum difference in error term to include another breakpoint in the single-cell segmentation (default=1).",
    )

    parser.add_argument("info", metavar="INFO", help="Info file with NB parameters for each single cell")
    parser.add_argument("counts", metavar="COUNT", help="Gzipped, tab-separated table with counts")
    parser.add_argument("jointseg", metavar="JOINTSEG", help="Tab-separated table with joint segmentation of all cells")
    parser.add_argument(
        "singleseg", nargs="+", metavar="SINGLESEG", help="Tab-separated table with single cell segmentation (one file per cell)"
    )
    args = parser.parse_args()

    # log.basicConfig(
    #     level=log.DEBUG,
    #     format="%(asctime)s %(levelname)-8s %(message)s",
    #     datefmt="%a, %d %b %Y %H:%M:%S",
    #     filename=args.log,
    #     filemode="w",
    # )

    if args.cellnames is None:
        # use filenames in the absence of given single cell names
        cell_names = args.singleseg
    else:
        l = args.cellnames.split(",")
        assert len(l) == len(args.singleseg)
        cell_names = l

    print(args.counts, args.jointseg, args.singleseg, file=sys.stderr)

    nb_params = read_info_file(args.info)
    # print(nb_params['TALL2x2PE20420'])

    print("Reading count table from", args.counts, file=sys.stderr)
    count_table = CountTable(args.counts)
    print(" ... done.", file=sys.stderr)

    # START

    jointseg = Segmentation(args.jointseg)
    jointseg.select_k(min_diff=args.min_diff_jointseg)
    print("Selected breakpoint numbers for joint segmentation:", file=sys.stderr)
    for chromosome in sorted(jointseg.selected_k.keys()):
        print(chromosome, jointseg.selected_k[chromosome], file=sys.stderr)
    if args.output_jointseg is not None:
        jointseg.write_selected_to_file(args.output_jointseg)

    output_strand_states_file = None
    if args.output_strand_states != None:
        output_strand_states_file = open(args.output_strand_states, "w")
        print("sample", "cell", "chrom", "start", "end", "class", sep="\t", file=output_strand_states_file)

    output_single_cell_seg_file = None
    if args.output_singleseg != None:
        output_single_cell_seg_file = open(args.output_singleseg, "w")
        print("sample", "cell", "chrom", "position", sep="\t", file=output_single_cell_seg_file)

    for filename, cell in zip(args.singleseg, cell_names):
        print("=" * 100, filename, file=sys.stderr)
        print("Processing", filename, file=sys.stderr)
        singleseg = Segmentation(filename)
        singleseg.select_k(min_diff=args.min_diff_singleseg)

        for chromosome in singleseg.chromosomes:
            print(" -- chromosome", chromosome, file=sys.stderr)
            breaks = singleseg.get_selected_segmentation(chromosome)
            if output_single_cell_seg_file is not None:
                for position in breaks:
                    print(args.samplename, cell, chromosome, position, sep="\t", file=output_single_cell_seg_file)

            w_counts, c_counts = count_table.get_counts(cell, chromosome, breaks)
            w, c = 0, 0
            strand_state_list = []
            strand_state = (0, 0)
            # all potential SCEs
            all_sce_candidates = []
            # indices of SCEs that have been selected
            selected_sce_indices = set()
            # iterate through all breaks and gather a list of potential SCEs
            # based on whether the strand state left and right of the segment is the same
            # and on whether the breakpoints coincide with breakpoints in the joint
            # segmentation of all cells.
            for i, b in enumerate(breaks):
                nearest_joint_breakpoint = jointseg.closest_breakpoint(chromosome, b)
                if i < len(w_counts):
                    w = w_counts[i]
                    c = c_counts[i]
                new_strand_state = get_strand_state(w, c)
                # if strand state could not be called (e.g. due to absence of reads), then
                # we assume the strand state to have stayed the same
                if new_strand_state == (0, 0):
                    new_strand_state = strand_state
                if (i > 0) and (new_strand_state != strand_state):
                    w_state_old, c_state_old = strand_state
                    w_state_new, c_state_new = new_strand_state
                    all_sce_candidates.append((b, w_state_new - w_state_old, c_state_new - c_state_old))
                    if abs(b - nearest_joint_breakpoint) >= args.sce_min_distance:
                        selected_sce_indices.add(len(all_sce_candidates) - 1)
                strand_state = new_strand_state
                strand_state_list.append(strand_state)
                print(
                    "    breakpoint: {}, nearest breakpoint (jointseg): {} (distance={}), W={}, C={} (ratio:{}), state: {}".format(
                        b, nearest_joint_breakpoint, abs(b - nearest_joint_breakpoint), w, c, safe_div(w, w + c), strand_state
                    ),
                    file=sys.stderr,
                )
            print("    strand states", strand_state_list, file=sys.stderr)
            print("    All SCE candidates:", all_sce_candidates, file=sys.stderr)
            # Compile initial list of SCEs
            sce_list = [all_sce_candidates[i] for i in sorted(selected_sce_indices)]
            print("    SCE list:", sce_list, file=sys.stderr)
            sce_list_is_valid, ground_state, mismatch_distance = evaluate_sce_list(sce_list, strand_state_list, breaks)
            print("    SCE list valid:", sce_list_is_valid, file=sys.stderr)
            print("    best ground (leftmost) state:", ground_state, "mismatch distance:", mismatch_distance, file=sys.stderr)
            # Refine SCE list:
            #  - add one more breakpoints if it substantially improves the concordence
            #  - add one or more breakpoints if the set of SCEs is invalid
            added_sces = 0
            while (added_sces <= 1) or (not sce_list_is_valid):
                best_i = None
                best_new_sce_list = None
                best_new_list_is_valid = None
                best_new_mismatch_distance = None
                best_new_ground_state = None

                # try out the effect of adding each SCE (one by one)
                for i in range(len(all_sce_candidates)):
                    if i in selected_sce_indices:
                        continue
                    print("      condidering adding SCE:", all_sce_candidates[i], file=sys.stderr)
                    new_selected_sce_indices = copy.copy(selected_sce_indices)
                    new_selected_sce_indices.add(i)
                    new_sce_list = [all_sce_candidates[i] for i in sorted(new_selected_sce_indices)]
                    new_sce_list_is_valid, new_ground_state, new_mismatch_distance = evaluate_sce_list(
                        new_sce_list, strand_state_list, breaks
                    )
                    if (best_new_mismatch_distance is None) or (
                        (new_sce_list_is_valid, -new_mismatch_distance) > (best_new_list_is_valid, -best_new_mismatch_distance)
                    ):
                        best_i = i
                        best_new_list_is_valid = new_sce_list_is_valid
                        best_new_mismatch_distance = new_mismatch_distance
                        best_new_ground_state = new_ground_state
                        best_new_sce_list = new_sce_list

                # Quit if there were no more candidates to be added potentially
                if best_new_sce_list is None:
                    break

                # Determine whether to reject the best possible change we found and stop
                if (not sce_list_is_valid) or ((mismatch_distance - best_new_mismatch_distance) >= args.sce_add_cutoff):
                    selected_sce_indices.add(best_i)
                    sce_list = best_new_sce_list
                    sce_list_is_valid = best_new_list_is_valid
                    mismatch_distance = best_new_mismatch_distance
                    ground_state = best_new_ground_state
                    added_sces += 1
                    print(
                        "      accepting change to SCE list:",
                        sce_list,
                        "new distance:",
                        mismatch_distance,
                        "new ground state:",
                        ground_state,
                        file=sys.stderr,
                    )
                else:
                    break

            # The procedure above should find a valid SCE selection
            # (in the worst case, it can just select all potential SCEs (which is valid)
            assert sce_list_is_valid

            if output_strand_states_file is not None:
                start = 0
                w_state, c_state = ground_state
                for sce_pos, w_state_diff, c_state_diff in sce_list:
                    end = sce_pos
                    strand_state_str = "W" * w_state + "C" * c_state
                    print(args.samplename, cell, chromosome, start, end, strand_state_str, sep="\t", file=output_strand_states_file)
                    w_state += w_state_diff
                    c_state += c_state_diff
                    start = sce_pos
                end = breaks[-1]
                strand_state_str = "W" * w_state + "C" * c_state
                print(args.samplename, cell, chromosome, start, end, strand_state_str, sep="\t", file=output_strand_states_file)

    if output_strand_states_file is not None:
        output_strand_states_file.close()

    if output_single_cell_seg_file is not None:
        output_single_cell_seg_file.close()


if __name__ == "__main__":
    main()
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import subprocess
p = []
try:
    f = snakemake.config["ground_truth_clonal"][snakemake.wildcards.sample]
    if len(f) > 0:
        p.append('--true-events-clonal')
        p.append(f)
except KeyError:
    pass
try:
    f = snakemake.config["ground_truth_single_cell"][snakemake.wildcards.sample]
    if len(f) > 0:
        p.append('--true-events-single-cell')
        p.append(f)
except KeyError:
    pass
if snakemake.wildcards.filter == 'TRUE':
    p.append('--merged-file')
    p.append(snakemake.input.merged)
additional_params = ' '.join(p)
subprocess.call('workflow/scripts/stats/callset_summary_stats.py --segmentation {} --strandstates {} --complex-regions {} {} {}  > {} '.format(
    snakemake.input.segmentation,
    snakemake.input.strandstates,
    snakemake.input.complex,
    additional_params,
    snakemake.input.sv_calls,
    snakemake.output.tsv
    ), shell=True
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pandas as pd
import os


def write_to_html_file(df, title):
    inp = """
        <html>
        <head>
        <style>
            h2 {
                text-align: center;
                font-family: Helvetica, Arial, sans-serif;
            }
            table { 
                margin-left: auto;
                margin-right: auto;
            }
            table, th, td {
                border: 1px solid black;
                border-collapse: collapse;
            }
            th, td {
                padding: 5px;
                text-align: center;
                font-family: Helvetica, Arial, sans-serif;
                font-size: 90%;
            }
            table tbody tr:hover {
                background-color: #dddddd;
            }
            .wide {
                width: 90%; 
            }
        </style>
        </head>
        <body>
        """
    out = """
            </body>
            </html>
            """
    result = inp
    result += "<h2> {} statistics summary </h2>\n".format(str(title))
    result += df.to_html(classes="wide", escape=False, index=False)
    result += out
    return result


df = pd.read_csv(snakemake.input[0], sep="\t")
df["callset"] = df["callset"].apply(lambda r: os.path.basename(r).replace(".tsv", ""))
df = df.set_index("callset")
df = df.fillna(0).T.reset_index()
pd.options.display.float_format = "{:,.1f}".format
df_out = write_to_html_file(df, snakemake.wildcards.sample)
with open(snakemake.output.html, "w") as o:
    o.write(df_out)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import pandas as pd

l_df = list()
print(snakemake.input.files)
for j, file in enumerate(snakemake.input.files):
    print(j, file)
    tmp_df = pd.read_csv(file, sep="\t")
    print(tmp_df)
    l_df.append(tmp_df)
df = pd.concat(l_df)
df.to_csv(snakemake.output[0], sep="\t", index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
sink(snakemake@log[[1]])
library(data.table)

d <- fread(snakemake@input[["states"]])

e <- fread(snakemake@input[["info"]])
e$bam <- basename(e$bam)
f <- merge(d, e, by = c("sample", "cell"))[class == "WC", .(chrom, start, end, bam)]

write.table(f, file = snakemake@output[[1]], quote = F, row.names = F, col.names = F, sep = "\t")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
sink(snakemake@log[[1]])
library(data.table)
library(assertthat)
e <- fread(snakemake@input[["phased_states"]])
e
d <- fread(snakemake@input[["info"]])
d
g <- fread(snakemake@input[["initial_states"]])
g

d$bam <- basename(d$bam)
e$bam <- e$cell
e$cell <- NULL
e$sample <- NULL
f <- merge(d, e, by = "bam")[, .(chrom, start, end, sample, cell, class)]
f

# Note that there is still a bug in Venla's strand state detection.
g <- merge(g, f, by = c("chrom", "start", "end", "sample", "cell"), all.x = T)
g


# Overwrite with David's phased strand state if available!
g <- g[, class := ifelse(!is.na(class.y), class.y, class.x)][]
g$class.x <- NULL
g$class.y <- NULL
g <- g[, .(chrom, start, end, sample, cell, class)]
g

write.table(g, file = snakemake@output[[1]], quote = F, row.names = F, col.names = T, sep = "\t")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
with open(snakemake.output[0], "w") as f:
    print("[General]", file=f)
    print("numCPU           = 1", file=f)
    print("chromosomes      = '" + snakemake.wildcards.chrom + "'", file=f)
    print("pairedEndReads   = '" + [e.strip() for e in open(snakemake.input.single_paired_end_detect, "r").readlines()][0] + "'", file=f)
    print("min.mapq         = 10", file=f)
    print("", file=f)
    print("[StrandPhaseR]", file=f)
    print("positions        = NULL", file=f)
    print("WCregions        = NULL", file=f)
    print("min.baseq        = 20", file=f)
    print("num.iterations   = 2", file=f)
    print("translateBases   = TRUE", file=f)
    print("fillMissAllele   = NULL", file=f)
    print("splitPhasedReads = TRUE", file=f)
    print("compareSingleCells = FALSE", file=f)
    print("callBreaks       = FALSE", file=f)
    print("exportVCF        = '", snakemake.wildcards.sample, "'", sep="", file=f)
    print("bsGenome         = '", snakemake.config["references_data"][snakemake.config["reference"]]["R_reference"], "'", sep="", file=f)
    # print("bsGenome         = '", snakemake.config["R_reference"], "'", sep="", file=f)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
options(error = traceback)
args <- commandArgs(TRUE)

# add user defined path to load needed libraries
.libPaths(c(.libPaths(), args[6]))

suppressPackageStartupMessages(library(StrandPhaseR))

# FIXME : tmp debuging local repo
# library(devtools)

# load package w/o installing
# load_all("/g/korbel2/weber/Gits/StrandPhaseR")
# strandphaser_path <- "/g/korbel2/weber/Gits/Strandphaser_clean/StrandPhaseR"
# print(strandphaser_path)
# load_all(strandphaser_path)

strandPhaseR(inputfolder = args[1], outputfolder = args[2], configfile = args[3], WCregions = args[4], positions = args[5], fillMissAllele = args[5])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os, sys
import subprocess
from tqdm import tqdm
import parmap
import multiprocessing as mp
import numpy as np


# snakemake_log = open(snakemake.log[0], "w")

l_files_selected = snakemake.input.bam


# Initiate MP
m = mp.Manager()
l_df = m.list()


def loop(file, l_df):
    """MP function

    Args:
        file (str): bam file path
        l_df (list): MP shared list
    """
    p = subprocess.Popen("samtools view -c -f 1 {file}".format(file=file), shell=True, stdout=subprocess.PIPE)
    p_out = int(p.communicate()[0].decode("utf-8").strip())

    # Add to shared MP list
    l_df.append(p_out)


# Launch function in parallel on list of files
parmap.starmap(loop, list(zip(l_files_selected)), l_df, pm_pbar=True, pm_processes=10)

l_df = list(l_df)

paired_end = True
if np.mean(l_df) == 0:
    paired_end = False
elif np.mean(l_df) == 0:
    if 0 in l_df:
        sys.exit("Mix of single-end and paired-end files")
    else:
        paired_end = True

# snakemake_log.write("Paired-end: {paired_end} for sample: {sample}".format(paired_end=paired_end, sample=snakemake.wildcards.sample))

with open(snakemake.output.single_paired_end_detect, "w") as output:
    output.write(str(paired_end).upper())
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import subprocess
import pandas as pd
import sys, os

# snakemake_log = open(snakemake.log[0], "w")

# Prepare header of info files
subprocess.call("grep '^#' {} > {}".format(snakemake.input.info_raw, snakemake.output.info), shell=True)
subprocess.call("grep '^#' {} > {}".format(snakemake.input.info_raw, snakemake.output.info_removed), shell=True)

# Read mosaic count info
df = pd.read_csv(snakemake.input.info_raw, skiprows=13, sep="\t")
df["pass1"] = df["pass1"].astype(int)

labels_path = snakemake.input.labels
labels = pd.read_csv(labels_path, sep="\t")
labels["cell"] = labels["cell"].str.replace(".sort.mdup.bam", "")
df["cell"] = df["cell"].str.replace(".sort.mdup.bam", "")

# print(df)
# print(labels)


# if snakemake.config["use_light_data"] is True and snakemake.wildcards.sample == "RPE-BM510":
#     df = pd.concat([df, pd.DataFrame([{"sample": "RPE-BM510", "cell": "BM510x04_PE20320.sort.mdup.bam", "pass1": 0}])])

# snakemake_log.write(labels.to_str())

# b_ashleys = "ENABLED" if snakemake.config["ashleys_pipeline"] is True else "DISABLED"
# b_old = "ENABLED" if snakemake.config["input_bam_legacy"] is True else "DISABLED"

# snakemake_log.write("ASHLEYS preprocessing module: {}".format(b_ashleys))
# snakemake_log.write("input_bam_legacy parameter: {}".format(b_old))
# snakemake_log.write("Computing intersection between lists ...")

# IF BOTH MOSAIC INFO FILE & LABELS DF ARE AVAILABLE + SAME SIZE
if labels.shape[0] == df.shape[0]:
    if len(set(labels.cell.values.tolist()).intersection(set(df.cell.values.tolist()))) == labels.shape[0]:
        print("labels.shape[0] == df.shape[0]")
        cells_to_keep_labels = labels.loc[labels["prediction"] == 1]["cell"].str.replace(".sort.mdup.bam", "").sort_values().tolist()
        cells_to_keep_mosaic = df.loc[df["pass1"] == 1]["cell"].unique().tolist()
        cells_to_keep = list(sorted(list(set(cells_to_keep_labels).intersection(cells_to_keep_mosaic))))
    else:
        sys.exit("Ashleys labels & Mosaicatcher count info file do not share the same cell naming format")


else:
    # CATCH ERROR IF DIFFERENT SIZES AND CONFIG ENABLED
    if (snakemake.config["ashleys_pipeline"] is True) or (snakemake.config["input_bam_legacy"] is True):
        sys.exit("Dataframes do not have the same dimensions:")
        sys.exit("mosaic info: {} ; labels: {}".format(str(df.shape[0]), str(labels.shape[0])))

    # ELSE NORMAL MODE
    else:
        print("df.shape[0] only")
        # snakemake_log.write("Standard mode using only 'mosaic count info' file")
        cells_to_keep = df.loc[df["pass1"] == 1]["cell"].unique().tolist()


# cells_to_keep = labels.loc[labels["prediction"] == 1]["cell"].str.replace(".sort.mdup.bam", "").tolist()
df_kept = df.loc[df["cell"].isin(cells_to_keep)]
df_removed = df.loc[~df["cell"].isin(cells_to_keep)]

# snakemake_log.write("List of cells kept: ")
# for cell in sorted(cells_to_keep):
# snakemake_log.write("- {cell}".format(cell=cell))

# snakemake_log.write("List of cells removed:")
# for cell in sorted(df_removed["cell"].values.tolist()):
# snakemake_log.write("- {cell}".format(cell=cell))


df_counts = pd.read_csv(snakemake.input.counts_sort, compression="gzip", sep="\t")
df_counts = df_counts.loc[df_counts["cell"].isin(cells_to_keep)]
df_counts.to_csv(snakemake.output.counts, compression="gzip", sep="\t", index=False)

df_kept.to_csv(snakemake.output.info, index=False, sep="\t", mode="a")
df_removed.to_csv(snakemake.output.info_removed, index=False, sep="\t", mode="a")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pandas as pd
import pysam

import parmap
import multiprocessing as mp
import logging

# Set up logging
logging.basicConfig(filename=snakemake.log[0], level=logging.INFO, format="%(asctime)s %(levelname)s:%(message)s")
# logging.basicConfig(filename="/Users/frank/mosaicatcher-pipeline/debug.log", level=logging.INFO, format="%(asctime)s %(levelname)s:%(message)s")


m = mp.Manager()
l_df = m.list()
# l_df = list()


def filter_chrom(bam, l):
    # READ BAM FILE HEADER OF FIRST BAM IN THE PANDAS DF
    h = pysam.view("-H", bam)
    # h = pysam.view("-H", os.listdir(snakemake.input.bam + "selected")[0])
    h = [e.split("\t") for e in h.split("\n") if "@SQ" in e]

    l.extend(h)


parmap.starmap(filter_chrom, list(zip(list(snakemake.input.bam))), l_df, pm_pbar=False, pm_processes=10)
# for file in list(snakemake.input.bam):
#     filter_chrom(file, l_df)


# CONVERT TO PANDAS DF
df_h = pd.DataFrame((list(l_df)), columns=["TAG", "Contig", "LN"])

logging.info(f"Processed raw DataFrame: \n {df_h.to_string()}")


# PROCESS CONTIGS
output_h = pd.DataFrame(df_h["Contig"].str.replace("SN:", ""))


logging.info(f'List of chromosomes provided in the configuration: \n {snakemake.params["chroms"]}')


output_h = output_h.loc[~output_h["Contig"].isin(snakemake.params["chroms"])]


# Log the content of the output_h DataFrame
logging.info(f"Processed list of chromosomes to be removed: \n {output_h.to_string()}")


# EXPORT
output_h = output_h["Contig"].drop_duplicates()

logging.info(f"Processed list of chromosomes to be removed without duplicates: \n {output_h.to_string()}")


output_h.to_csv(snakemake.output[0], index=False, sep="\t", header=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys, os
import pandas as pd

folder = snakemake.wildcards.folder
sample = snakemake.wildcards.sample

# folder = "/g/korbel2/weber/MosaiCatcher_files/HGSVC_WH"
# sample = "TEST"
ext = ".sort.mdup.bam"

# ASSERTIONS TO CHECK IF FOLDERS EXIST OR NOT
assert os.path.isdir("{folder}/{sample}/bam/".format(folder=folder, sample=sample)), "Folder all for sample {sample} does not exist".format(
    sample=sample
)
assert os.path.isdir(
    "{folder}/{sample}/selected/".format(folder=folder, sample=sample)
), "Folder selected for sample {sample}  does not exist".format(sample=sample)

# RETRIEVE LIST OF FILES
l_files_all = [f for f in os.listdir("{folder}/{sample}/bam/".format(folder=folder, sample=sample)) if f.endswith(ext)]
l_files_selected = [f for f in os.listdir("{folder}/{sample}/selected/".format(folder=folder, sample=sample)) if f.endswith(ext)]

# CHECK IF FILE EXTENSION IS CORRECT
if (
    len(l_files_all) == 0
    and len([f for f in os.listdir("{folder}/{sample}/bam/".format(folder=folder, sample=sample)) if f.endswith(".bam")]) > 0
):
    sys.exit("BAM files extension were correctly set: .bam instead of .sort.mdup.bam (prevent further issues)")


# pd.options.display.max_rows = 100

# CREATE PANDAS DATAFRAME
df = pd.DataFrame([l_files_all, [1] * len(l_files_all), [1] * len(l_files_all)]).T
df.columns = ["cell", "probability", "prediction"]

# FLAG UNSELECTED FILES
unselected = set(l_files_all).difference(set(l_files_selected))
df.loc[df["cell"].isin(unselected), "probability"] = 0
df.loc[df["cell"].isin(unselected), "prediction"] = 0

# OUTPUT
df.sort_values(by="cell").to_csv(snakemake.output[0], sep="\t", index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
package <- snakemake@params[["selected_package"]]
# package <- "workflow/data/ref_genomes/BSgenome.T2T.CHM13.V2_1.0.0.tar.gz"
# print(grepl("BSgenome.T2T.CHM13.V2_1.0.0.tar.gz", package, fixed = TRUE, perl = FALSE))

is_package_available <- require(package, character.only = TRUE)

if (!isTRUE(is_package_available)) {
    if (!require("BiocManager", quietly = TRUE)) {
        install.packages("BiocManager", repos = "http://cran.us.r-project.org")
    }
    if (grepl("BSgenome.T2T.CHM13.V2_1.0.0.tar.gz", package, fixed = TRUE, perl = FALSE)) {
        print("T2T")
        BiocManager::install("GenomeInfoDbData", update = FALSE)
        install.packages(package, repos = NULL, type = "source")
    } else {
        BiocManager::install(package, update = FALSE)
    }
    quit(save = "no")
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pandas as pd

# Read 200kb bins file

binbed = pd.read_csv(
    snakemake.input.bin_bed,
    # "../../../../mosaicatcher-update/workflow/data/bin_200kb_all.bed",
    sep="\t",
    names=["chrom", "start", "end", "bin_id"],
)
binbed["ID"] = binbed["chrom"] + "_" + binbed["start"].astype(str) + "_" + binbed["end"].astype(str)

# Turn chrom into categorical
binbed["chrom"] = pd.Categorical(
    binbed["chrom"],
    categories=["chr{}".format(e) for e in range(1, 23)] + ["chrX", "chrY"],
    ordered=True,
)

# Sort & filter out chrY #TMP / can be changed
binbed = binbed.sort_values(by=["chrom", "start", "end"]).reset_index(drop=True)
binbed["w"], binbed["c"], binbed["class"] = 0, 0, None


# Read SV file
# df = pd.read_csv("../../../../mosaicatcher-update/.tests/data_CHR17/RPE-BM510/counts/RPE-BM510.txt.raw.gz", sep="\t")

# sep = "," if "/multistep_normalisation/" in snakemake.input.counts else "\t"
sep = "\t"
df = pd.read_csv(snakemake.input.counts, sep=sep, compression="gzip")
df["ID"] = df["chrom"] + "_" + df["start"].astype(str) + "_" + df["end"].astype(str)
df["w"] = df["w"].round(0).astype(int)
df["c"] = df["c"].round(0).astype(int)
if sep == ",":
    df["tot_count"] = df["tot_count"].round(0).astype(int)

## Populate counts df for each cell in order to have all bins represented
l = list()

# Loop over cells
for cell in df.cell.unique().tolist():

    # Outer join to retrieve both real count values from specified chromosome and empty bins
    tmp_df = pd.concat(
        [
            binbed.loc[~binbed["ID"].isin(df.loc[df["cell"] == cell].ID.values.tolist())],
            df.loc[df["cell"] == cell],
        ]
    )

    # Filla cell & sample columns
    tmp_df["cell"] = cell
    tmp_df["sample"] = df.loc[df["cell"] == cell, "sample"].values.tolist()[0]
    l.append(tmp_df)

# Concat list of DF and output
populated_df = pd.concat(l).sort_values(by=["cell", "chrom", "start"])
# populated_df.to_csv("test.txt.gz", compression="gzip", sep="\t", index=False)
populated_df.to_csv(snakemake.output.populated_counts, compression="gzip", sep="\t", index=False)
1
2
3
4
5
6
7
8
9
import pandas as pd

df = pd.read_csv(snakemake.input[0], sep=",", compression="gzip")
df = df[["chrom", "start", "end", "sample", "cell", "c", "w", "class"]]
df["w"] = df["w"].fillna(0)
df["c"] = df["c"].fillna(0)
df["class"] = df["class"].fillna("None")
df["class"] = df["class"].astype(str)
df.to_csv(snakemake.output[0], sep="\t", index=False, compression="gzip")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
import yaml

labels = snakemake.input.labels
info_raw = snakemake.input.info_raw
ploidy_summary = snakemake.input.ploidy_summary
single_paired_end_detect = snakemake.input.single_paired_end_detect

single_paired_end_detect_content = open(single_paired_end_detect, "r").readlines()[0]

df_labels = pd.read_csv(labels, sep="\t")
df_labels = df_labels[["cell", "prediction"]]
df_labels = df_labels.rename({"prediction": "hand_labels"}, axis=1)
df_labels["cell"] = df_labels["cell"].str.replace(".sort.mdup.bam", "")

df_info = pd.read_csv(info_raw, skiprows=13, sep="\t")[["cell", "pass1"]]
df_info = df_info.rename({"pass1": "mosaic_cov_pass"}, axis=1)

if df_labels.shape[0] > 0:
    final_df = pd.merge(df_labels, df_info, on="cell")
    final_df.loc[(final_df["hand_labels"] == 1) & (final_df["mosaic_cov_pass"] == 1), "Final_keep"] = 1
    final_df["Final_keep"] = final_df["Final_keep"].fillna(0)

else:
    final_df = df_info
    final_df["Final_keep"] = final_df["mosaic_cov_pass"]
final_df = final_df.rename({"hand_labels": "Ashleys/hand labels"}, axis=1).sort_values(by="cell", ascending=True)


df_ploidy = pd.read_csv(ploidy_summary, sep="\t")[["#chrom", "50%"]]
df_ploidy = df_ploidy.loc[df_ploidy["#chrom"] != "genome"]
chroms = ["chr" + str(c) for c in list(range(1, 23))] + ["chrX", "chrY"]
df_ploidy["#chrom"] = pd.Categorical(df_ploidy["#chrom"], categories=chroms, ordered=True)
df_ploidy = df_ploidy.sort_values(by=["#chrom"]).rename({"#chrom": "chrom", "50%": "ploidy_estimation"}, axis=1)
df_ploidy.loc[df_ploidy["ploidy_estimation"] == 1, "StrandPhaseR_processed"] = 0
df_ploidy["StrandPhaseR_processed"] = df_ploidy["StrandPhaseR_processed"].fillna(1)

with open(snakemake.output.summary, "w") as o:
    o.write("\n==============Library quality summary==============\n")
    o.write("\n")
    o.write(final_df.to_markdown(tablefmt="github", index=False))
    o.write("\n")
    o.write("\n==============Ploidy summary==============\n")
    o.write("\n")
    o.write(df_ploidy.to_markdown(tablefmt="github", index=False))
    o.write("\n")
    o.write("\n==============YAML configuration used==============\n")
    o.write("\n")
    o.write(yaml.dump(snakemake.config))
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import pandas as pd

df = pd.read_csv(snakemake.input[0], sep="\t", compression="gzip")
df["start"] = df["start"].astype(int)
df["end"] = df["end"].astype(int)
# chroms = ["chr{}".format(str(c)) for c in list(range(1, 23))] + ["chrX", "chrY"]
chroms = snakemake.config["chromosomes"]
print(chroms)
df["chrom"] = pd.Categorical(df["chrom"], categories=chroms, ordered=True)
df = df.loc[(df["chrom"].str.contains("chr") == True) & (df["chrom"].isna() == False) & (~df["chrom"].isin(["NaN", "nan", "NA", None, ""]))]
df.sort_values(by=["cell", "chrom", "start", "end"]).to_csv(snakemake.output[0], compression="gzip", sep="\t", index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import subprocess

if snakemake.config["use_light_data"] is False:
    subprocess.Popen(
        "ln -s  {input_bam} {output_bam}".format(input_bam=snakemake.input.bam, output_bam=snakemake.output.bam),
        shell=True,
        stdout=subprocess.PIPE,
    )
    subprocess.Popen(
        "ln -s  {input_bai} {output_bai}".format(input_bai=snakemake.input.bai, output_bai=snakemake.output.bai),
        shell=True,
        stdout=subprocess.PIPE,
    )
else:
    subprocess.Popen(
        "cp {input_bam} {output_bam}".format(input_bam=snakemake.input.bam, output_bam=snakemake.output.bam),
        shell=True,
        stdout=subprocess.PIPE,
    )
    subprocess.Popen(
        "cp {input_bai} {output_bai}".format(input_bai=snakemake.input.bai, output_bai=snakemake.output.bai),
        shell=True,
        stdout=subprocess.PIPE,
    )
ShowHide 177 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/friendsofstrandseq/mosaicatcher-pipeline
Name: mosaicatcher-pipeline
Version: 2.2.1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...