Multiplex Accurate Sensitive Quantitation (MASQ) analysis and primer design

public public 1yr ago 0 bookmarks

Multiplex Accurate Sensitive Quantitation (MASQ) Analysis and Primer Design Design Pipelines

Set-up

# Clone workflow into working directory
git clone https://github.com/amoffitt/MASQ /path/to/workdir
# Set-up virtual environment using conda
# Installs required Python and R packages
conda env create -f environment.yaml
conda activate MASQ
# Set-up reference genome files for examples
bash ./prepare_example_references.sh
# Edit configuration as needed for MASQ analysis
vi config.yaml
# Execute MASQ analysis workflow (in dry-run mode)
snakemake -n

MASQ Analysis

The MASQ analysis pipeline is contained in a Snakemake workflow (https://snakemake.readthedocs.io/en/stable/). The workflow is defined by rules in the Snakefile. Individual scripts called from the Snakefile are located in the scripts folder.

Example input files are included for testing of the workflow installation. Example files include small snippets of FASTQ and BAM files, and a corresponding example locus table. The config.yaml file included here is set-up to run the example files. To execute the workflow on the example, run the snakemake command from the working directory.

To run on a cluster, a cluster.yaml file can be added to specify parameters specific to your cluster setup, as described in the Snakemake documentation (https://snakemake.readthedocs.io/en/stable/snakefiles/configuration.html).

To run data other than the examples, edit the config.yaml file as necessary to point to FASTQ files, the loci table, WGS bam files, matching reference files, and other sample and run parameters.

MASQ Primer Design

The MASQ primer design pipeline is run with the following command:

cd primer_design
python select_enzymes_for_snps.py config.primerdesign.example.yaml 2>&1 | tee log.primerdesign.txt

Example configuration file and example input SNV lists are included in the primer_design folder. Enzyme cut site files are included for hg19. Edit the configuration file to point to different SNV files or change the primer design parameters.

BLAT and Primer3 are installed via the provided conda environment file, but can be installed separately and placed in the path.

Code Snippets

  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import logging

from masq_helper_functions import tabprint, hamming
from masq_helper_functions import hamming
from masq_helper_functions import hamming_withNs
from masq_helper_functions import test_pair
from masq_helper_functions import test_pair_withNs
from masq_helper_functions import reverseComplement
from masq_helper_functions import setup_logger, load_snv_table
########################################################################
# Start timer
t0 = time.time()
# Setup log file
log = setup_logger(snakemake.log,'all_base_report')
log.info('Starting process')

########################################################################
# INPUT FILES AND PARAMETERS
log.info('Getting input files and parameters from snakemake object')

# Read data stored in counters
vt_counter_filename = snakemake.input.vt_counter
vt_seq_counter_filename = snakemake.input.vt_seq_counter

# Input SNV table
SNV_table = open(snakemake.params.SNV_table,'r')

# WHICH REGION ARE WE PROCESSING
REGION = snakemake.params.region

# Sequence and hamming parameters
MAX_HAMMING_TARGET = snakemake.params.target_hamming
# Possible minimum coverage cutoffs
# need to see this many reads at that position to call a base for each tag
MIN_COVERS = snakemake.params.coverage_list
# Error rate allowed within reads of the same tag
MAX_ERROR = snakemake.params.base_error_rate
CUTOFF = (1. - MAX_ERROR)

# Masking low quality bases
MASK_LOWQUAL = snakemake.params.mask_lowqual_bases
MAX_N_RATIO = snakemake.params.max_N_ratio
FILTER_NS = snakemake.params.filter_ns


########################################################################
EXACT_COVERS = np.arange(1,51)

########################################################################
POSSIBLE_PROPORTIONS = []
PROPORTIONS_BY_COV = dict()
for coverage in MIN_COVERS:
    a=np.linspace(0,1,coverage+1)
    POSSIBLE_PROPORTIONS.extend(a)
    PROPORTIONS_BY_COV[coverage]=a
POSSIBLE_PROPORTIONS = np.unique(POSSIBLE_PROPORTIONS)

########################################################################
# OUTPUT FILES
log.info('Getting output files from snakemake object')

# All base report for specific region
# Original report style
base_count_per_pos_file = open(snakemake.output.base_count_report, 'w')
# One base per line (with >=X reads per tag, and =X reads per tag)
base_count_per_base_file = open(snakemake.output.base_count_report_per_base, 'w')

# Alignment counters report for region
alignment_rate_file = open(snakemake.output.alignment_report, 'w')

# Unaligned reads
unaligned_reads_file = open(snakemake.output.unaligned_reads, 'w')

# Within tag error
withintagerrors_file = open(snakemake.output.withintagerrors_table, 'w')


########################################################################
## load the genome sequence from a pickle file
log.info('Loading reference genome pickle')
seq_pickle = snakemake.params.ref_genome
seq_dic = pickle.load(open(seq_pickle, 'rb'))
log.info('Done loading reference genome pickle')

########################################################################
# Nucleotide <-> Integer mapping
NUCS = ["A", "C", "G", "T"]
BASES = ["A", "C", "G", "T", "N"]
BASE2INT = dict(x[::-1] for x in enumerate(BASES))

# FOR WITHIN TAG ERROR
REF_TRINUCS = [x+y+z for x in NUCS for y in NUCS for z in NUCS]
# All pairs of REFS and ALT trinucs
# PAIRED_TRINUCS=[ (a , a[0]+x+a[2], b) for b in ('R1','R2') for a in REF_TRINUCS for x in NUCS if x!=a[1] ]
PAIRED_TRINUCS=[ (a , a[0]+x+a[2]) for a in REF_TRINUCS for x in NUCS if x!=a[1] ]
# Trinuc to Index
TRINUC2INT = dict(x[::-1] for x in enumerate(PAIRED_TRINUCS)) # 192
# Array of error counts
ERR_RANGE1 = np.arange(-0.1,1.0,0.1)
ERR_RANGE2 = np.arange(0,1.01,0.1)
s1 = len(TRINUC2INT)
s2 = len(ERR_RANGE1)
s3 = len(('R1','R2'))
READS_PER_TAG_CUTOFF = 10

########################################################################
# Load sequence data
log.info('Loading read data from pickles')
vt_counter = pickle.load(open(vt_counter_filename, 'rb')) # only one region
vt_seq_counter = pickle.load(open(vt_seq_counter_filename, 'rb'))
log.info('Done loading read data from pickles')

########################################################################
# Write headers to output files
log.info('Writing headers to file')
# Alignment counters report for region
header=["Region Index","Ref Sequence Index","Number of Tags", "Number of Aligned Tags", "Fraction Aligned Tags","Number of Processed Reads", "Number of Aligned Reads", "Fraction Aligned Reads"]
alignment_rate_file.write(tabprint(header) + "\n")

# Original report style
position_header = ['target_base', 'locus_index', 'reference_index', 'target_strand', 'read_index', 'read_pos', 'template_pos', 'expected_read_base', 'expected_template_base', 'variant_read_base', 'variant_template_base']
hlists = [["".join([nuc, str(cover)]) for nuc in NUCS] for cover in MIN_COVERS]
header = position_header + [item for sublist in hlists for item in sublist]
base_count_per_pos_file.write(tabprint(header) + "\n")

# One base per line (with >=X reads per tag, and =X reads per tag)
position_header = ['target_base', 'locus_index', 'reference_index', 'target_strand', 'read_index', 'read_pos', 'template_pos', 'expected_read_base', 'expected_template_base', 'variant_read_base', 'variant_template_base', 'read_ref_trinuc','template_ref_trinuc']
base_header = ['read_alt_trinuc','template_alt_trinuc','read_alt_base','proportion_of_RPT','total_count','alt_count','at_least_x_reads','exactly_x_reads','proportions_exactly_x_reads']
header = position_header + base_header
base_count_per_base_file.write(tabprint(header) + "\n")


########################################################################
# Load the input SNV table
log.info('Loading SNV table')
snv_info = load_snv_table(SNV_table)
for key,value in snv_info.items():
    log.debug(key)
    log.debug(tabprint(value))
log.info('Done loading SNV table')

log.info('Parsing specific SNV info fields')
SP1 = snv_info['specific-primer-1']
SP2 = snv_info['specific-primer-2']
region_array = snv_info['trimmed-target-seq'] # target region seqeunce
target_locs_array = [list(map(int, x.split(";") ) ) if len(x)>0 else [] for x in snv_info['target_locs']]  # Store AML loci in target_locs_array
add_locs_array = [list(map(int, x.split(";") ) ) if len(x)>0 else [] for x in snv_info['add-targets']]  # Store other loci in add_locs_array
region_strand_array = snv_info['strand'] # strand of target region sequence
refbase = [x[0] for x in snv_info['ref-alt_allele']]
altbase = [x[2] for x in snv_info['ref-alt_allele']]

########################################################################
# Process region specific information, including indels
# Store multiple reference sequences and positions resulting from indels in list
log.info('Processing region specific info')
# Only processing one specific region in this script: REGION
region_index = int(REGION)
## load region specific information
region_strand = region_strand_array[region_index]
var_base = altbase[region_index]

# Now things that vary with indels
list_region_seq = [region_array[region_index]]
list_region_rc  = [reverseComplement(region_array[region_index])]
list_region_len = [len(region_array[region_index])]

## get the AML target positions (may be more than one)
# add additional loci to target list
## and their index for both forward (read1) and reverse (read2)
targets    = target_locs_array[region_index] + add_locs_array[region_index]
targets_rc = [len(list_region_seq[0]) - pos - 1 for pos in targets]
## aml only targets
aml_only    = target_locs_array[region_index]
aml_only_rc = [len(list_region_seq[0]) - pos - 1 for pos in aml_only]
# so we can add another entry if there is an indel
list_targets = [targets]
list_targets_rc = [targets_rc]
list_aml_only = [aml_only]
list_aml_only_rc = [aml_only_rc]

if 'indel_start' in snv_info:
    if len(snv_info['indel_start'][region_index])>0:
        log.warning('Indels included for this region. Assuming only one indel is present')

        istart = int(snv_info['indel_start'][region_index])
        ilen = int(snv_info['indel_length'][region_index])
        iseq = snv_info['indel_seq'][region_index]

        new_region_seq = list_region_seq[0]
        if ilen<0: # deletion
            new_region_seq = new_region_seq[0:istart] + new_region_seq[(istart-ilen):]
        else: # insertion
            new_region_seq = new_region_seq[0:istart] + iseq + new_region_seq[istart:]

        list_region_seq.append(new_region_seq)
        list_region_rc.append(reverseComplement(new_region_seq))
        list_region_len.append(len(new_region_seq))

        # Adjust targets
        new_targets = [x if x<istart else (x+ilen) for x in targets]
        new_targets_rc = [len(new_region_seq) - pos - 1 for pos in new_targets]
        new_aml_only = [x if x<istart else (x+ilen) for x in aml_only]
        new_aml_only_rc = [len(new_region_seq) - pos - 1 for pos in new_aml_only]
        list_targets.append(new_targets)
        list_targets_rc.append(new_targets_rc)
        list_aml_only.append(new_aml_only)
        list_aml_only_rc.append(new_aml_only_rc)

########################################################################
# Output variables
list_WITHIN_TAG_ERRS = []

# Loop over each possible reference sequence and go through tags to align
log.info('Parsing each tag and associated reads')
for refiter in range(len(list_targets)):
    log.info('Target sequence iteration: %d' % refiter)
    log.info('Parsing %d tags in vt_counter' % len(vt_counter))
    log.info('Parsing %d tags in vt_seq_counter' % len(vt_seq_counter))

    # Extract info for this reference sequence
    log.info('Extracting info specific to this reference sequence')
    region_seq = list_region_seq[refiter]
    log.info('Reference sequence: %s' % region_seq)
    region_rc = list_region_rc[refiter]
    region_len = list_region_len[refiter]
    targets = list_targets[refiter]
    targets_rc = list_targets_rc[refiter]
    aml_only = list_aml_only[refiter]
    aml_only_rc = list_aml_only_rc[refiter]

    # Max sequence lengths to count bases on
    # R1: min(regionlen, trimlen - SP1len)
    # R2: min(regionlen, trimlen - SP2len - UP2len - Taglen)
    L1 = min(region_len, snakemake.params.trim_len - len(SP1[region_index]))
    L2 = min(region_len, snakemake.params.trim_len - len(SP2[region_index]) - len(snakemake.params.UP2) - len(snakemake.params.tag))
    log.debug('L1 is %d' % L1)
    log.debug('L2 is %d' % L2)
    ## and establish default ranges
    range_L1 = np.arange(L1)
    range_L2 = np.arange(L2)

    ## and figure out which the AML target positions are active for each read
    r1_targets = [target for target in targets if target < L1]
    r2_targets = [target for target in targets_rc if target < L2]

    r1_aml_only = [target for target in aml_only if target < L1]
    r2_aml_only = [target for target in aml_only_rc if target < L2]

    # Setup counters
    log.info('Setting up counters')
    # Basic alignment counters
    tag_counter = 0
    all_counter = 0
    hit_counter = 0
    aligned_tag_counter = 0
    # Within tag errors
    WITHIN_TAG_ERRS=np.zeros(shape=(s1,s2,s3), dtype=int)
    # store count data at each position subject to the coverage constraints in MIN_COVERS
    # read pos x coverage values x base
    base_counter1 = np.zeros(shape=(L1, len(MIN_COVERS), 4), dtype=int)
    base_counter2 = np.zeros(shape=(L2, len(MIN_COVERS), 4), dtype=int)

    base_counter1_exact = np.zeros(shape=(L1, len(EXACT_COVERS), 4), dtype=int) # changed to include all values 1 to 50
    base_counter2_exact = np.zeros(shape=(L2, len(EXACT_COVERS), 4), dtype=int)

    base_counter1_proportions = np.zeros(shape=(L1, len(MIN_COVERS), 4, len(POSSIBLE_PROPORTIONS)), dtype=int)
    base_counter2_proportions = np.zeros(shape=(L2, len(MIN_COVERS), 4, len(POSSIBLE_PROPORTIONS)), dtype=int)

    ########################################################################

    ## for each tag
    for tag, tag_count in vt_counter.most_common():
        tag_counter += 1
        tag_aligned = False
        RPT = 0
        if tag_counter % 10000 == 0:
            log.info(tabprint([tag_counter,
                            np.max(base_counter1),
                            np.max(base_counter2),
                            all_counter,
                            hit_counter,
                            float(hit_counter) / max(1,all_counter)]))
        ## get sequence counter associated with the tag
        seq_counter = vt_seq_counter[tag]
        ## record aggregate base call data (tag-data-1 and -2)
        td1    = np.zeros(shape=(5, L1), dtype=int)
        td2    = np.zeros(shape=(5, L2), dtype=int)
        ## iterate through read pairs (r1, r2)
        for (r1, r2), seq_count in seq_counter.most_common():
            RPT += seq_count
            ## trim read pairs as needed to fit target region
            r1 = r1[:L1]
            r2 = r2[:L2]
            aligned=False
            # If reads are shorter than expected, add N's...
            if len(r1)<L1:
                R1_long=r1+''.join(['N' for z in range(len(r1),L1)])
                log.debug("R1 length modified: %s" % R1_long)
                r1=R1_long
            if len(r2)<L2:
                R2_long=r2+''.join(['N' for z in range(len(r2),L2)])
                log.debug("R2 length modified: %s" % R2_long)
                r2=R2_long

            all_counter += seq_count # number of reads seen


            ## measure the match to reference sequences
            if MASK_LOWQUAL or FILTER_NS:
                score = test_pair_withNs(r1, r2, region_seq, region_rc, maxNratio=MAX_N_RATIO)
                ## omit the target positions from the score
                # N's are not penalized so don't fix them
                for pos in r1_targets:
                    if (pos<len(r1)):
                        score -= int((r1[pos] != region_seq[pos]) and (r1[pos] != 'N'))
                for pos in r2_targets:
                    if (pos<len(r2)):
                        score -= int((r2[pos] != region_rc [pos]) and (r2[pos] != 'N'))
            else:
                score = test_pair(r1, r2, region_seq, region_rc)
                ## omit the target positions from the score
                for pos in r1_targets:
                    if (pos<len(r1)):
                        score -= int(r1[pos] != region_seq[pos])
                for pos in r2_targets:
                    if (pos<len(r2)):
                        score -= int(r2[pos] != region_rc [pos])


            if score  <= MAX_HAMMING_TARGET:
                hit_counter += seq_count # reads match sequence of interest!
                r1_int = [BASE2INT[x] for x in r1]
                r2_int = [BASE2INT[x] for x in r2]
                ## and add to the base call data according to the count
                np.add.at(td1, [r1_int, range_L1], seq_count)
                np.add.at(td2, [r2_int, range_L2], seq_count)
                aligned=True
                tag_aligned=True

            else: # If Hamming Distance wasn't a good match, try a 1-2 base shift in the sequence
                origscore = score
                r1_shifted = 'N' + r1[:-1] # shift one to the right
                r2_shifted = 'N' + r2[:-1] # shift one to the right
                score = test_pair_withNs(r1_shifted, r2_shifted, region_seq, region_rc, maxNratio=MAX_N_RATIO)
                for pos in r1_targets:
                    if (pos<len(r1)):
                        score -= int((r1_shifted[pos] != region_seq[pos]) and (r1_shifted[pos] != 'N'))
                for pos in r2_targets:
                    if (pos<len(r2)):
                        score -= int((r2_shifted[pos] != region_rc [pos]) and (r2_shifted[pos] != 'N'))
                if score  <= MAX_HAMMING_TARGET: # Now, after shifting, if the score is good enough...
                    log.debug("Left shift 1 worked")
                    hit_counter += seq_count # reads match sequence of interest!
                    r1_int = [BASE2INT[x] for x in r1_shifted]
                    r2_int = [BASE2INT[x] for x in r2_shifted]
                    ## and add to the base call data according to the count
                    np.add.at(td1, [r1_int, range_L1], seq_count) # problem here
                    np.add.at(td2, [r2_int, range_L2], seq_count)
                    aligned=True
                    tag_aligned=True
                else:
                    r1_shifted = r1[1:] + 'N' # shift one to the left
                    r2_shifted = r2[1:] + 'N' # shift one to the left
                    score = test_pair_withNs(r1_shifted, r2_shifted, region_seq, region_rc, maxNratio=MAX_N_RATIO)
                    for pos in r1_targets:
                        if (pos<len(r1)):
                            score -= int((r1_shifted[pos] != region_seq[pos]) and (r1_shifted[pos] != 'N'))
                    for pos in r2_targets:
                        if (pos<len(r2)):
                            score -= int((r2_shifted[pos] != region_rc [pos]) and (r2_shifted[pos] != 'N'))
                    if score  <= MAX_HAMMING_TARGET: # Now, after shifting, if the score is good enough...
                        log.debug("Right shift 1 worked")
                        hit_counter += seq_count # reads match sequence of interest!
                        r1_int = [BASE2INT[x] for x in r1_shifted]
                        r2_int = [BASE2INT[x] for x in r2_shifted]
                        ## and add to the base call data according to the count
                        np.add.at(td1, [r1_int, range_L1], seq_count)
                        np.add.at(td2, [r2_int, range_L2], seq_count)
                        aligned=True
                        tag_aligned=True

            # Write out unaligned reads
            if not aligned:
                unaligned_reads_file.write(str(REGION)+"\t"+r1+"\t"+r2+"\t"+"\t"+str(origscore)+"\t"+str(score)+"\t"+tag+"\n")

        if tag_aligned:
            aligned_tag_counter +=1
        ## now compute some statistics over the base call information
        total1   = np.sum(td1[:4], axis = 0)    # total coverage (r1) at each position
        max_ind1 = np.argmax(td1, axis=0)       # index of maximal base (r1)
        max_val1 = np.max(td1, axis=0)          # count value at maximal base (r1)
        total2   = np.sum(td2[:4], axis = 0)    # as above for r2.
        max_ind2 = np.argmax(td2, axis=0)
        max_val2 = np.max(td2, axis=0)

        log.debug(["tagalign",tag,str(tag_aligned),str(RPT)])

        ## for each coverage cut off
        for cind, EXACT_COVER in enumerate(EXACT_COVERS):

            # Also want to look at counts for exactly x reads per tag        
            if EXACT_COVER==EXACT_COVERS[-1]: # catch all for 50 or more
                strong1  = (total1 >= EXACT_COVER) * (max_val1 >= CUTOFF*total1) * (max_ind1 != 4)
                strong2  = (total2 >= EXACT_COVER) * (max_val2 >= CUTOFF*total2) * (max_ind2 != 4)
            else:
                strong1  = (total1 == EXACT_COVER) * (max_val1 >= CUTOFF*total1) * (max_ind1 != 4)
                strong2  = (total2 == EXACT_COVER) * (max_val2 >= CUTOFF*total2) * (max_ind2 != 4)

            base_counter1_exact[strong1, cind, max_ind1[strong1]] += 1
            base_counter2_exact[strong2, cind, max_ind2[strong2]] += 1


        for cind, MIN_COVER in enumerate(MIN_COVERS):
            ## test against coverage and base_ratio cutoffs
            # Check that max index is not an N
            strong1  = (total1 >= MIN_COVER) * (max_val1 >= CUTOFF*total1) * (max_ind1 != 4)
            strong2  = (total2 >= MIN_COVER) * (max_val2 >= CUTOFF*total2) * (max_ind2 != 4)
            ## increment those positions that pass the tests
            base_counter1[strong1, cind, max_ind1[strong1]] += 1
            base_counter2[strong2, cind, max_ind2[strong2]] += 1

            # Also want to keep track of bases that don't pass the 0.8 cutoff
            strong1  = (total1 == MIN_COVER) * (max_ind1 != 4)
            strong2  = (total2 == MIN_COVER) * (max_ind2 != 4)

            for B in range(0,4):
                p_obs1 = np.true_divide(td1[B,:],total1)
                p_obs2 = np.true_divide(td2[B,:],total2)

                for pind, PROPORTION in enumerate(POSSIBLE_PROPORTIONS):
                    if PROPORTION in PROPORTIONS_BY_COV[MIN_COVER]:

                            strong1p = (strong1) * (p_obs1==PROPORTION)
                            strong2p = (strong2) * (p_obs2==PROPORTION)

                            base_counter1_proportions[strong1p, cind, B, pind] += 1
                            base_counter2_proportions[strong2p, cind, B, pind] += 1


        #######################################################################
        # WITHIN TAG ERROR
        # For R1, use td1, region_seq
        # For R2, use td2, region_rc
        # Positions to skip: r1_targets, r2_targets

        if tag_count > READS_PER_TAG_CUTOFF:
            # R1
            fractions = td1[:4,]/total1 # Get A/C/G/T fractions from counts
            errind = np.ceil(fractions*10) # gets index for the 10 err bins
            for pos in range(1,(L1-1)): # skip first and last to get trinucs
                if pos not in r1_targets:
                    ref_base= region_seq[pos]
                    ref_trinuc= region_seq[(pos-1):(pos+2)]
                    alt_bases = [x for x in NUCS if x!=ref_base]

                    for alt in alt_bases:
                        alt_trinuc = ref_trinuc[0] + alt + ref_trinuc[2]
                        try:
                            e = int(errind[BASE2INT[alt],pos])
                        except:
                            e = None
                        if e is not None:
                            WITHIN_TAG_ERRS[TRINUC2INT[(ref_trinuc,alt_trinuc)],e,0] += 1


            # R2
            fractions = td2[:4,]/total2 # Get A/C/G/T fractions from counts
            errind = np.ceil(fractions*10) # gets index for the 10 err bins
            for pos in range(1,(L2-1)): # skip first and last to get trinucs
                if pos not in r2_targets:
                    ref_base= region_rc[pos]
                    ref_trinuc= region_rc[(pos-1):(pos+2)]
                    alt_bases = [x for x in NUCS if x!=ref_base]

                    for alt in alt_bases:
                        alt_trinuc = ref_trinuc[0] + alt + ref_trinuc[2]
                        try:
                            e = int(errind[BASE2INT[alt],pos])
                        except:
                            e = None
                        if e is not None:
                            WITHIN_TAG_ERRS[TRINUC2INT[(ref_trinuc,alt_trinuc)],e,1] += 1


    #######################################################################
    log.info('Done with all tags for this reference sequence')
    list_WITHIN_TAG_ERRS.append(WITHIN_TAG_ERRS)

    # Write counters results to file
    counters=[region_index, refiter, tag_counter, aligned_tag_counter, aligned_tag_counter / float(max(1,tag_counter)), all_counter, hit_counter, hit_counter / float(max(1,all_counter))]
    alignment_rate_file.write(tabprint(counters) + "\n")

    ########################################################################
    # ORIGINAL REPORT FORMAT

        ## write the output for each position over all coverage levels
    for pos, data in enumerate(base_counter1):
        # indicator for what type of variant position it is
        # 0 - ref base
        # 1 - non-aml variant base
        # 2 - AML base
        aml_loc = int(pos in r1_targets) + int(pos in r1_aml_only)
        if aml_loc==2:
            variant_read_base = var_base
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = pos
        outline = tabprint([aml_loc,
                            region_index,
                            refiter,
                            region_strand,
                            1,
                            pos,
                            pos_in_region,
                            region_seq[pos],
                            region_seq[pos_in_region],
                            variant_read_base,
                            variant_template_base] + list(data.ravel()))
        base_count_per_pos_file.write(outline)
        base_count_per_pos_file.write("\n")

    for pos, data in enumerate(base_counter2):
        aml_loc = int(pos in r2_targets) + int(pos in r2_aml_only)
        if aml_loc==2:
            variant_read_base = reverseComplement(var_base)
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = region_len - pos - 1
        outline = tabprint([aml_loc,
                            region_index,
                            refiter,
                            region_strand,
                            2,
                            pos,
                            pos_in_region,
                            region_rc[pos],
                            region_seq[pos_in_region],
                            variant_read_base,
                            variant_template_base] + list(data.ravel()))
        base_count_per_pos_file.write(outline)
        base_count_per_pos_file.write("\n")

    ########################################################################
    # NEW REPORT FORMAT
    ## write the output for each position over all coverage levels
    ## Separate one line per base and per coverage level
    ## Do for exactly and at least X reads per base

    ########################################################################
    # counts by proportion of tag
    print("BC1_proportions")
    print(base_counter1_proportions.shape)
    for pos, data in enumerate(base_counter1_proportions):
        aml_loc = int(pos in r1_targets) + int(pos in r1_aml_only)
        if aml_loc==2:
            variant_read_base = var_base
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = pos

        if pos==0:
            read_trinuc='X'+region_seq[0:(pos+2)]
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        elif pos==(len(region_seq)-1):
            read_trinuc=region_seq[(pos-1):(pos+2)]+'X'
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        else:
            read_trinuc=region_seq[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]

        print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
        print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(MIN_COVERS):
            for nindex,nuc in enumerate(NUCS):
                for pind, PROPORTION in enumerate(POSSIBLE_PROPORTIONS):
                    if PROPORTION in PROPORTIONS_BY_COV[cov]:

                        count = data[(cindex,nindex,pind)]
                        total = np.sum(data[cindex,0]) # sum up proportions for A (all tags are either 100% A, 50% A or 0% A for 2 RPT)

                        alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                        alt_template_trinuc = template_trinuc[0] + nuc + template_trinuc[2]

                        baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,PROPORTION,total,count,'.','.',cov] 

                        outline = tabprint([aml_loc,
                                            region_index,
                                            refiter,
                                            region_strand,
                                            1,
                                            pos,
                                            pos_in_region,
                                            region_seq[pos], # read base
                                            region_seq[pos_in_region], # template base
                                            variant_read_base,
                                            variant_template_base,
                                            read_trinuc,
                                            template_trinuc] +
                                            baseinfo)
                        base_count_per_base_file.write(outline)
                        base_count_per_base_file.write("\n")
    ##########################################
    print("BC2_proportions")
    print(base_counter2_proportions.shape)
    for pos, data in enumerate(base_counter2_proportions):
        aml_loc = int(pos in r2_targets) + int(pos in r2_aml_only)
        if aml_loc==2:
            variant_read_base = reverseComplement(var_base)
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = region_len - pos - 1

        if pos==0:
            read_trinuc='X'+region_rc[0:(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        elif pos==(len(region_seq)-1):
            read_trinuc=region_rc[(pos-1):(pos+2)]+'X'
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        else:
            read_trinuc=region_rc[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]
            print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
            print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(MIN_COVERS):
            for nindex,nuc in enumerate(NUCS):
                for pind, PROPORTION in enumerate(POSSIBLE_PROPORTIONS):
                    if PROPORTION in PROPORTIONS_BY_COV[cov]:

                        count = data[(cindex,nindex,pind)]
                        total = np.sum(data[cindex,0])

                        alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                        alt_template_trinuc = template_trinuc[0] + reverseComplement(nuc) + template_trinuc[2]

                        baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,PROPORTION,total,count,'.','.',cov] 

                        outline = tabprint([aml_loc,
                                            region_index,
                                            refiter,
                                            region_strand,
                                            2,
                                            pos,
                                            pos_in_region,
                                            region_rc[pos], # read base
                                            region_seq[pos_in_region], # template base
                                            variant_read_base,
                                            variant_template_base,
                                            read_trinuc,
                                            template_trinuc] +
                                            baseinfo)
                        base_count_per_base_file.write(outline)
                        base_count_per_base_file.write("\n")



    ########################################################################
    ## EXACTLY X reads per base
    print(region_seq)
    print(len(region_seq))

    print("BC1exact")
    print(base_counter1_exact.shape)
    for pos, data in enumerate(base_counter1_exact):
        aml_loc = int(pos in r1_targets) + int(pos in r1_aml_only)
        if aml_loc==2:
            variant_read_base = var_base
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = pos

        if pos==0:
            read_trinuc='X'+region_seq[0:(pos+2)]
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        elif pos==(len(region_seq)-1):
            read_trinuc=region_seq[(pos-1):(pos+2)]+'X'
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        else:
            read_trinuc=region_seq[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]

        print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
        print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(EXACT_COVERS):
            for nindex,nuc in enumerate(NUCS):
                count = data[(cindex,nindex)]
                total = np.sum(data[cindex])

                alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                alt_template_trinuc = template_trinuc[0] + nuc + template_trinuc[2]

                baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,'.',total,count,'.',cov,'.'] 

                outline = tabprint([aml_loc,
                                    region_index,
                                    refiter,
                                    region_strand,
                                    1,
                                    pos,
                                    pos_in_region,
                                    region_seq[pos], # read base
                                    region_seq[pos_in_region], # template base
                                    variant_read_base,
                                    variant_template_base,
                                    read_trinuc,
                                    template_trinuc] +
                                    baseinfo)
                base_count_per_base_file.write(outline)
                base_count_per_base_file.write("\n")

    print("BC2exact")
    print(base_counter2_exact.shape)
    for pos, data in enumerate(base_counter2_exact):
        aml_loc = int(pos in r2_targets) + int(pos in r2_aml_only)
        if aml_loc==2:
            variant_read_base = reverseComplement(var_base)
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = region_len - pos - 1

        if pos==0:
            read_trinuc='X'+region_rc[0:(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        elif pos==(len(region_seq)-1):
            read_trinuc=region_rc[(pos-1):(pos+2)]+'X'
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        else:
            read_trinuc=region_rc[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]
            print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
            print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(EXACT_COVERS):
            for nindex,nuc in enumerate(NUCS):
                count = data[(cindex,nindex)]
                total = np.sum(data[cindex])

                alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                alt_template_trinuc = template_trinuc[0] + reverseComplement(nuc) + template_trinuc[2]

                baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,'.',total,count,'.',cov,'.'] 

                outline = tabprint([aml_loc,
                                    region_index,
                                    refiter,
                                    region_strand,
                                    2,
                                    pos,
                                    pos_in_region,
                                    region_rc[pos], # read base
                                    region_seq[pos_in_region], # template base
                                    variant_read_base,
                                    variant_template_base,
                                    read_trinuc,
                                    template_trinuc] +
                                    baseinfo)
                base_count_per_base_file.write(outline)
                base_count_per_base_file.write("\n")

    ## AT LEAST X reads per base
    print("BC1over")
    print(base_counter1.shape)
    for pos, data in enumerate(base_counter1):
        aml_loc = int(pos in r1_targets) + int(pos in r1_aml_only)
        if aml_loc==2:
            variant_read_base = var_base
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = pos

        if pos==0:
            read_trinuc='X'+region_seq[0:(pos+2)]
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        elif pos==(len(region_seq)-1):
            read_trinuc=region_seq[(pos-1):(pos+2)]+'X'
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        else:
            read_trinuc=region_seq[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]
        print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
        print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(MIN_COVERS):
            for nindex,nuc in enumerate(NUCS):
                count = data[(cindex,nindex)]
                total = np.sum(data[cindex])

                alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                alt_template_trinuc = template_trinuc[0] + nuc + template_trinuc[2]

                baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,'.',total,count,cov,'.','.'] 

                outline = tabprint([aml_loc,
                                    region_index,
                                    refiter,
                                    region_strand,
                                    1,
                                    pos,
                                    pos_in_region,
                                    region_seq[pos], # read base
                                    region_seq[pos_in_region], # template base
                                    variant_read_base,
                                    variant_template_base,
                                    read_trinuc,
                                    template_trinuc] +
                                    baseinfo)
                base_count_per_base_file.write(outline)
                base_count_per_base_file.write("\n")

    print("BC2over")
    print(base_counter2.shape)
    for pos, data in enumerate(base_counter2):
        aml_loc = int(pos in r2_targets) + int(pos in r2_aml_only)
        if aml_loc==2:
            variant_read_base = reverseComplement(var_base)
            variant_template_base = var_base
        else:
            variant_read_base = '-'
            variant_template_base = '-'
        pos_in_region = region_len - pos - 1

        if pos==0:
            read_trinuc='X'+region_rc[0:(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]+'X'
        elif pos==(len(region_seq)-1):
            read_trinuc=region_rc[(pos-1):(pos+2)]+'X'
            template_trinuc='X'+region_seq[0:(pos_in_region+2)]
        else:
            read_trinuc=region_rc[(pos-1):(pos+2)]
            template_trinuc=region_seq[(pos_in_region-1):(pos_in_region+2)]
        print("Position: %d: read_trinuc: %s" % (pos,read_trinuc))
        print("Position: %d: template_trinuc: %s" % (pos,template_trinuc))

        for cindex,cov in enumerate(MIN_COVERS):
            for nindex,nuc in enumerate(NUCS):
                count = data[(cindex,nindex)]
                total = np.sum(data[cindex])

                alt_read_trinuc = read_trinuc[0] + nuc + read_trinuc[2]
                alt_template_trinuc = template_trinuc[0] + reverseComplement(nuc) + template_trinuc[2]

                baseinfo = [alt_read_trinuc,alt_template_trinuc,nuc,'.',total,count,cov,'.','.'] 

                outline = tabprint([aml_loc,
                                    region_index,
                                    refiter,
                                    region_strand,
                                    2,
                                    pos,
                                    pos_in_region,
                                    region_rc[pos], # read base
                                    region_seq[pos_in_region], # template base
                                    variant_read_base,
                                    variant_template_base,
                                    read_trinuc,
                                    template_trinuc] +
                                    baseinfo)
                base_count_per_base_file.write(outline)
                base_count_per_base_file.write("\n")



########################################################################
# Write the within tag errors to pickle file and to table
# Just save this for the first reference sequence.... for now
log.info('Writing within tag error to a file')
WITHIN_TAG_ERRS = list_WITHIN_TAG_ERRS[0]
pickle.dump(WITHIN_TAG_ERRS, open(snakemake.output.withintagerrors, 'wb'), pickle.HIGHEST_PROTOCOL)
for i,pt in enumerate(PAIRED_TRINUCS):
    withintagerrors_file.write(tabprint(['R1']+list(pt)+list(WITHIN_TAG_ERRS[i,:,0]))+'\n')
    withintagerrors_file.write(tabprint(['R2']+list(pt)+list(WITHIN_TAG_ERRS[i,:,1]))+'\n')

#######################################################################
# Close output files
base_count_per_pos_file.close()
base_count_per_base_file.close()
alignment_rate_file.close()
withintagerrors_file.close()
unaligned_reads_file.close()

########################################################################
# End timer
t1 = time.time()
td = (t1 - t0) / 60
log.info("Done in %0.2f minutes" % td)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import os,sys
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import editdistance
import numpy as np
import pysam
from masq_helper_functions import tabprint
from masq_helper_functions import reverseComplement
from masq_helper_functions import convert_cigar_string
from masq_helper_functions import setup_logger, load_snv_table, write_snv_table

########################################################################
# Start timer
t0 = time.time()
# Setup log file
log = setup_logger(snakemake.log,'check_loci')
log.info('Starting process')

########################################################################
# INPUT FILES AND PARAMETERS
log.info('Getting input files and parameters from snakemake object')
# WGS path to bam
# Now allows multiple BAMs
if isinstance(snakemake.input.bam,list) and isinstance(snakemake.params.wgs_name,list):
    WGS_BAM = snakemake.input.bam
    WGS_NAME = snakemake.params.wgs_name
elif isinstance(snakemake.input.bam,str) and isinstance(snakemake.params.wgs_name,str):
    WGS_BAM = [snakemake.input.bam]
    WGS_NAME = [snakemake.params.wgs_name]
else:
    logger.error("WGS_BAM and WGS_BAM should be lists or strings")

# # Input SNV table
SNV_table = open(snakemake.input.SNV_table,'r')

########################################################################
# OUTPUT FILES
log.info('Getting output files from snakemake object')
# Output plots
list_of_plot_files = snakemake.output.plots
# Updated SNP file
updated_SNV_table = snakemake.output.new_SNV_table

# Make plot folder if it doesn't exist
plotfolder = os.path.dirname(list_of_plot_files[0])
os.makedirs(plotfolder, exist_ok=True)

########################################################################
# Plotting colors
BASE_COLORS = ["#00ABBA",
               "#9ACA3C",
               "#F26421",
               "#672D8F"]

########################################################################
## load the genome sequence from a pickle file
log.info('Loading reference genome pickle')
seq_pickle = snakemake.params.wgs_ref
seq_dic = pickle.load(open(seq_pickle, 'rb'))
log.info('Done loading reference genome pickle')

########################################################################
# Load the input SNV table
log.info('Loading SNV table')
snv_info = load_snv_table(SNV_table)
for key,value in snv_info.items():
    log.debug(key)
    log.debug(tabprint(value))
log.info('Done loading SNV table')

########################################################################

# Process input file to extend it with strand info and get coordinates of seq
log.info('Parsing specific SNV info fields')

target_info = []
target_locs_array = [list(map(int, x.split(";") ) ) if len(x)>0 else [] for x in snv_info['target_locs']]

for i,loc in enumerate(snv_info['loc']):
    chrom=snv_info['chr'][i]
    pos = int(snv_info['posi'][i])
    target_locs = target_locs_array[i]
    aml_loc = target_locs[0]
    log.debug('Position: %d' % pos)
    log.debug('AML Loc: %d' % aml_loc)
    int_seq = snv_info['trimmed-target-seq'][i]
    length = len(int_seq)

    if chrom=="0":
        strand="+"
        start=1
        end=100
        targets=target_locs
    elif ('strand' in snv_info.keys()) and ('fragment-start' in snv_info.keys()) and ('fragment-end' in snv_info.keys()) and ('add-targets' in snv_info.keys()):
        strand=snv_info['strand'][i]
        start=snv_info['fragment-start'][i]
        end=snv_info['fragment-end'][i]
        end=snv_info['add-targets'][i]
    else:
        log.info('Inferring strand, start, and end from match to reference genome')
        # identifying start and end genomic positions of the targeted region
        # and getting target sequence from ref genome
        top_start = pos - 1 - aml_loc
        top_end = top_start + length
        top_match = seq_dic[chrom][top_start:top_end]
        # if the target sequence was on the bottom strand...
        bottom_start = pos - (length - aml_loc)
        bottom_end = bottom_start + length
        bottom_match = reverseComplement(seq_dic[chrom][bottom_start:bottom_end])

        log.debug('Target seq : %s' % int_seq)
        log.debug('Top match  : %s' % top_match)
        log.debug('Btm match  : %s' % bottom_match)

        # check which strand the sequence came from and set coordinates
        if top_match == int_seq:
            log.debug('Locus %s: positive strand' % (loc))
            strand = "+"
            start = top_start
            end = top_end
            targets = target_locs
        else:
            log.debug('Locus %s: negative strand' % (loc))
            strand = "-"
            start = bottom_start
            end = bottom_end
            targets = [length - x - 1 for x in target_locs]

    target_info.append([chrom, strand, start, end, targets])

########################################################################
# Now go through WGS data and add non-ref bases to the input file...
log.info('Finding additional variant positions from WGS BAM')
BASES = ["A", "C", "G", "T"]
#sample_names = [WGS_NAME] # EDIT TO ALLOW MULTIPLE BAMS...
BASE2INT = dict([x[::-1] for x in enumerate(BASES)])
MIN_QUAL = 20
MIN_MAP = 40
buff = 30
region_buffer = 30


all_targets = []
loc_ind = 0  # counter for which loci we're on
#
# load each region from the target info read in and processed
for chrom, strand, true_start, true_end, true_targets in target_info:

    if chrom=="0": # not human / mouse sequence (GFP seq for example)
        all_targets.append([])
        image_filename = list_of_plot_files[loc_ind]
        fig = plt.figure(figsize=(20, 18))
        plt.savefig(image_filename, dpi=200, facecolor='w', edgecolor='w',
                    papertype=None, format=None,
                    transparent=False)
        plt.close()
        loc_ind += 1
    else:

        log.debug('Locus %d' % (loc_ind+1))
        image_filename = list_of_plot_files[loc_ind]
        loc_ind += 1

        # add buffer to target sequence on either side and update values
        start = true_start - region_buffer
        end = true_end + region_buffer
        targets = [x + region_buffer for x in true_targets]
        # get local sequence
        local_seq = seq_dic[chrom][start:end]
        log.debug('Chrom: %s' % chrom)
        log.debug('start: %d' % start)
        log.debug('end: %d' % end)
        log.debug('Local_seq: %s' % local_seq)
        L = len(local_seq)
        # also in integer form
        local_int = np.array([BASE2INT[x] for x in local_seq])
        # intialize for keeping track of variants at this position
        has_something = np.zeros(L, dtype=bool)

        # for each sample
        fig = plt.figure(figsize=(20, 18))
        locus_string = "%s:%d-%d %s" % (chrom, start, end, strand)
        fig.suptitle(locus_string, fontsize=20)
        ind = 1
        for sample_name,sample_bam_fn in zip(WGS_NAME,WGS_BAM):
            log.info('Plotting locus %d' % loc_ind)
            # Set up the plot
            ax1 = fig.add_subplot(len(WGS_NAME), 1, ind)
            ax2 = ax1.twinx()
            ax1.tick_params(axis='both', labelsize=15)
            ax2.tick_params(axis='both', labelsize=15)
            ind += 1

            # open the bam file
            sample_bam = pysam.AlignmentFile(sample_bam_fn, "rb")

            # for each position of interest
            # get an iterator that walks over the reads that include our event
            read_iterator = sample_bam.fetch(
                chrom, start - buff, end + buff)  # ref has chr1,chr2,etc.

            # store as pairs
            read_names = []
            read_pair_dict = dict()
            for read in read_iterator:
                try:
                    read_pair_dict[read.qname].append(read)  # if pair is there
                except:
                    read_pair_dict[read.qname] = [read]  # if pair is absent
                    read_names.append(read.qname)  # keys to dictionary

            N = len(read_names)  # number of read pairs in this region
            # store an integer at every position for each read(pair) in interval
            # matrix (number of reads (N) by number of positions (end-start))
            read_stack = np.zeros(shape=(N, end - start), dtype=int)
            # get all read pairs
            for rcounter, read_name in enumerate(read_names):
                reads = read_pair_dict[read_name]
                for read in reads:
                    read_sequence = read.query_sequence
                    read_qualities = read.query_qualities
                    # get_aligned_pairs:
                    # returns paired list of ref seq position and read seq position
                    for rindex, pindex in read.get_aligned_pairs():
                        # rindex and pindex return ALL positions, including soft-clipped
                        # rindex: position in read (1-150 for example)
                        # pindex: position in reference genome
                        # if there's a base in the read
                        # and we are in range
                        if rindex is not None and pindex is not None:
                            # Separated these lines, added pindex None check
                            if pindex >= start and pindex < end:
                                # compute the base score
                                base_score = read_qualities[rindex]
                                # if it's good enough
                                if base_score >= MIN_QUAL:
                                    # check if there is already a value in place
                                    current = read_stack[rcounter, pindex - start]
                                    base = read_sequence[rindex]
                                    if base == "N":
                                        baseint = 0
                                    else:
                                        # A-1, C-2, G-3, T-4
                                        baseint = BASE2INT[base] + 1
                                    # if there is no value, store the value
                                    if current == 0:
                                        read_stack[rcounter, pindex - start] = baseint
                                    else:
                                        # if there is a mismatch between the two reads,
                                        # set value back to 0
                                        # this value is just for the 2 paired reads
                                        if current != baseint:
                                            read_stack[rcounter, pindex - start] = 0
            summary = []
            # iterating over numpy array - by rows
            # transpose first to iterate over positions
            for x in read_stack.transpose():
                # gets counts of N,A,C,G,T per position as array
                # append to summary list
                summary.append(np.bincount(x, minlength=5))
            # convert summary to array, fills by row
            # drop the N count, and transpose again
            # .T tranposes, only if num dimensions>1
            summary = np.array(summary)[:, 1:].T
            # now we have base (4) by position array as summary
            # base_cover: coverage at each position (sum of A/C/G/T counts)
            base_cover = np.sum(summary, axis=0)
            # base_ratio: A/C/G/T counts over total coverage
            # aka frequency of each base
            base_ratio = summary.astype(float) / np.maximum(1, base_cover)
            # update has something
            # EDIT - 0 coverage is not has_something??
            has_something += ( (base_ratio[local_int, np.arange(L)] < 0.9) & (base_cover>3) ) # reference ratio is less than 0.9

            ########################################################################
            # Plot variants in each region

            # plot the coverage first
            ax1.plot(base_cover, color='k', label='coverage', alpha=0.5)
            # draw lines for boundaries of event
            # and targets
            if strand == "+":
                ax2.axvline(region_buffer-0.5, color='g', lw=2)
                ax2.axvline(L-region_buffer-0.5, color='g', linestyle='--', lw=2)
            else:
                ax2.axvline(region_buffer-0.5, color='g', linestyle='--', lw=2)
                ax2.axvline(L-region_buffer-0.5, color='g', lw=2)
            for pos in targets:
                ax1.axvline(pos, color='r', lw=2)

            # Plot colored circle for each base at position vs. frequency
            for BASE_COLOR, BASE, yvals in zip(BASE_COLORS, BASES, base_ratio):
                ax2.plot(yvals, 'o', markersize=13, label=BASE, color=BASE_COLOR)
                base_filter = local_int == BASE2INT[BASE]  # same as ref base
                x = np.arange(L)
                # label ref bases with  black circle
                ax2.plot(x[base_filter], yvals[base_filter], 'o',
                         markersize=13, mfc="None", mec='black', mew=2)
            ax1.set_ylabel(sample_name, fontsize=25)

            # for non-ref sites...
            # label number of reads for non-ref base in red
            for hs_ind in np.where(has_something)[0]:
                base_counts, total = summary[:, hs_ind], base_cover[hs_ind]
                ref_base = local_int[hs_ind]
                for bind in range(4):
                    # loop through A/C/G/T, find the non-zero counts
                    if bind != ref_base and base_counts[bind] > 0:
                        # text_string = r"$\frac{%d}{%d}$" % (
                            # base_counts[bind],total)
                        # text_string = r"%d/%d" % (base_counts[bind], total)
                        text_string = r"%d" % (base_counts[bind])
                        # add the count to the plot
                        ax2.text(hs_ind + 1, base_ratio[bind, hs_ind], text_string,
                                 fontsize=20, color='red', weight='semibold')
            ax2.set_xlim(-10, 10+L)
            ax2.set_ylim(0, 1)
            ax2.legend(loc="upper center", numpoints=1, ncol=4,
                       fontsize=18, columnspacing=0.8, handletextpad=-0.2)
            ax1.legend(loc="lower center", numpoints=2,
                       fontsize=18, columnspacing=0.5)

        plt.tight_layout(rect=(0, 0, 1, 0.98))
        plt.savefig(image_filename, dpi=200, facecolor='w', edgecolor='w',
                    papertype=None, format=None,
                    transparent=False)
        plt.close()

    ########################################################################
        # Add non-ref het/homo sites to "target" list
        rlen = true_end - true_start
        new_targets = np.where(has_something)[0] - region_buffer
        if strand == "-":
            new_targets = true_end - true_start - new_targets - 1

        # new_targets>=0 (in original target region)
        new_targets = new_targets[(new_targets >= 0) * (new_targets < rlen)]
        all_targets.append(new_targets)

    ########################################################################

# Add new het/homo non-ref sites to original input file
log.info('Writing updated SNV info table')
outfile = open(updated_SNV_table, 'w')

snv_info['add-targets']=[]
snv_info['strand']=[]
snv_info['fragment-start']=[]
snv_info['fragment-end']=[]
print("Start loop")
for i, (new_targets, more_info) in enumerate(zip(all_targets, target_info)):
    print(i)
    prev_targets = list(map(int, snv_info['target_locs'][i].split(";")))
    add_targets = [x for x in new_targets if x not in prev_targets]
    snv_info['add-targets'].append(";".join(list(map(str, add_targets))))
    snv_info['strand'].append(more_info[1])
    snv_info['fragment-start'].append(more_info[2])
    snv_info['fragment-end'].append(more_info[3])
print("End loop")
print(snv_info)

write_snv_table(snv_info,outfile)

########################################################################
# Close files
outfile.close()

########################################################################
# End timer
t1 = time.time()
td = (t1 - t0) / 60
log.info("Done in %0.2f minutes" % td)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint
from masq_helper_functions import cluster_rollup2

########################################################################

# Start timer
t0_all = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Filenames and parameters
vt_counter_filename = snakemake.input.vt_counter
vt_seq_counter_filename = snakemake.input.vt_seq_counter
flip_counter_filename = snakemake.input.flip_counter

DNA_INPUT_NANOGRAMS = snakemake.params.dna_input_ng

# WHICH REGION ARE WE PROCESSING
REGION = snakemake.params.region
# Sample name
sample = snakemake.params.sample

# Output report file
outfilename = snakemake.output.tagcounts
outfile = open(outfilename,"w")

# New counters
new_vt_counter_filename = snakemake.output.vt_counter
new_vt_seq_counter_filename = snakemake.output.vt_seq_counter
new_flip_counter_filename = snakemake.output.flip_counter

########################################################################

# Load sequence data
t0 = time.time()
print("loading sequence data...")
vt_counter = pickle.load(open(vt_counter_filename, 'rb')) # only one region
vt_seq_counter = pickle.load(open(vt_seq_counter_filename, 'rb'))
flip_counter = pickle.load(open(flip_counter_filename, 'rb'))
print("data loaded in %0.2f seconds" % (time.time() - t0))

########################################################################

# Do cluster rollup 
new_vt_counter, new_vt_seq_counter, unique_list, unique_count, match_dict, new_flip_counter = cluster_rollup2(vt_counter,vt_seq_counter, flip_counter, show_progress=False)

########################################################################

# Save new counters
pickle.dump(new_vt_counter, open(new_vt_counter_filename, 'wb'), pickle.HIGHEST_PROTOCOL)
pickle.dump(new_vt_seq_counter, open(new_vt_seq_counter_filename, 'wb'), pickle.HIGHEST_PROTOCOL)
pickle.dump(new_flip_counter, open(new_flip_counter_filename, 'wb'), pickle.HIGHEST_PROTOCOL)

########################################################################

# Calculate tag counts for report
num_obs_tags = sum(vt_counter.values())
num_uniq_tags = len(vt_counter)
num_collapsed_tags = len(new_vt_counter)
########################################################################

# Report on original tags, unique tags, and rolled up tags
outfile.write(tabprint(["Region","Observed Tags", "Unique Tags", "Rolled-Up Tags",
                        "Avg Reads Per Unique Tag","Avg Reads Per Rolled-Up Tag",
                        "Fraction of Unique Tags that are Rolled-Up","Yield: Rolled-Up Tags"])+"\n")
rolledupyield=float(num_collapsed_tags)/( DNA_INPUT_NANOGRAMS/(3.59*0.001) ) 
counts = [int(REGION),num_obs_tags,num_uniq_tags,num_collapsed_tags,
          float(num_obs_tags)/max(1,num_uniq_tags), float(num_obs_tags)/max(1,num_collapsed_tags),
          float(num_collapsed_tags)/max(1,num_uniq_tags),rolledupyield]
outfile.write(tabprint(counts)+"\n")

########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint

########################################################################

# Start timer
t0 = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Input is list of region specific report files which all have a header
# Output is combined report files with one header

input_file_list = snakemake.input.region_reports
output_file = snakemake.output.combined_report
outfile = open(output_file,"w")

counter=0
for f in input_file_list:
    counter +=1
    infile=open(f,"r")
    linecounter=1
    for line in infile:
        if linecounter==1:
            if counter==1: # first file header
                outfile.write(line)
        else:
            outfile.write(line)
        linecounter+=1
    infile.close()
outfile.close()


########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import dill as pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint

########################################################################

# Start timer
t0 = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Input is all the pickle files with error tables
# Output is 2 table files - counts, and converted to fractions

input_file_list = snakemake.input.within_tag_error_pickles
output_file1 = snakemake.output.within_tag_table1
output_file2 = snakemake.output.within_tag_table2

outfile1 = open(output_file1,"w")
outfile2 = open(output_file2,"w")
########################################################################
NUCS = ["A", "C", "G", "T"]
REF_TRINUCS = [x+y+z for x in NUCS for y in NUCS for z in NUCS]
PAIRED_TRINUCS=[ (a , a[0]+x+a[2]) for a in REF_TRINUCS for x in NUCS if x!=a[1] ]
ERR_RANGE1 = np.arange(-0.1,1.0,0.1)

# Intiialize table
WITHIN_TAG_ERRS_ALL = np.zeros(shape=(len(PAIRED_TRINUCS),len(ERR_RANGE1),2), dtype=int)

for f in input_file_list:
    # Add to current table
    WITHIN_TAG_ERRS_REGION = pickle.load(open(f,'rb'))
    WITHIN_TAG_ERRS_ALL = WITHIN_TAG_ERRS_ALL  +  WITHIN_TAG_ERRS_REGION

# Original table
for i,pt in enumerate(PAIRED_TRINUCS):
    outfile1.write(tabprint(['R1']+list(pt)+list(WITHIN_TAG_ERRS_ALL[i,:,0]))+'\n')
    outfile1.write(tabprint(['R2']+list(pt)+list(WITHIN_TAG_ERRS_ALL[i,:,1]))+'\n')

# Convert to fractions, skip the 0 error bin
for i,pt in enumerate(PAIRED_TRINUCS):
    fracR1=WITHIN_TAG_ERRS_ALL[i,1:,0]/WITHIN_TAG_ERRS_ALL[i,1:,0].sum(keepdims=True)
    fracR2=WITHIN_TAG_ERRS_ALL[i,1:,1]/WITHIN_TAG_ERRS_ALL[i,1:,1].sum(keepdims=True)
    outfile2.write(tabprint(['R1']+list(pt)+list(fracR1))+'\n')
    outfile2.write(tabprint(['R2']+list(pt)+list(fracR2))+'\n')


########################################################################

# Close output files
outfile1.close()
outfile2.close()

########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm


from masq_helper_functions import tabprint

########################################################################

# Start timer
t0 = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Extract only the variant bases to a new report
# Could add additional calculations here if necessary

input_file = snakemake.input.combined_report
output_file = snakemake.output.variant_report
outfile = open(output_file,"w")

infile=open(input_file,"r")
linecounter=1
for x in infile:
    line = x.strip().split()
    if linecounter==1: # header
        outfile.write(x)
    elif int(line[0])== 2:
        outfile.write(x)
    linecounter+=1
infile.close()
outfile.close()


########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint
from masq_helper_functions import setup_logger

########################################################################
# Start timer
t0 = time.time()
# Setup log file
log = setup_logger(snakemake.log,'check_loci')
log.info('Starting process')

########################################################################

# Input is list of region specific report files which all have a header
# Output is combined report files with one header
log.info('Combining reports')
header_all = []
data_all = []

with open(snakemake.input.input_snv_table,'r') as f:
    header = f.readline().strip("\n").split("\t")
    header_all.extend(header)
    for line in f:
        x = line.strip("\n").split("\t")
        data_all.append(x)

print(data_all)

with open(snakemake.input.report_primers,'r') as f:
    c=0
    header1=f.readline()
    summarydata=f.readline()
    f.readline()
    header=f.readline().strip("\n").split("\t")[1:]
    header_all.extend(header)
    for line in f:
        x = line.strip("\n").split("\t")[1:]
        if not(line.startswith("NONE")):
            data_all[c].extend(x)
            c+=1

with open(snakemake.input.report_rollup,'r') as f:
    c=0
    header = f.readline().strip("\n").split("\t")[1:]
    header_all.extend(header)
    for line in f:
        x = line.strip("\n").split("\t")[1:]
        data_all[c].extend(x)
        c+=1

with open(snakemake.input.report_alignment,'r') as f:
    header = f.readline().strip("\n").split("\t")[2:]
    header_all.extend(header)
    align_data=dict()
    regct=0
    for line in f:
        x = line.strip("\n").split("\t")
        reg=int(x[0])
        if reg in align_data:
            align_data[reg] = [int(x[2]),int(x[3]),float(x[4]),int(x[5]),int(x[6]) + align_data[reg][4] ,float(x[7]) + align_data[reg][5]]
        else:
            align_data[reg] = [int(x[2]),int(x[3]),float(x[4]),int(x[5]),int(x[6]),float(x[7])]
            regct+=1
    c=0
    for reg in range(regct):
        y=align_data[reg]
        data_all[c].extend(y)
        c+=1


BASES = ["A", "C", "G", "T"]
BASE2INT = dict([x[::-1] for x in enumerate(BASES)])
with open(snakemake.input.report_variants,'r') as f:
    var_data = dict()
    header = f.readline().strip("\n").split("\t")
    header_all.extend(['strand','read_index','read_pos','template_pos', 'expected_read_base','expected_template_base','variant_read_base','variant_template_base','A1','C1','G1','T1','A2','C2','G2','T2','VarAF'])
    regct=0
    for line in f:
        x = line.strip("\n").split("\t")

        locus = x[1]
        ref_index = x[2]
        strand = x[3]
        read = x[4]
        poss = x[5:7]
        bases = x[7:11]
        onecounts = list(map(int,x[11:15]))
        twocounts = list(map(int,x[15:19]))
        if (sum(twocounts)>0) and (bases[2] in BASES):
            altbase = BASE2INT[bases[2]]
            varaf = float(twocounts[altbase])/sum(twocounts)
        else:
            varaf = 0

        if locus in var_data:
            if (var_data[locus][0] != ref_index) and (var_data[locus][2] == read): # combine counts
                prevcounts1 = var_data[locus][9:13]
                prevcounts2 = var_data[locus][13:17]
                var_data[locus] = [ref_index,strand,read]
                var_data[locus].extend(poss)
                var_data[locus].extend(bases)
                newcounts1 = [a+b for a,b in zip(onecounts,prevcounts1)]
                newcounts2 = [a+b for a,b in zip(twocounts,prevcounts2)]
                if (sum(newcounts2)>0) and (bases[2] in BASES):
                    altbase = BASE2INT[bases[2]]
                    newvaraf = float(newcounts2[altbase])/sum(newcounts2)
                else:
                    newvaraf = 0
                var_data[locus].extend(newcounts1)
                var_data[locus].extend(newcounts2)
                var_data[locus].append(newvaraf)
            # skip other read entry
        else: # first entry
            var_data[locus] = [ref_index,strand,read]
            var_data[locus].extend(poss)
            var_data[locus].extend(bases)
            var_data[locus].extend(onecounts)
            var_data[locus].extend(twocounts)
            var_data[locus].append(varaf)
            regct+=1

    c=0
    for reg in range(regct):
        if str(reg) in var_data: # if reads were too short to cover variant this will fail, so check first
            y=var_data[str(reg)][1:]
            data_all[c].extend(y)
        else: 
            log.info("Target variant missing from per base files %s" % (str(reg))) 
        c+=1


########################################################################
# Write to summary file
output_file = snakemake.output.combined_report
outfile = open(output_file,"w")

outfile.write(tabprint(header_all)+"\n")
for data in data_all:
    outfile.write(tabprint(data)+"\n")

outfile.close()

########################################################################
# End timer
t1 = time.time()
td = (t1 - t0) / 60
log.info("Done in %0.2f minutes" % td)
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
suppressMessages(library(ggplot2))
suppressMessages(library(cowplot))
library(scales)

# Collect list of report files to plot 
args = commandArgs(trailingOnly=TRUE)
outplot=args[1]
outtxt=args[2]
badlocifile=args[3]
infiles=args[-c(1,2,3)]

# Number of files
N=length(args)-3

# Load files and get sample name
load_table = function(tablefile){
  D=read.table(tablefile,sep="\t",header=T,stringsAsFactors=F)
  D$gt2templates=apply(D[,c("A2","C2","G2","T2")],1,sum)
  return(D)
}
get_samplename = function(tablefile){
  Dname=sub('\\.final_report.txt$', '', basename(tablefile))
  return(Dname)
}

# Function to get plots for 1 file
qc_plots_1file = function(D,Dname,qc_fail){
  qcindex = which(D$loc %in% qc_fail)
  # Number of Templates
  p1 = ggplot(data=D)+
    geom_bar(aes(x=factor(loc),y=Rolled.Up.Tags+1),stat="identity",fill="lightpink") +
    xlab("Locus") + 
    ylab("Total Templates") + 
    scale_y_continuous(trans = "log10", breaks=10^(0:7),labels=math_format(expr=10^.x)(0:7)) +
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.5, size=8),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",
        sprintf("%0.2f",mean(D$Rolled.Up.Tags))))+
 #       "\t",
#        sprintf("%0.2f",mean(D$Yield:.Rolled-Up.Tags))))+ 
    geom_hline(yintercept=0.10*median(D$Rolled.Up.Tags),lty=3,col='black')
p2 = ggplot(data=D)+
    geom_bar(aes(x=factor(loc),y=gt2templates),stat="identity",fill="gold2") +
    xlab("Locus") + 
    ylab("Aligned >=2 RPT Templates") + 
    scale_y_continuous(trans = "log10", breaks=10^(0:7),labels=math_format(expr=10^.x)(0:7)) +
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.5, size=8),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",sprintf("%0.2f",mean(D$gt2templates))))+
    geom_hline(yintercept=0.10*median(D$gt2templates),lty=3,col='red')
  # Reads Per Template
  p3 = ggplot(data=D)+ 
    geom_bar(aes(x=factor(loc),y=Avg.Reads.Per.Rolled.Up.Tag),stat="identity",fill="green") +
    xlab("Locus") + 
    ylab("Average Reads Per Tags") + 
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.5, size=8),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",sprintf("%0.2f",mean(D$Avg.Reads.Per.Rolled.Up.Tag))))+
    geom_hline(yintercept=5,lty=3,col='black')
  # Primer Alignment Rate
  p4 = ggplot(data=D)+ 
    geom_bar(aes(x=factor(loc),y=Percent.of.Assigned.that.Pass),stat="identity",fill="orange") +
    xlab("Locus") + 
    ylab("Primer Alignment Rate") + 
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.5, size=8),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",sprintf("%0.2f",mean(D$Percent.of.Assigned.that.Pass))))+
    ylim(c(0,1.05))+
    geom_hline(yintercept=0.75,lty=3,col='red')
  # Sequence Alignment Rate
  p5 = ggplot(data=D)+ 
    geom_bar(aes(x=factor(loc),y=Fraction.Aligned.Reads),stat="identity",fill="lightblue") +
    xlab("Locus") + 
    ylab("Sequence Alignment Rate") + 
    theme(axis.text.x = element_text(angle = 90, hjust=1, vjust=0.5, size=6),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",sprintf("%0.2f",mean(D$Fraction.Aligned.Reads))))+
    ylim(c(0,1.05))+
    geom_hline(yintercept=0.75,lty=3,col='red')
  # Uncorrected VAF
  p6 = ggplot(data=D)+ 
    geom_bar(aes(x=factor(loc),y=VarAF),stat="identity",fill="darkorchid3") +
    xlab("Locus") + 
    ylab("Uncorrected Variant AF") + 
    theme(axis.text.x = element_text(angle = 90, hjust=1, vjust=0.5, size=6),
          axis.text.y = element_text(size=8),
          axis.title.x = element_text(size=10),
          axis.title.y = element_text(size=10),
          plot.title = element_text(size=10)) +
    ggtitle(paste0(Dname,"\n",sprintf("%0.2f",mean(D$VarAF))))

  # IF THERE ARE QC FAIL LOCI, ANNOTATE THEM
    if(length(qcindex)>0){
        p1 = p1 + annotate('point', x=qcindex, y=1.05*max(D$Rolled.Up.Tags), size=5, color='red',shape='*')
        p2 = p2 + annotate("point", x=qcindex, y = 1.05*max(D$gt2templates), size=5, color='red',shape='*')
        p3 = p3 + annotate("point", x=qcindex, y = 1.05*max(D$Avg.Reads.Per.Rolled.Up.Tag), size=5, color='red',shape='*')
        p4 = p4 + annotate("point", x=qcindex, y = 1.02, size=5, color='red',shape='*')
        p5 = p5 + annotate("point", x=qcindex, y = 1.02, size=5, color='red',shape='*')
        p6 = p6 + annotate("point", x=qcindex, y = 1.05*max(D$VarAF), size=5, color='red',shape='*')

    }

  # Save all plots 
  onesamp_plots=list(p1,p2,p3,p4,p5,p6)
  return(onesamp_plots)
}

# QC filters...
qc_filters = function(alltables){
  locusnames=alltables[[1]]$loc

  get_primer_rate=function(D){
    return(D$Percent.of.Assigned.that.Pass)
  }
  combinedcols=do.call(cbind,lapply(alltables,get_primer_rate))
  max_primer_rate=apply(combinedcols,1,max)

  get_align_rate=function(D){
    return(D$Fraction.Aligned.Reads)
  }
  combinedcols=do.call(cbind,lapply(alltables,get_align_rate))
  max_align_rate=apply(combinedcols,1,max)
  min_align_rate=apply(combinedcols,1,min)

  get_total_templates=function(D){
    return(D$gt2templates)
  }
  combinedcols=do.call(cbind,lapply(alltables,get_total_templates))
  medcounts=apply(combinedcols,2,median)
  medcountmat=matrix(medcounts,nrow=nrow(combinedcols),ncol=length(medcounts),byrow=TRUE)
  n_okcount=apply(combinedcols>(medcountmat*0.1),1,sum)

  # return(locusnames[max_primer_rate<0.75 | max_align_rate<0.75 ])
  return(locusnames[max_primer_rate<0.75 | min_align_rate<0.5 | max_align_rate<0.75 | n_okcount==0])

}




# Load all files
alltables = lapply(infiles, load_table)
allnames = unlist(lapply(infiles, get_samplename))

# QC fail loci
qc_fail=qc_filters(alltables)

# Add bad_loci
bl = read.table(badlocifile,header=F,col.names=c("Locus"))
if (nrow(bl)>0){
    qc_fail = unique(c(qc_fail,bl$Locus))
}

# Write to file
write.table(qc_fail,file=outtxt,quote=F,sep="\t",col.names=F,row.names=F)

# Plot all samples, with annotated QC fails
allplots = mapply(qc_plots_1file, alltables, allnames, MoreArgs=list(qc_fail=qc_fail))

# Plot grid
# Rows - samples
# Columns - metrics
grDevices::pdf(NULL)
p=plot_grid(plotlist=allplots,nrow=N,ncol=6) 
# generates Rplot.pdf, if null device is not opened (in some version of ggplot / R / cowplot)
save_plot(outplot, p, base_height = 3*N, base_width = 24)
grDevices::dev.off()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint

########################################################################

# Start timer
t0 = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################
input_file = snakemake.input.combined_report
output_file = snakemake.output.plot1
sample = snakemake.params.sample
########################################################################

# Process report file to get counts
num_obs_tags = []
num_uniq_tags = []
num_collapsed_tags = []

infile = open(input_file,"r")
linecount=1
regioncount = 0
for x in infile:
    line = x.strip().split()
    if linecount>1:
        region = line[0]
        regioncount+=1
        num_obs_tags.append(int(line[1]))
        num_uniq_tags.append(int(line[2]))
        num_collapsed_tags.append(int(line[3]))
    linecount+=1
infile.close()
print(num_obs_tags)
print(num_uniq_tags)
print(num_collapsed_tags)

########################################################################

# Graph for each region, total tags, unique tags, collapsed tags
N = regioncount
fig = plt.figure(figsize=(50,10))
fig.suptitle(sample+"\n"+
             "Results of Tag Rollup - 1 error allowed", fontsize=16)
x=range(N)
width=0.25
ax = plt.subplot(111)
ax.bar(x,num_obs_tags,width,alpha=0.5,color='purple',label='Total')
ax.bar([p + width for p in x],num_uniq_tags,width,alpha=0.5,color='green',label='Unique')
ax.bar([p + width*2 for p in x],num_collapsed_tags,width,alpha=0.5,color='blue',label='Collapsed')
ax.legend(['Total','Unique','Collapsed'], loc='upper left')
ax.ticklabel_format(style='plain')
ax.get_yaxis().set_major_formatter(
    matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

plt.xticks([p + width for p in x],range(N))
plt.xlim([min(x)-width,max(x)+width*4])
plt.xlabel("Region",fontsize=14)
plt.ylabel("Tag Counts",fontsize=14)
plt.savefig(output_file, dpi=200, facecolor='w', edgecolor='w',
            papertype=None, format=None,
            transparent=False)
plt.close()


########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import sys
import pickle
from primer_design_functions import tabprint
from primer_design_functions import reverseComplement

primer_table=snakemake.input.oldtable

# sequence dictionary
seq_pickle=snakemake.params.ref_genome
ref_dic = pickle.load(open(seq_pickle, 'rb'))


c=0
loci=0
with open(snakemake.output.newtable,'w') as fout:
    with open(primer_table,'r') as f:
        for line in f:
            if c==0:
                header=line.strip().split("\t")
                fout.write(tabprint(["loc","chr","posi","specific-primer-1","specific-primer-2","trimmed-target-seq","target_locs","ref-alt_allele"])+"\n")
            else:
                X=line.strip().split("\t")

                #chrom=X[header.index("chrom")][3:]
                chrom=X[header.index("chrom")]
                pos=int(X[header.index("pos")])
                primer1=X[header.index("downstream_primerseq")]
                primer2=X[header.index("cutadj_primerseq")]
                strand=X[header.index("strand")]
                refbase=X[header.index("ref_trinuc")][1]
                altbase=X[header.index("alt_trinuc")][1]
                ref_alt=refbase+"_"+altbase

                # get target sequence coordinates based on primer Coordinates
                # pull sequence from dictionary
                # reverse complement depending on strand
                # target position is relative to target seq
                primercoords1=X[header.index("cutadj_primer_coordinates")].split(':')[1].split('-')
                primercoords1.extend(X[header.index("downstream_primer_coordinates")].split(':')[1].split('-'))
                primercoords = [int(x) for x in primercoords1]
                primercoords.sort()

                target_start = primercoords[1]
                target_end = primercoords[2]-1

                target_seq_ref=ref_dic[chrom][target_start:target_end]

                if strand=="bottom":
                    target_seq_stranded = reverseComplement(target_seq_ref)
                    relative_pos = target_end - pos
                elif strand=="top":
                    target_seq_stranded = target_seq_ref
                    relative_pos = pos - target_start - 1

                fout.write(tabprint([loci,chrom,pos,primer1,primer2,target_seq_stranded,relative_pos,ref_alt])+"\n")

                loci=loci+1
            c=c+1
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint
from masq_helper_functions import setup_logger
########################################################################
# Start timer
t0 = time.time()
# Setup log file
log = setup_logger(snakemake.log,'filter_loci')
log.info('Starting process')

########################################################################

# Input is combined base report file 
# Output is same file, with QC filtered loci removed
log.info('Filtering report')

qc_fail_loci=[]
with open(snakemake.input.qcfail,'r') as f:
    for line in f:
        qc_fail_loci.append(line.strip())

c=0
with open(snakemake.input.base_report,'r') as f:
    with open(snakemake.output.filtered_base_report,'w') as fout:
        for line in f:
            if (c==0):
                fout.write(line)
            else:
                x=line.split()
                if (x[1] not in qc_fail_loci):
                    fout.write(line)
            c=c+1

########################################################################
# End timer
t1 = time.time()
td = (t1 - t0) / 60
log.info("Done in %0.2f minutes" % td)
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
import os,sys
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import editdistance
import re
import logging
from masq_helper_functions import tabprint, hamming, double_counter
from masq_helper_functions import readWalker, convert_quality_score
from masq_helper_functions import hamming_withNs, check_tag_structure
from masq_helper_functions import setup_logger, load_snv_table

########################################################################
# Start timer
t0 = time.time()
# Setup log file
log = setup_logger(snakemake.log,'sort_data_by_loci')
log.info('Starting process')

########################################################################
# INPUT FILES AND PARAMETERS
log.info('Getting input files and parameters from snakemake object')
# Input FASTQs
fastq1 = snakemake.input.fastq1
fastq2 = snakemake.input.fastq2

# Input SNV table
SNV_table = open(snakemake.input.SNV_table,'r')

# Sequence and hamming parameters
MAX_HAMMING_UP2 = snakemake.params.UP2_hamming
MAX_HAMMING_SS = snakemake.params.SS_sum_hamming
MIN_LEN = snakemake.params.min_len
TRIM_LEN = snakemake.params.trim_len
TAG_TEMPLATE = snakemake.params.tag
UP2 = snakemake.params.UP2
PROTOCOL = snakemake.params.protocol
USE_EDIT_DISTANCE = snakemake.params.use_edit_distance

## compute some lengths
UPL   = len(UP2)
TAGL  = len(TAG_TEMPLATE)
UPVTL = UPL + TAGL

# Masking low quality bases
MASK_LOWQUAL = snakemake.params.mask_lowqual_bases
QUAL_CUTOFF = snakemake.params.qual_cutoff
MAX_N_RATIO = snakemake.params.max_N_ratio

# Quick run parameters
QUICK_RUN = snakemake.params.quick_run
QUICK_RUN_READS = snakemake.params.quick_run_reads

########################################################################
# OUTPUT FILES
log.info('Getting output files from snakemake object')

# Main report
outfile = open(snakemake.output.counter_report, 'w')

# Troubleshooting reports
# Unmatched UP2
up2file = open(snakemake.output.up2_unmatched_report, 'w')
# Unmatched SS1 and SS2
ss1ss2file = open(snakemake.output.ss1ss2_unmatched_report, 'w')
# Good and bad tag structure
goodtagfile = open(snakemake.output.goodtag_report, 'w')
badtagfile = open(snakemake.output.badtag_report, 'w')

# Plots
# Hamming distance of SS1 and SS2
hamming_SS_figure = snakemake.output.ss_hamming_plot

# Seq data pickles
# Lists of files (one for each region)
vt_counters_filelist = snakemake.output.vt_counters
vt_seq_counters_filelist = snakemake.output.vt_seq_counters
flip_counters_filelist = snakemake.output.flip_counters



########################################################################
# Load the input SNV table
log.info('Loading SNV table')
snv_info = load_snv_table(SNV_table)
for key,value in snv_info.items():
    log.debug(key)
    log.debug(tabprint(value))
log.info('Done loading SNV table')

########################################################################
# Get sequence specific primer info from the SNV_table
log.info('Getting sequence specific primer info from table')
SP1 = snv_info['specific-primer-1']
SP2 = snv_info['specific-primer-2']
## keep track of the lengths of the primers
## and the longest seen of each
SP1_len = [len(x) for x in SP1]
SP2_len = [len(x) for x in SP2]
MAX_SP1 = max(SP1_len)
MAX_SP2 = max(SP2_len)

########################################################################
# Setup counters for processing reads
log.info('Initializing counters')
# Single counters
counter      = 0
good_counter = 0
up2_counter  = 0
flip_counter = 0
oklength_counter = 0

# Counters by region
NREG = len(SP1)
good_counter_list = [0 for _ in range(NREG+1)] # extra one for garbage
assigned_counter_list = [0 for _ in range(NREG+1)] # extra one for garbage

# Counter for SS1 and SS2 hamming distances
hamming_counter = Counter()

# Counters with compressed sequence information
## keep a counters for each primer pair
## the counters are:
## how many reads with each varietal tag
## for each varietal tag, a Counter for the associated read pairs
## for each varietal tag and read-pair sequence,
##    how often it was not-flipped (so UP2 is on read2) and flipped.
# Lists of counters...
vt_counters     = [Counter() for _ in SP1]
vt_seq_counters = [defaultdict(Counter) for _ in SP1]
flip_counters   = [defaultdict(double_counter) for _ in SP1]

########################################################################
# Build read walker from 2 FASTQ files
log.info('Building read walker for 2 fastqs')
rwalker = readWalker(fastq1, fastq2)

########################################################################
# Process reads
# Check for:
# 1 - OK length (then trim)
# 2 - correct UP2 sequence (then remove it and tag)
# 3 - correct SS1 and SS2 sequences (then group reads and remove primers)
# for regular PCR data, skip the UP2 step, check both orientations for SS1/SS2
log.info('Begin processing reads')
counter = 0

log.debug("PROTOCOL: %s" % PROTOCOL)
if not(PROTOCOL=="standard PCR"):
    log.info('Assuming UP2 and VT are present (Not Standard PCR)')

for read_pair in rwalker:
    ## record if we flip it around so UP2 is on read2 (or SP1 is R1, SP2 is R2)
    flipped = False
    counter += 1

    # For a test run with only a small number of reads processed
    if QUICK_RUN and counter > QUICK_RUN_READS:
        log.warning('Stopped processing reads - reached quick run cutoff')
        break

    # Write update to log file
    if counter % 10000 == 0:
        for r in range(NREG+1):
            log.debug(tabprint(
               [r,
                assigned_counter_list[r],
                good_counter_list[r],
                assigned_counter_list[r] / np.float64(counter),
                good_counter_list[r] / np.float64(counter),
                good_counter_list[r] / np.float64(assigned_counter_list[r]) ] ))
        log.debug(' ')
        log.debug(tabprint([counter, oklength_counter, up2_counter, good_counter, oklength_counter / np.float64(counter), up2_counter / np.float64(counter), good_counter / np.float64(counter)]))

    # Get sequences from fastq
    seq1, seq2 = read_pair[1], read_pair[3]
    qual1, qual2 = read_pair[2], read_pair[4]

    # Check length
    ## if too short, skip this read pair
    if len(seq1) < MIN_LEN or len(seq2) < MIN_LEN:
        log.debug('Skipping read for length of read pair. Sequence1: %s' % seq1)
        log.debug('Skipping read for length of read pair. Sequence2: %s' % seq2)
        continue
    oklength_counter += 1

    ## otherwise, trim to length
    # Works even if TRIM_LEN is greater than length of sequence
    seq1 = seq1[:TRIM_LEN]
    seq2 = seq2[:TRIM_LEN]
    qual1 = qual1[:TRIM_LEN]
    qual2 = qual2[:TRIM_LEN]

    # check for UP2 if MASQ (not for standard PCR)
    if not(PROTOCOL=="standard PCR"):
        # log.info('Assuming UP2 and VT are present (Not Standard PCR)')
        ## compute hamming distance for UP2 to each read
        # check beginning of each read for match to UP2
        d1 = hamming(seq1, UP2, use_edit_distance=USE_EDIT_DISTANCE)
        d2 = hamming(seq2, UP2, use_edit_distance=USE_EDIT_DISTANCE)

        ## if UP2 matches read1, flip it
        if d1 < d2:
            seq1, seq2 = seq2, seq1
            qual1, qual2 = qual2, qual1
            flipped = True
            flip_counter += 1

        ## pull the varietal tag and the rest of seq2 (after trimming tag and UP2)
        vt         = seq2[UPL:UPVTL]
        seq2_rest  = seq2[UPVTL:]
        qual2_rest = qual2[UPVTL:]
    else: # STANDARD PCR PROTOCOL
        # fake the up2 distances so they pass check
        d1=0;  d2=100
        # fake the trimming of seq2
        seq2_rest = seq2; qual2_rest = qual2
        # fake the varietal tag to be unique for each read
        vt = str(counter)

    ## if the match is close enough
    # if UP2 is found, move on with this read pair
    # if protocol has no UP2, move on
    if (min(d1, d2) <= MAX_HAMMING_UP2):
        # keep track of how many correct UP2s are found
        up2_counter += 1

        ## check hamming distance of seq1 to list of SP1 and seq2 to list of SP2
        ham_array1 = [hamming(seq1     , sp, use_edit_distance=USE_EDIT_DISTANCE) for sp in SP1]
        ham_array2 = [hamming(seq2_rest, sp, use_edit_distance=USE_EDIT_DISTANCE) for sp in SP2]
        ## find the "best" index for each
        ham_ind1   = np.argmin(ham_array1)
        ham_ind2   = np.argmin(ham_array2)
        ## and store the distance
        ham_dist1  = ham_array1[ham_ind1]
        ham_dist2  = ham_array2[ham_ind2]

        if PROTOCOL!="standard PCR": # only check one orientation
            log.debug("Regular QSD: only check one orientation")
            ## pick the index with the smaller hamming
            best_ind   = ham_ind1 if ham_dist1 < ham_dist2 else ham_ind2

            ## store the value of the match into the hamming counter array
            ## for showing size of off targetness
            hamming_counter[(ham_array1[best_ind], ham_array2[best_ind])] += 1

            # to continue with this read - best matches must agree and have low distances
            hamming_sum = ham_dist1 + ham_dist2
            hamming_match = ham_ind1 == ham_ind2
            hamming_min = min(ham_dist1,ham_dist2)

        # For standard PCR, also check the flipped version
        else:
            log.debug("Standard PCR: checking flipped orientation for match")
            ham_array3 = [hamming(seq1     , sp, use_edit_distance=USE_EDIT_DISTANCE) for sp in SP2]
            ham_array4 = [hamming(seq2_rest, sp, use_edit_distance=USE_EDIT_DISTANCE) for sp in SP1]
            ## find the "best" index for each
            ham_ind3   = np.argmin(ham_array3)
            ham_ind4   = np.argmin(ham_array4)
            ## and store the distance
            ham_dist3  = ham_array3[ham_ind3]
            ham_dist4  = ham_array4[ham_ind4]

            ## pick the index with the smaller hamming
            min_dist=100
            for i,ind,dist in zip([1,2,3,4],[ham_ind1,ham_ind2,ham_ind3,ham_ind4],[ham_dist1,ham_dist2,ham_dist3,ham_dist4]):
                if dist<min_dist:
                    min_dist=dist
                    best_ind = ind
                    best_orientation = i
            if best_orientation<3: # original orientation is good
                log.debug("Original Orientation")    
                hamming_counter[(ham_array1[best_ind], ham_array2[best_ind])] += 1
                hamming_sum = ham_dist1 + ham_dist2
                hamming_match = ham_ind1 == ham_ind2
                hamming_min = min(ham_dist1,ham_dist2)
            else: # flip the reads and use the second 2 hamming comparisons
                log.debug("Flipped Orientation")
                hamming_counter[(ham_array3[best_ind], ham_array4[best_ind])] += 1
                hamming_sum = ham_dist3 + ham_dist4
                hamming_match = ham_ind3 == ham_ind4
                hamming_min = min(ham_dist3,ham_dist4)

                seq1, seq2_rest = seq2_rest, seq1
                qual1, qual2_rest = qual2_rest, qual1
                flipped = True
                flip_counter += 1
                log.debug("Flip counter: %d" % flip_counter)


        ## if the sum of the distances is less than MAX Hamming and the two reads agree on the best match
        # If SP1 and SP2 look ok, move on with this read pair
        if hamming_sum <= MAX_HAMMING_SS and hamming_match:
            ## get the rest of the reads and store the info into the counters
            # did this read pair pass all the filters
            good_counter += 1
            good_counter_list[best_ind] +=1
            # remove SP1 and SP2, take rest of read
            ror1 = seq1     [SP1_len[best_ind]:]
            ror2 = seq2_rest[SP2_len[best_ind]:]

            # If mask low qual parameter setting is true, replace low quality bases with N's
            if MASK_LOWQUAL:
                qual1_target = qual1     [SP1_len[best_ind]:]
                qual2_target = qual2_rest[SP2_len[best_ind]:]
                qual1_converted = convert_quality_score(qual1_target)
                qual2_converted = convert_quality_score(qual2_target)
                ror1 = ''.join([ror1[x] if (qual1_converted[x]>QUAL_CUTOFF) else ('N') for x in range(len(ror1))])
                ror2 = ''.join([ror2[x] if (qual2_converted[x]>QUAL_CUTOFF) else ('N') for x in range(len(ror2))])

            # counter for each varietal tag, per locus
            vt_counters[best_ind][vt] += 1
            # store the sequences for each varietal tag, per locus
            vt_seq_counters[best_ind][vt][(ror1, ror2)] += 1
            # store flip status for each vt/r1/r2, per locus
            flip_counters[best_ind][(vt, ror1, ror2)][int(flipped)] += 1
            # this is a good read pair -  record if its tag structure is correct in "good tag" file
            goodtagfile.write(tabprint([vt,check_tag_structure(vt,TAG_TEMPLATE)])+"\n")
        else:
            if hamming_min<4: # ss1 and ss2 did not pass but came close
                # Save the info for these failed reads to file
                read_ss1= seq1[:len(SP1[best_ind])]
                read_ss2= seq2_rest[:len(SP2[best_ind])]
                read_ss1_ham = hamming(read_ss1,SP1[best_ind])
                read_ss2_ham = hamming(read_ss2,SP2[best_ind])
                read_ss1_edit = editdistance.eval(read_ss1,SP1[best_ind])
                read_ss2_edit = editdistance.eval(read_ss2,SP2[best_ind])
                ror2 = seq2_rest[SP2_len[best_ind]:]
                ror1 = seq1[SP1_len[best_ind]:]
                ss1ss2file.write(tabprint([best_ind,read_ss1,read_ss1_ham,read_ss1_edit,read_ss2,read_ss2_ham,read_ss2_edit,ror1,ror2,check_tag_structure(vt,TAG_TEMPLATE)])+"\n")
                # this is a bad read pair - record if its tag structure is correct in "bad tag" file
                badtagfile.write(tabprint([vt,check_tag_structure(vt,TAG_TEMPLATE)])+"\n")
            else: # not even close to any SS's by hamming distance...
                ss1ss2file.write(tabprint([None,seq1[:20],None,None,seq2_rest[:20],None,None,seq1[20:],seq2_rest[20:],check_tag_structure(vt,TAG_TEMPLATE)])+"\n")
                # this is a bad read pair - record if its tag structure is correct in "bad tag" file
                badtagfile.write(tabprint([vt,check_tag_structure(vt,TAG_TEMPLATE)])+"\n")

        # for reporting how many are assigned to each value
        if hamming_min<4: # hard coded cutoff for garbage assignments
            assigned_counter_list[best_ind] +=1
        else:
            assigned_counter_list[-1] += 1 # garbage bin
    else:  # UP2 did not match well on either side
        s1, s2 = read_pair[1], read_pair[3]
        up2_seq1 = s1[:UPL]
        up2_seq2 = s2[:UPL]
        dedit1 = editdistance.eval(up2_seq1,UP2)
        dedit2 = editdistance.eval(up2_seq2,UP2)
        up2file.write(tabprint([up2_seq1,d1,dedit1,up2_seq2,d2,dedit2])+"\n")

log.info('Done processing reads')

########################################################################

# Write to output files - counter report
log.info('Writing read assignment report')

header=["Read Pairs Processed", "Pass Length Filter", "Pass Length+UP2 Filters","Pass All Filters", "Flipped Read Pair", "Fraction Length Good", "Fraction Length+UP2 Good", "Fraction All Good",  "Fraction Flipped"]
counters=[counter, oklength_counter, up2_counter, good_counter, flip_counter, oklength_counter / np.float64(counter), up2_counter / np.float64(counter), good_counter / np.float64(counter), flip_counter/ np.float64(counter)]
outfile.write(tabprint(header) + "\n")
outfile.write(tabprint(counters) + "\n\n")
outfile.write(tabprint(["Region","Assigned Reads","Pass Reads","Percent Assigned", "Percent Pass", "Percent of Assigned that Pass"])+"\n")
for r in range(NREG+1):
    region = r if r<NREG else "NONE"
    outfile.write(tabprint(
          [region,
           assigned_counter_list[r],
           good_counter_list[r],
           assigned_counter_list[r] / np.float64(counter),
           good_counter_list[r] / np.float64(counter),
           good_counter_list[r] / np.float64(assigned_counter_list[r]) ]))
    outfile.write("\n")

########################################################################
# Write condensed sequence information in counter objects to python pickle files
# Separate data by region
log.info("Saving vt counters")
for r in range(NREG):
    log.info("Saving region %d" % r)
    pickle.dump(vt_counters[r], open(vt_counters_filelist[r], 'wb'), pickle.HIGHEST_PROTOCOL)
    pickle.dump(vt_seq_counters[r], open(vt_seq_counters_filelist[r], 'wb'), pickle.HIGHEST_PROTOCOL)
    pickle.dump(flip_counters[r], open(flip_counters_filelist[r], 'wb'), pickle.HIGHEST_PROTOCOL)


########################################################################
# Close output files
log.info('Closing output files')
outfile.close()
up2file.close()
ss1ss2file.close()
goodtagfile.close()
badtagfile.close()

########################################################################
# FIGURE - Hamming distances from SS1 and SS2
log.info('Plotting hamming distances from SP1 and SP2')
N = len(SP1) # number of regions
matrix = np.zeros(shape=(N, N), dtype=int)
print("\n"+"Hamming distances for SS1, SS2, and counts of reads"+ "\n")
for x in range(N):
    for y in range(N):
        print(x, y, hamming_counter[(x, y)])
        matrix[x, y] = hamming_counter[(x, y)]

fig = plt.figure(figsize=(20, 18))
fig.suptitle("Hamming Distance from Sequence Specific Primers", fontsize=20)
cax = plt.imshow(np.log10(matrix), interpolation="nearest", aspect="auto")
cbar = plt.colorbar(cax)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(["$10^%d$" % d for d in np.arange(7)])
cbar.ax.tick_params(labelsize=16)
plt.xlabel("hamming distance from SP1 (internal)", fontsize=18)
plt.ylabel("hamming distance from SP2 (cut-site)", fontsize=18)
plt.tick_params(labelsize=16)
plt.savefig(hamming_SS_figure, dpi=200, facecolor='w', edgecolor='w',
            papertype=None, format=None,
            transparent=False)
plt.close()

########################################################################
# End timer
t1 = time.time()
td = (t1 - t0) / 60
log.info("Done in %0.2f minutes" % td)
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm


from masq_helper_functions import tabprint
from masq_helper_functions import plot_number_of_reads_per_tag
from masq_helper_functions import plot_at_least_x_reads_per_tag
########################################################################

# Start timer
t0_all = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Filenames and parameters
vt_counter_filenames = snakemake.input.vt_counters
NREG = len(vt_counter_filenames)

# Sample name
sample = snakemake.params.sample

# Output report file
outfilename = snakemake.output.tagcounts
outfile = open(outfilename,"w")
########################################################################

# Load sequence data
t0 = time.time()
print("loading sequence data...")
vt_counters=[]
vt_counters = [Counter() for _ in range(NREG)]
for r in range(NREG):
    vt_counters[r]=pickle.load(open(vt_counter_filenames[r], 'rb'))
print("data loaded in %0.2f seconds" % (time.time() - t0))

########################################################################

# tabulate distribution of reads per tag
reads_per_tag = Counter()
Nreads_total = 0
Ntags_total = 0

for r in range(NREG):
    vt_counter = vt_counters[r]
    Ntags = len(vt_counter)
    Ntags_total += Ntags
    for tag, tag_count in vt_counter.items():
        reads_per_tag[tag_count]+=1
        Nreads_total += tag_count

numreads=np.array(list(reads_per_tag.keys()))
numtags=np.array(list(reads_per_tag.values()))

########################################################################

# Plot reads per tag
plot_number_of_reads_per_tag(
                                numreads,
                                numtags,
                                Nreads_total,
                                Ntags_total,
                                sample=sample,
                                filename=snakemake.output.plot1,
                                logscale=True)
plot_at_least_x_reads_per_tag(
                                numreads,
                                numtags,
                                Nreads_total,
                                Ntags_total,
                                sample=sample,
                                filename=snakemake.output.plot2,
                                logscale=True,
                                maxcount=100)
########################################################################

# Write bar graph counts out to file
header = ["Region","Reads Per Tag","Number of Tags"]
outfile.write(tabprint(header)+"\n")
for i in range(len(numreads)):
    outfile.write(tabprint(["Combined",numreads[i],numtags[i]])+"\n")

########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import os
import numpy as np
import gzip
from collections import Counter, defaultdict
import fileinput
import operator
import pickle
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from masq_helper_functions import tabprint
from masq_helper_functions import plot_number_of_reads_per_tag
from masq_helper_functions import plot_at_least_x_reads_per_tag
########################################################################

# Start timer
t0_all = time.time()

########################################################################

# Redirect prints to log file
old_stdout = sys.stdout
log_file = open(str(snakemake.log),"w")
sys.stdout = log_file
sys.stderr = log_file

########################################################################

# Filenames and parameters
vt_counter_filename = snakemake.input.vt_counter

# WHICH REGION ARE WE PROCESSING
REGION = snakemake.params.region
# Sample name
sample = snakemake.params.sample

# Output report file
outfilename = snakemake.output.tagcounts
outfile = open(outfilename,"w")
########################################################################

# Load sequence data
t0 = time.time()
print("loading sequence data...")
vt_counter = pickle.load(open(vt_counter_filename, 'rb')) # only one region
print("data loaded in %0.2f seconds" % (time.time() - t0))

########################################################################

# tabulate distribution of reads per tag
reads_per_tag = Counter()
Nreads_total = 0
Ntags_total = len(vt_counter)
for tag, tag_count in vt_counter.items():
    reads_per_tag[tag_count]+=1
    Nreads_total += tag_count
numreads=np.array(list(reads_per_tag.keys()))
numtags=np.array(list(reads_per_tag.values()))

########################################################################

# Plot reads per tag
plot_number_of_reads_per_tag(
                                numreads,
                                numtags,
                                Nreads_total,
                                Ntags_total,
                                sample=sample,
                                filename=snakemake.output.plot1,
                                logscale=True)
plot_at_least_x_reads_per_tag(
                                numreads,
                                numtags,
                                Nreads_total,
                                Ntags_total,
                                sample=sample,
                                filename=snakemake.output.plot2,
                                logscale=True,
                                maxcount=100)
########################################################################

# Write bar graph counts out to file
header = ["Region","Reads Per Tag","Number of Tags"]
outfile.write(tabprint(header)+"\n")
for i in range(len(numreads)):
    outfile.write(tabprint([int(REGION),numreads[i],numtags[i]])+"\n")

########################################################################

# End timer
t1 = time.time()
td = (t1 - t0) / 60
print("done in %0.2f minutes" % td)

########################################################################
# Put standard out back...

sys.stdout = old_stdout
log_file.close()
 5
 6
 7
 8
 9
10
11
12
13
threads=$1
outdir=$2
fastq1=$3
fastq2=$4
sample_bc_list=$5

barcode_program="scripts/trimBarcodeFragments"

$barcode_program -nThreads=$threads -dirPrefix=$outdir -inputFiles=$fastq1,$fastq2 $sample_bc_list
120
121
script:
    "scripts/primer_table_to_sd_table.py"
SnakeMake From line 120 of master/Snakefile
135
136
137
138
shell:
    """
    scripts/trim_bcs.sh {params.threads} {params.output_dir} {input.presplit_fastq1} {input.presplit_fastq2} {params.sample_bc_list}
    """
SnakeMake From line 135 of master/Snakefile
153
154
shell:
    "(fastqc --nogroup -t {threads} {input} -o {wildcards.sample}/fastqc) >& {log}"
SnakeMake From line 153 of master/Snakefile
169
170
script:
    "scripts/check_loci_plot_and_extend.py"
SnakeMake From line 169 of master/Snakefile
205
206
script:
    "scripts/sort_data_by_tag_and_locus.py"
SnakeMake From line 205 of master/Snakefile
238
239
script:
    "scripts/all_base_report.py" 
SnakeMake From line 238 of master/Snakefile
250
251
script:
    "scripts/combine_withintagerr.py"
SnakeMake From line 250 of master/Snakefile
267
268
script:
    "scripts/tag_count_graphs.py"
SnakeMake From line 267 of master/Snakefile
283
284
script:
    "scripts/tag_count_graphs_allregions.py"
SnakeMake From line 283 of master/Snakefile
304
305
script:
    "scripts/collapse_tags.py"
SnakeMake From line 304 of master/Snakefile
316
317
script:
    "scripts/combine_reports.py"
SnakeMake From line 316 of master/Snakefile
329
330
script:
    "scripts/combine_reports.py"
SnakeMake From line 329 of master/Snakefile
341
342
script:
    "scripts/combine_reports.py"
SnakeMake From line 341 of master/Snakefile
353
354
script:
    "scripts/combine_reports.py"
SnakeMake From line 353 of master/Snakefile
366
367
script:
    "scripts/combine_reports.py"
SnakeMake From line 366 of master/Snakefile
378
379
script:
    "scripts/extract_variant_info.py"
SnakeMake From line 378 of master/Snakefile
392
393
script:
    "scripts/plot_rollup_results.py"
SnakeMake From line 392 of master/Snakefile
408
409
script:
    "scripts/final_report.py"
SnakeMake From line 408 of master/Snakefile
414
415
416
417
418
run:
    with open(output[0],'w') as f:
        if "bad_loci" in config:
            for loc in config["bad_loci"]:
                f.write(str(loc)+"\n")
SnakeMake From line 414 of master/Snakefile
427
428
shell:
    "R_LIBS=""; Rscript scripts/masq_QC_plots.R {output.plot} {output.qcfail} {input.bad_loci} {input.reports}"
SnakeMake From line 427 of master/Snakefile
440
441
script:
    "scripts/qcfilter_report.py"
SnakeMake From line 440 of master/Snakefile
ShowHide 31 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/amoffitt/MASQ
Name: masq
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...