ChIP-seq analysis pipeline used in Bragdon et. al. 2022.

public public 1yr ago Version: v1 0 bookmarks

Snakemake workflow used to analyze ChIP-seq data for the 2022 publication Cooperative assembly confers regulatory specificity and long-term genetic circui

Code Snippets

22
23
24
25
shell: """
    (sed 's/>/>{params.exp_name}_/g' {input.experimental} | \
    cat - <(sed 's/>/>{params.si_name}_/g' {input.spikein}) > {output}) &> {log}
    """
39
40
41
shell: """
    (bowtie2-build {input} {params.idx_path}/{wildcards.basename}) &> {log}
    """
69
70
71
72
73
74
75
76
77
shell: """
    (bowtie2 --minins {params.min_fraglength} --maxins {params.max_fraglength} --fr --no-mixed --no-discordant --al-conc-gz fastq/aligned/{wildcards.sample}_{FACTOR}-chipseq-aligned.fastq.gz --un-conc-gz fastq/unaligned/{wildcards.sample}_{FACTOR}-chipseq-unaligned.fastq.gz -p {threads} -x {params.idx_path}/{basename} -1 {input.r1} -2 {input.r2}  | \
     samtools view -buh -q {params.minmapq} - | \
     samtools sort -T .{wildcards.sample} -@ {threads} -o {output.bam} -) &> {output.log}
    mv fastq/aligned/{wildcards.sample}_{FACTOR}-chipseq-aligned.fastq.1.gz {output.aligned_r1}
    mv fastq/aligned/{wildcards.sample}_{FACTOR}-chipseq-aligned.fastq.2.gz {output.aligned_r2}
    mv fastq/unaligned/{wildcards.sample}_{FACTOR}-chipseq-unaligned.fastq.1.gz {output.unaligned_r1}
    mv fastq/unaligned/{wildcards.sample}_{FACTOR}-chipseq-unaligned.fastq.2.gz {output.unaligned_r2}
    """
91
92
93
94
95
96
shell: """
    (samtools collate -O -u --threads {threads} {input} | \
            samtools fixmate -m -u --threads {threads} - - | \
            samtools sort -u -T .remove_duplicates_sort_{wildcards.sample} -@ {threads} | \
            samtools markdup -r -f {output.markdup_log} -d 100 -m t -T .remove_duplicates_markdup_{wildcards.sample} --threads {threads} --write-index - {output.bam}) &> {log}
    """
113
114
115
116
117
118
119
120
shell: """
    (samtools view -h -@ {threads} {input.bam} $(faidx {input.fasta} -i chromsizes | \
                                                 grep {params.prefix}_ | \
                                                 awk 'BEGIN{{FS="\t"; ORS=" "}}{{print $1}}') | \
     grep -v -e 'SN:{params.filterprefix}_' | \
     sed 's/{params.prefix}_//g' | \
     samtools view -bh -@ {threads} --write-index -o {output.bam} -) &> {log}
    """
21
22
23
shell: """
    (cutadapt --cut={params.cut_5prime} -U {params.cut_5prime} --adapter=AGATCGGAAGAGCACACGTCTGAACTCCAGTCA -A AGATCGGAAGAGCGTCGTGTAGGGAAAGAGTGT --trim-n --cores={threads} --nextseq-trim={params.qual_cutoff} --minimum-length=6 --output={output.r1} --paired-output={output.r2} {input.r1} {input.r2}) &> {output.log}
    """
29
30
31
32
33
34
35
run:
    if FIGURES[wildcards.figure]["parameters"]["type"]=="absolute":
        shell("""(computeMatrix reference-point -R {input.annotation} -S {input.bw} --referencePoint {params.refpoint} -out {output.dtfile} --outFileNameMatrix {output.matrix} -b {params.upstream} -a {params.dnstream} {params.nan_afterend} --binSize {params.binsize} --averageTypeBins {params.binstat} -p {threads}) &> {log}""")
    else:
        shell("""(computeMatrix scale-regions -R {input.annotation} -S {input.bw} -out {output.dtfile} --outFileNameMatrix {output.matrix} -m {params.scaled_length} -b {params.upstream} -a {params.dnstream} --binSize {params.binsize} --averageTypeBins {params.binstat} -p {threads}) &> {log}""")
    melt_upstream = params.upstream-params.binsize
    shell("""(Rscript scripts/melt_matrix_chipseq.R -i {output.matrix} -r {params.refpoint} --group {params.group} -s {wildcards.sample} -t {wildcards.sampletype} -a {params.anno_label} -b {params.binsize} -u {melt_upstream} -o {output.melted}) &>> {log}""")
47
48
49
shell: """
    (cat {input} > {output}) &> {log}
    """
95
96
script:
    "../scripts/plot_chipseq_figures.R"
14
15
16
17
18
shell: """
    bedtools makewindows -g <(faidx {input.fasta} -i chromsizes) -w {wildcards.windowsize} | \
    awk 'BEGIN{{FS=OFS="\t"}}{{print $1, $2, $3, ".", 0, "."}}' | \
    LC_COLLATE=C sort -k1,1 -k2,2n > {output}
    """
30
31
32
33
34
shell: """
    (cut -f1-6 {input.bed} | \
     LC_COLLATE=C sort -k1,1 -k2,2n | \
     bedtools map -a stdin -b {input.bg} -c 4 -o sum > {output}) &> {log}
    """
46
47
48
49
50
shell: """
    (paste {input} | \
     cut -f$(paste -d, <(echo "1-6") <(seq -s, 7 7 {params.n})) | \
     cat <(echo -e "chrom\tstart\tend\tname\tscore\tstrand\t{params.names}" ) - > {output}) &> {log}
    """
82
83
script:
    "../scripts/differential_binding_chipseq.R"
105
106
107
shell: """
    (python scripts/chipseq_diffbind_results_to_narrowpeak.py -i {input.condition_coverage} -j {input.control_coverage} -d {input.diffbind_results} -n {output.narrowpeak} -b {output.summit_bed}) &> {log}
    """
20
21
22
23
24
shell: """
    (mkdir -p qual_ctrl/fastqc/{wildcards.fqtype}
    fastqc --adapters <(echo -e "adapter\t{params.adapter}") --nogroup --noextract -t {threads} -o qual_ctrl/fastqc/{wildcards.fqtype} {input.fastq}
    unzip -p qual_ctrl/fastqc/{wildcards.fqtype}/{params.fname}_fastqc.zip {params.fname}_fastqc/fastqc_data.txt > {output}) &> {log}
    """
81
82
83
84
85
86
87
88
89
90
91
92
93
run:
    shell("rm -f {output}")
    for fastqc_metric, out_path in output.items():
        title = fastqc_dict[fastqc_metric]["title"]
        fields = fastqc_dict[fastqc_metric]["fields"]
        for read_status, read_status_data in input.items():
            sample_id_list = ["_".join(x) for x in itertools.product((["unmatched"]if config["unmatched"]["r1"] and config["unmatched"]["r2"] else []) + list(SAMPLES.keys()), ["r1", "r2"])] if read_status=="raw" else ["_".join(x) for x in itertools.product(SAMPLES.keys(), ["r1", "r2"])]
            for sample_id, fastqc_data in zip(sample_id_list, read_status_data):
                if sample_id in ["unmatched_r1", "unmatched_r2"] and title=="Adapter Content":
                    shell("""awk 'BEGIN{{FS=OFS="\t"}} /{title}/{{flag=1;next}}/>>END_MODULE/{{flag=0}} flag {{m=$2;for(i=2;i<=NF-2;i++)if($i>m)m=$i; print $1, m, "{sample_id}", "{read_status}"}}' {fastqc_data} | tail -n +2 >> {out_path}""")
                else:
                    shell("""awk 'BEGIN{{FS=OFS="\t"}} /{title}/{{flag=1;next}}/>>END_MODULE/{{flag=0}} flag {{print $0, "{sample_id}", "{read_status}"}}' {fastqc_data} | tail -n +2 >> {out_path}""")
        shell("""sed -i "1i {fields}" {out_path}""")
117
118
script:
    "../scripts/fastqc_summary.R"
SnakeMake From line 117 of rules/fastqc.smk
19
20
21
22
23
shell: """
    rm -f .{wildcards.sample}_{wildcards.species}*.bam
    (samtools sort -n -T .get_fragments_{wildcards.sample}_{wildcards.species} -@ {threads} {input.bam} | \
     bedtools bamtobed -bedpe -i stdin > {output}) &> {log}
    """
33
34
35
36
37
38
shell: """
    (awk 'BEGIN{{FS=OFS="\t"}} {{width=$6-$2}} {{(width % 2 != 0) ? (mid=(width+1)/2+$2) : ((rand()<0.5)? (mid=width/2+$2) : (mid=width/2+$2+1))}} width>0 {{print $1, mid, mid+1, $7}}' {input.bedpe} | \
     sort -k1,1 -k2,2n | \
     bedtools genomecov -i stdin -g <(faidx {input.fasta} -i chromsizes) -bga | \
     LC_COLLATE=C sort -k1,1 -k2,2n > {output}) &> {log}
    """
47
48
49
50
shell: """
    (bedtools genomecov -ibam {input.bam} -bga -pc | \
     LC_COLLATE=C sort -k1,1 -k2,2n > {output}) &> {log}
    """
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
run:
    if wildcards.norm=="libsizenorm" or wildcards.sample in INPUTS:
        shell("""
              (awk -v norm_factor=$(samtools view -c {input.bam_experimental} | \
                                    paste -d "" - <(echo "/1000000") | bc -l) \
               'BEGIN{{FS=OFS="\t"}}{{$4=$4/norm_factor; print $0}}' {input.counts} > {output.normalized}) &> {log}
              """)
    else:
        shell("""
              (awk -v norm_factor=$(paste -d "" \
                      <(samtools view -c {input.bam_spikein}) <(echo "*") \
                      <(samtools view -c {input.input_bam_experimental}) <(echo "/") \
                      <(samtools view -c {input.input_bam_spikein}) <(echo "/1000000") | bc -l) \
                      'BEGIN{{FS=OFS="\t"}}{{$4=$4/norm_factor; print $0}}' {input.counts} > {output.normalized}) &> {log}
              """)
93
94
95
shell: """
    (python scripts/smooth_midpoint_coverage.py -b {params.bandwidth} -i {input} -o {output}) &> {log}
    """
108
109
110
shell: """
    (python scripts/make_ratio_bigwig.py -c {input.ip_sample} -i {input.input_sample} -o {output}) &> {log}
    """
120
121
122
shell: """
    (bedGraphToBigWig {input.bedgraph} <(faidx {input.fasta} -i chromsizes) {output}) &> {log}
    """
18
19
20
21
22
23
run:
    bam = input[0]
    shell("""samtools view {bam} | cut -f9 | sed 's/-//g' | sort -k1,1n -S 80% --parallel {threads} | uniq -c | awk 'BEGIN{{OFS="\t"}}{{print $2, $1}}' > {output}""")
    for bam in input[1:]:
        shell("""join -1 1 -2 2 -t $'\t' -e 0 -a 1 -a 2 --nocheck-order {output} <(samtools view {bam} | cut -f9 | sed 's/-//g' | sort -k1,1n -S 80% --parallel {threads} | uniq -c | awk 'BEGIN{{OFS="\t"}}{{print $1, $2}}') > qual_ctrl/fragment_length_distributions/.frag_length.temp; mv qual_ctrl/fragment_length_distributions/.frag_length.temp {output}""")
    shell("""sed -i "1i {params.header}" {output}""")
32
33
script:
    "../scripts/paired_end_fragment_length.R"
44
45
46
47
48
49
50
51
52
53
run:
    shell("""(echo -e "sample\traw\tcleaned\tmapped\tunique_map\tno_dups" > {output}) &> {log}""")
    for sample, adapter, align, markdup in zip(SAMPLES.keys(), input.adapter, input.align, input.markdup):
        shell("""
              (grep -e "Total read pairs processed:" -e "Pairs written" {adapter} | cut -d: -f2 | sed 's/,//g' | awk 'BEGIN{{ORS="\t"; print "{sample}"}}{{print $1}}' >> {output}
               grep -e "1 time" {align} | awk 'BEGIN{{sum=0; ORS="\t"}} {{sum+=$1}} END{{print sum}}' >> {output}
               grep -e "READ:" -e "WRITTEN:" {markdup} | cut -d ' ' -f2 | awk 'BEGIN{{ORS="\t"}} {{print $1/2}} END{{ORS="\\n"; print ""}}' >> {output}) &>> {log}
               """)
               # grep -e "exactly 1 time" {align} | awk 'BEGIN{{sum=0; ORS="\t"}} {{sum+=$1}} END{{print sum}}' >> {output}
               # grep -e "concordantly exactly 1 time" {align} | awk '{{print $1}}' >> {output}) &> {log}
64
65
script:
    "../scripts/processing_summary.R"
79
80
81
82
83
84
85
86
87
88
89
90
run:
    shell("""(echo -e "sample\tgroup\ttotal_counts_input\texperimental_counts_input\tspikein_counts_input\ttotal_counts_IP\texperimental_counts_IP\tspikein_counts_IP" > {output}) &> {log} """)
    for sample, group, input_exp, input_si ,ip_exp, ip_si in zip(get_samples(spikein=True, paired=True).keys(), params.groups,
                                                                 input.input_bam_experimental, input.input_bam_spikein,
                                                                 input.ip_bam_experimental, input.ip_bam_spikein):
        shell("""(paste <(echo -e "{sample}\t{group}\t") \
                    <(samtools view -c {input_exp}) \
                    <(samtools view -c {input_si}) \
                    <(echo "") \
                    <(samtools view -c {ip_exp}) \
                    <(samtools view -c {ip_si}) | \
                    awk 'BEGIN{{FS=OFS="\t"}} {{$3=$4+$5; $6=$7+$8; print $0}}'>> {output}) &>> {log} """)
104
105
script:
    "../scripts/spikein_abundance_chipseq.R"
18
19
20
21
22
23
24
25
26
27
shell: """
    (bedtools slop -b {params.search_dist} -i {input.peaks} -g <(faidx {input.fasta} -i chromsizes) | \
     sort -k1,1 -k2,2n | \
    bedtools cluster -d 0 -i stdin | \
    bedtools groupby -g 7 -c 5 -o max -full -i stdin | \
    sort -k4,4V | \
    bedtools getfasta -name+ -fi {input.fasta} -bed stdin | \
    awk 'BEGIN{{FS=":|-"}} {{if ($1 ~ />/) {{print $1"::"$3":"$4+1"-"$5+1}} else {{print $0}}}}' \
    > {output}) &> {log}
    """
42
43
44
shell: """
    (meme-chip -oc motifs/{wildcards.annotation}/{wildcards.condition}-v-{wildcards.control}/{wildcards.norm}/{wildcards.condition}-v-{wildcards.control}_{wildcards.factor}-chipseq-{wildcards.norm}-{wildcards.annotation}-diffbind-results-{wildcards.direction}-meme_chip {params.db_command} {input.dbs} -bfile <(fasta-get-markov {input.genome_fasta} -m 1) -order 1 -meme-mod {params.meme_mode} -meme-nmotifs {params.meme_nmotifs} -meme-p 1 -meme-norand -centrimo-local {input.seq}) &> {log}
    """
24
25
26
27
28
shell: """
    (macs2 callpeak --treatment {input.chip_bam} --control {input.input_bam} --format BAMPE --name peakcalling/sample_peaks/{wildcards.sample}_{wildcards.species}-{wildcards.factor}-chipseq --SPMR --gsize $(faidx {input.fasta} -i chromsizes | awk '{{sum += $2}} END {{print sum}}') --slocal {params.slocal} --llocal {params.llocal} --keep-dup auto --bdg --call-summits --max-gap {params.maxgap} -q 1) &> {log}
    (sed -i -e 's/peakcalling\/sample_peaks\///g' {output.peaks}) &>> {log}
    (sed -i -e 's/peakcalling\/sample_peaks\///g' {output.summits}) &>> {log}
    """
46
47
48
49
50
51
52
53
54
shell: """
    (idr -s {input} --input-file-type narrowPeak --rank q.value -o {output.allpeaks} -l {log} --plot --peak-merge-method max) &> {log}
    (awk '$5>{params.idr} || $9=="inf"' {output.allpeaks} | \
     LC_COLLATE=C sort -k1,1 -k2,2n | \
     tee {output.filtered} | \
     awk 'BEGIN{{FS=OFS="\t"}}{{print $1, $2, $3, $4, $5, $6, $7, $11, $12, $10}}' | \
     tee {output.narrowpeak} | \
     awk 'BEGIN{{FS=OFS="\t"}}{{start=$2+$10; print $1, start, start+1, $4, $5, $6}}' > {output.summits}) &>> {log}
    """
64
65
66
67
68
69
shell: """
    (bedtools multiinter -i {input} | \
     bedtools merge -i stdin | \
     awk 'BEGIN{{FS=OFS="\t"}}{{print $1, $2, $3, ".", 0, "."}}' | \
     sort -k1,1 -k2,2n > {output}) &> {log}
    """
11
12
13
14
15
shell: """
    (bedtools makewindows -g <(faidx {input.fasta} -i chromsizes) -w {wildcards.windowsize} | \
     LC_COLLATE=C sort -k1,1 -k2,2n | \
     bedtools map -a stdin -b {input.bg} -c 4 -o sum > {output}) &> {log}
    """
26
27
28
shell: """
    (bedtools unionbedg -i {input} -header -names {params.names} | bash scripts/cleanUnionbedg.sh | pigz -f > {output}) &> {log}
    """
43
44
script:
    "../scripts/plot_scatter_plots.R"
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import argparse
import numpy as np
import pandas as pd
import pyBigWig as pybw

#given paths to bigwig files representing replicates,
#return a dictionary where keys are chromosome names and
#values are the average coverage across replicates
def average_bigwigs(coverage_paths):
    coverage = {}
    for index, path in enumerate(coverage_paths):
        bw = pybw.open(path)
        chroms = bw.chroms()
        for chrom in chroms:
            if index==0:
                coverage[chrom] = bw.values(chrom, 0, chroms[chrom], numpy=True)
            else:
                coverage[chrom] = np.add(bw.values(chrom, 0, chroms[chrom], numpy=True), coverage[chrom])
            if index==len(coverage_paths):
                coverage[chrom] = np.divide(coverage[chrom], index)
        bw.close()
    return coverage

#given "unstranded" chromosome coordinates,
#return 0-based offset of summit position from start.
#If multiple positions have the same max signal, return the mean position
def get_summit(row, coverage):
    local_coverage = coverage[row['chrom']][row['start']:row['end']]
    if not np.any(np.isfinite(local_coverage)):
        return int(len(local_coverage) / 2)
    return int(np.mean(np.argwhere(local_coverage==np.amax(local_coverage[np.isfinite(local_coverage)]))))

def main(condition_paths,
        control_paths,
        diffexp_path,
        narrowpeak_out,
        bed_out):

    #condition and control coverage are imported separately and
    #averaged across replicates in case the number of samples
    #in each group is different
    condition_coverage = average_bigwigs(condition_paths)
    coverage = average_bigwigs(control_paths)
    for chrom in coverage:
        coverage[chrom] = np.add(coverage[chrom], condition_coverage[chrom])

    #we only need to perform operations using start and end as integers,
    #so everything else can be treated as an object to avoid reformatting
    diffexp_df = pd.read_csv(diffexp_path, sep="\t",
                             dtype={'chrom':str,
                                    'start':np.uint32,
                                    'end':np.uint32,
                                    'name':str,
                                    'score':str,
                                    'strand':str,
                                    'log2FC_enrichment':str,
                                    'lfc_SE':str,
                                    'stat':str,
                                    'log10_pval':str,
                                    'log10_padj':str,
                                    'mean_counts':str,
                                    'condition_enrichment':str,
                                    'condition_enrichment_SE':str,
                                    'control_enrichment':str,
                                    'control_enrichment_SE':str})

    if diffexp_df.shape[0] > 0:
        diffexp_df['summit'] = diffexp_df.apply(get_summit, coverage=coverage, axis=1)
        diffexp_df = diffexp_df.assign(summit_start = diffexp_df['start'] + diffexp_df['summit'])
        diffexp_df = diffexp_df.assign(summit_end = diffexp_df['summit_start'] + 1)

    #NOTE: we convert NAs (found in pvalue and score columns) to zero for narrowpeak compatibility
    diffexp_df.to_csv(narrowpeak_out,
                      sep="\t",
                      columns=(['chrom', 'start', 'end', 'name', 'score', 'strand',
                               'log2FC_enrichment', 'log10_pval', 'log10_padj', 'summit'] if
                               diffexp_df.shape[0] > 0 else []),
                      header=False,
                      index=False,
                      float_format="%.3f",
                      encoding='utf-8',
                      na_rep="0")
    diffexp_df.to_csv(bed_out,
                      sep="\t",
                      columns=(['chrom', 'summit_start', 'summit_end', 'name', 'score', 'strand'] if
                          diffexp_df.shape[0] > 0 else []),
                      header=False,
                      index=False,
                      float_format="%.3f",
                      encoding='utf-8',
                      na_rep="0")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Add back summit information to ChIP-seq differential binding results.')
    parser.add_argument('-i', dest = 'condition_paths', type=str, nargs='+', help='BigWigs for all condition samples')
    parser.add_argument('-j', dest = 'control_paths', type=str, nargs='+', help='BigWigs for all control samples')
    parser.add_argument('-d', dest = 'diffexp_path', type=str, help='differential binding results file')
    parser.add_argument('-n', dest = 'narrowpeak_out', type=str, help='output path for narrowPeak file')
    parser.add_argument('-b', dest = 'bed_out', type=str, help='output path for BED file of summit positions')
    args = parser.parse_args()

    main(args.condition_paths,
         args.control_paths,
         args.diffexp_path,
         args.narrowpeak_out,
         args.bed_out)
3
awk 'BEGIN{FS=OFS="\t"} NR==1{ORS="\t"; print "name"; for(k=4;k<NF;k++) print $k; ORS="\n"; print $NF} {ORS="\t"; sum=0; for(i=4;i<=NF;i++) sum+=$i} sum>0{print $1"-"$2"-"$3; for(j=4;j<NF;j++) print $j; ORS="\n"; print $NF}'
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
library(tidyverse)
library(magrittr)
library(DESeq2)
library(gridExtra)

get_countdata = function(path, samples){
    df = read_tsv(path) %>%
        select(samples) %>%
        rowid_to_column(var="index") %>%
        column_to_rownames(var="index") %>%
        as.data.frame()
    df = df[rowSums(df)>1,]
    return(df)
}

initialize_dds = function(data_path,
                          samples,
                          conditions,
                          sample_type,
                          condition_id,
                          control_id){
    dds = DESeqDataSetFromMatrix(countData = get_countdata(data_path, samples),
                                 colData = data.frame(condition = factor(conditions,
                                                                         levels = c(control_id,
                                                                                    condition_id)),
                                                      sample_type = factor(sample_type,
                                                                          levels = c("input",
                                                                                     "ChIP")),
                                                      row.names = samples),
                                 design = ~ sample_type + condition + sample_type:condition)
    return(dds)
}

extract_normalized_counts = function(dds){
    dds %>%
        counts(normalized=TRUE) %>%
        as.data.frame() %>%
        rownames_to_column(var="index") %>%
        as_tibble() %>%
        return()
}

extract_rlog_counts = function(dds){
    dds %>%
        rlog(blind=FALSE) %>%
        assay() %>%
        as.data.frame() %>%
        rownames_to_column(var="index") %>%
        as_tibble() %>%
        return()
}

build_mean_sd_df_pre = function(dds){
     dds %>%
        normTransform() %>%
        assay() %>%
        as_tibble() %>%
        rowid_to_column(var="index") %>%
        gather(sample, signal, -index) %>%
        group_by(index) %>%
        summarise(mean = mean(signal),
                  sd = sd(signal)) %>%
        mutate(rank = min_rank(dplyr::desc(mean))) %>%
        return()
}

build_mean_sd_df_post = function(counts){
    counts %>%
        gather(sample, signal, -index) %>%
        group_by(index) %>%
        summarise(mean = mean(signal),
                  sd = sd(signal)) %>%
        mutate(rank = min_rank(dplyr::desc(mean))) %>%
        return()
}

reverselog_trans <- function(base = exp(1)) {
    trans <- function(x) -log(x, base)
    inv <- function(x) base^(-x)
    scales::trans_new(paste0("reverselog-", format(base)), trans, inv,
              scales::log_breaks(base = base),
              domain = c(1e-100, Inf))
}

mean_sd_plot = function(df, ymax, title){
    ggplot(data = df, aes(x=rank, y=sd)) +
        geom_hex(aes(fill=..count.., color=..count..), bins=100, size=0) +
        geom_smooth(color="#4292c6") +
        scale_fill_viridis_c(option="inferno", name=expression(log[10](count)), guide=FALSE) +
        scale_color_viridis_c(option="inferno", guide=FALSE) +
        scale_x_continuous(trans = reverselog_trans(10),
                           name="rank(mean enrichment)",
                           expand = c(0,0)) +
        scale_y_continuous(limits = c(NA, ymax),
                           name = "SD") +
        theme_light() +
        ggtitle(title) +
        theme(text = element_text(size=8))
}

extract_deseq_results = function(dds,
                                 annotations,
                                 alpha,
                                 lfc){
    control_enrichment = results(dds,
            contrast=c(0,1,0,0),
            tidy=TRUE) %>%
        as_tibble() %>%
        select(row,
               control_enrichment = log2FoldChange,
               control_enrichment_SE = lfcSE)
    condition_enrichment = results(dds,
            contrast=c(0,1,0,1),
            tidy=TRUE) %>%
        as_tibble() %>%
        select(row,
               condition_enrichment = log2FoldChange,
               condition_enrichment_SE = lfcSE)

    results(dds,
            alpha=alpha,
            lfcThreshold=lfc,
            altHypothesis="greaterAbs",
            tidy=TRUE) %>%
        as_tibble() %>%
        left_join(control_enrichment,
                  by="row") %>%
        left_join(condition_enrichment,
                  by="row") %>%
        left_join(annotations, ., by=c("index"="row")) %>%
        arrange(padj) %>%
        mutate(name = if_else(name==".",
                              paste0("peak_", row_number()),
                              name),
               score = as.integer(pmin(-125*log2(padj), 1000))) %>%
        mutate_at(vars(pvalue, padj), ~(-log10(.))) %>%
        mutate_if(is.double, round, 3) %>%
        select(index, chrom, start, end, name, score, strand,
               log2FC_enrichment=log2FoldChange, lfc_SE=lfcSE,
               stat, log10_pval=pvalue, log10_padj=padj, mean_counts=baseMean,
               condition_enrichment, condition_enrichment_SE,
               control_enrichment, control_enrichment_SE) %>%
        return()
}

write_counts_table = function(results_df,
                              annotations,
                              counts_df,
                              output_path){
    results_df %>%
        select(1:7) %>%
        right_join(annotations %>% select(-c(name, score)),
                   by = c("index", "chrom", "start", "end", "strand")) %>%
        left_join(counts_df, by="index") %>%
        select(-index) %>%
        write_tsv(output_path) %>%
        return()
}

plot_ma = function(df_sig = results_df_filtered_significant,
                   df_nonsig = results_df_filtered_nonsignificant,
                   xvar = mean_expr,
                   yvar = log2_enrichment,
                   lfc,
                   condition,
                   control){
    xvar = enquo(xvar)
    yvar = enquo(yvar)
    ggplot() +
        geom_hline(yintercept = 0, color="black", linetype="dashed") +
        geom_hline(yintercept = c(-lfc, lfc), color="grey70", linetype="dashed") +
        stat_bin_hex(data = df_nonsig,
                     geom="point",
                     aes(x=!!xvar, y=!!yvar, alpha=..count..),
                     binwidth = c(.01, 0.01),
                     color="black", stroke=0, size=0.7) +
        stat_bin_hex(data = df_sig,
                     geom="point",
                     aes(x=!!xvar, y=!!yvar, alpha=..count..),
                     binwidth = c(.01, 0.01),
                     color="red", stroke=0, size=0.7) +
        scale_x_log10(name="mean of normalized counts") +
        scale_alpha_continuous(range = c(0.5, 1)) +
        ylab(bquote(log[2]~frac("enrichment in" ~ .(condition),
                                "enrichment in" ~ .(control)))) +
        theme_light() +
        theme(text = element_text(size=8, color="black"),
              axis.text = element_text(color = "black"),
              axis.title.y = element_text(angle=0, hjust=1, vjust=0.5),
              legend.position = "none")
}

plot_volcano = function(df = results_df_filtered,
                        xvar = log2_enrichment,
                        yvar = log10_padj,
                        lfc,
                        alpha,
                        condition,
                        control){
    xvar = enquo(xvar)
    yvar = enquo(yvar)
    ggplot() +
        geom_vline(xintercept = 0, color="black", linetype="dashed") +
        geom_vline(xintercept = c(-lfc, lfc), color="grey70", linetype="dashed") +
        stat_bin_hex(data = df,
                     geom = "point",
                     aes(x = !!xvar, y = !!yvar, color=log10(..count..)),
                     binwidth = c(0.01, 0.1),
                     alpha=0.8, stroke=0, size=0.7) +
        geom_hline(yintercept = -log10(alpha), color="red", linetype="dashed") +
        xlab(bquote(log[2] ~ frac("enrichment in" ~ .(condition),
                                  "enrichment in" ~ .(control)))) +
        ylab(expression(-log[10] ~ FDR)) +
        scale_color_viridis_c(option="inferno") +
        theme_light() +
        theme(text = element_text(size=8),
              axis.title.y = element_text(angle=0, hjust=1, vjust=0.5),
              legend.position = "none")
}

main = function(exp_table="depleted-v-non-depleted_allsamples-experimental-Rpb1-chipseq-counts-verified-coding-genes.tsv.gz",
                spike_table="depleted-v-non-depleted_allsamples-spikein-Rpb1-chipseq-counts-peaks.tsv.gz",
                samples=read_tsv(exp_table) %>% select(-c(1:6)) %>% names(),
                conditions=rep(c(rep("non-depleted",4), rep("depleted",4)), 2),
                sample_type=c(rep("input",8), rep("ChIP", 8)),
                # batches = rep(c(rep(1,2), rep(2,2)), 4),
                norm="spikenorm",
                condition="depleted",
                control="non-depleted",
                alpha=0.1,
                lfc=0,
                counts_norm_out="counts_norm.tsv",
                counts_rlog_out="counts_rlog.tsv",
                results_all_out="results_all.tsv",
                results_up_out="results_up.tsv",
                results_down_out="results_down.tsv",
                results_unchanged_out="results_unch.tsv",
                # bed_all_out="all.bed",
                # bed_up_out="up.bed",
                # bed_down_out="down.bed",
                # bed_unchanged_out="nonsignificant.bed",
                qc_plots_out="qcplots.png"){

    annotations = read_tsv(exp_table) %>%
        select(1:6) %>%
        rownames_to_column(var="index") %>%
        mutate(chrom = str_replace(chrom, "-minus$|-plus$", ""))

    dds = initialize_dds(data_path = exp_table,
                         samples = samples,
                         conditions = conditions,
                         sample_type = sample_type,
                         condition_id = condition,
                         control_id = control)

    if (norm=="spikenorm"){
        dds_spike = initialize_dds(data_path = spike_table,
                                   samples = samples,
                                   conditions = conditions,
                                   sample_type = sample_type,
                                   condition_id = condition,
                                   control_id = control) %>%
            estimateSizeFactors()
        sizeFactors(dds) = sizeFactors(dds_spike)
    } else {
        dds %<>% estimateSizeFactors()
    }
    dds %<>% estimateDispersions() %>% nbinomWaldTest()

    #extract normalized counts and write to file
    counts_norm = extract_normalized_counts(dds = dds)
    counts_rlog = extract_rlog_counts(dds = dds)

    mean_sd_df_pre = build_mean_sd_df_pre(dds)
    mean_sd_df_post = build_mean_sd_df_post(counts_rlog)

    sd_max = max(c(mean_sd_df_pre[["sd"]],
                   mean_sd_df_post[["sd"]]),
                 na.rm=TRUE)*1.01

    mean_sd_plot_pre = mean_sd_plot(df = mean_sd_df_pre,
                                    ymax = sd_max,
                                    title = expression(log[2] ~ "counts," ~ "pre-shrinkage"))
    mean_sd_plot_post = mean_sd_plot(df = mean_sd_df_post,
                                     ymax = sd_max,
                                     title = expression(regularized ~ log[2] ~ "counts"))

    results_df = extract_deseq_results(dds = dds,
                                       annotations = annotations,
                                       alpha = alpha,
                                       lfc = lfc) %>%
        mutate(chrom = str_replace(chrom, "-minus$|-plus$", ""))

    write_counts_table(results_df = results_df,
                       annotations = annotations,
                       counts_df = counts_norm,
                       output_path = counts_norm_out)
    write_counts_table(results_df = results_df,
                       annotations = annotations,
                       counts_df = counts_rlog,
                       output_path = counts_rlog_out)

    results_df %<>%
        select(-index) %>%
        write_tsv(results_all_out)
    # results_df %>%
    #     select(1:6) %>%
    #     write_tsv(bed_all_out, col_names=FALSE)

    results_df_significant = results_df %>%
        filter(log10_padj > -log10(alpha))
    results_df_nonsignificant = results_df %>%
        filter(log10_padj <= -log10(alpha)) %>%
        write_tsv(results_unchanged_out)
    # results_df_nonsignificant %>%
    #     select(1:6) %>%
    #     write_tsv(bed_unchanged_out, col_names=FALSE)

    results_df_significant %>%
        filter(log2FC_enrichment >= 0) %>%
        write_tsv(results_up_out)
        # write_tsv(results_up_out) %>%
        # select(1:6) %>%
        # write_tsv(bed_up_out, col_names=FALSE)

    results_df_significant %>%
        filter(log2FC_enrichment < 0) %>%
        write_tsv(results_down_out)
        # write_tsv(results_down_out) %>%
        # select(1:6) %>%
        # write_tsv(bed_down_out, col_names=FALSE)

    maplot = plot_ma(df_sig = results_df_significant,
                     df_nonsig = results_df_nonsignificant,
                     xvar = mean_counts,
                     yvar = log2FC_enrichment,
                     lfc = lfc,
                     condition = condition,
                     control = control)

    volcano = plot_volcano(df = results_df,
                           xvar = log2FC_enrichment,
                           yvar = log10_padj,
                           lfc = lfc,
                           alpha = alpha,
                           condition = condition,
                           control = control)

    qc_plots = arrangeGrob(mean_sd_plot_pre,
                           mean_sd_plot_post,
                           maplot,
                           volcano,
                           ncol=2)

    ggsave(qc_plots_out,
           plot = qc_plots,
           width = 16*1.5,
           height = 9*1.5,
           units="cm")
}

main(exp_table = snakemake@input[["exp_counts"]],
     spike_table = snakemake@input[["spike_counts"]],
     samples = snakemake@params[["samples"]],
     conditions = snakemake@params[["conditions"]],
     sample_type = snakemake@params[["sampletypes"]],
     norm = snakemake@wildcards[["norm"]],
     condition = snakemake@wildcards[["condition"]],
     control = snakemake@wildcards[["control"]],
     alpha = snakemake@params[["alpha"]],
     lfc = snakemake@params[["lfc"]],
     counts_norm_out = snakemake@output[["counts_norm"]],
     counts_rlog_out = snakemake@output[["counts_rlog"]],
     results_all_out = snakemake@output[["results_all"]],
     results_up_out = snakemake@output[["results_up"]],
     results_down_out = snakemake@output[["results_down"]],
     results_unchanged_out = snakemake@output[["results_nonsig"]],
     # bed_all_out = snakemake@output[["bed_all"]],
     # bed_up_out = snakemake@output[["bed_up"]],
     # bed_down_out = snakemake@output[["bed_down"]],
     # bed_unchanged_out = snakemake@output[["bed_nonsig"]],
     qc_plots_out = snakemake@output[["qc_plots"]])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
library(tidyverse)
library(forcats)
library(viridis)
library(ggthemes)
library(ggrepel)

import = function(path){
    read_tsv(path) %>%
        mutate_at(vars(sample, status), funs(fct_inorder(., ordered=TRUE))) %>%
        return()
}

main = function(seq_len_dist_in, per_tile_in, per_base_qual_in,
                per_base_seq_in, per_base_n_in, per_seq_gc_in,
                per_seq_qual_in, adapter_content_in, seq_dup_in,
                seq_len_dist_out, per_tile_out, per_base_qual_out,
                per_base_seq_out, per_seq_gc_out, per_seq_qual_out,
                adapter_content_out, seq_dup_out){
    #damnit fastqc...why bin some of the data and then output in this shite format...
    length_distribution = import(seq_len_dist_in) %>%
        separate(length, into=c('a','b'), sep="-", fill="right", convert=TRUE) %>%
        mutate_at(vars(count), funs(if_else(is.na(b), ., ./2))) %>%
        gather(key, length, c(a,b)) %>%
        filter(!is.na(length)) %>%
        select(-key)

    nsamples = n_distinct(length_distribution$sample)

    per_tile_quality = import(per_tile_in) %>%
        mutate_at(vars(tile), funs(fct_inorder(as.character(.), ordered=TRUE)))

    per_base_qual = import(per_base_qual_in) %>%
        left_join(length_distribution, by=c("base"="length", "sample", "status")) %>%
        group_by(sample, status) %>%
        mutate_at(vars(count), funs(if_else(is.na(.), 0, .))) %>%
        mutate(n = lag(sum(count)-cumsum(count), default=sum(count))) %>%
        mutate_at(vars(n), funs(./max(n)))

    adapter_content = import(adapter_content_in)

    per_base_seq = import(per_base_seq_in) %>%
        left_join(import(per_base_n_in), by=c("base", "sample","status")) %>%
        rename(position=base, n=n_count) %>%
        gather(base, pct, -c(position, sample, status)) %>%
        mutate_at(vars(base), funs(toupper(.))) %>%
        mutate_at(vars(base), funs(fct_inorder(., ordered=TRUE)))

    per_seq_gc = import(per_seq_gc_in) %>%
        filter(count > 0) %>%
        group_by(sample, status) %>%
        mutate(norm_count = count/sum(count))

    per_seq_qual = import(per_seq_qual_in) %>%
        filter(count > 0) %>%
        group_by(sample, status) %>%
        mutate(norm_count = count/sum(count))

    duplication_levels = import(seq_dup_in) %>%
        mutate_at(vars(duplication_level), funs(fct_inorder(., ordered=TRUE)))

    #kmer_content = import(kmer_in)

    theme_standard = theme_light() +
        theme(text = element_text(size=12, color="black", face="bold"),
              axis.text = element_text(size=12, color="black"),
              axis.title = element_text(size=12, color="black", face="bold"),
              strip.placement = "outside",
              strip.background = element_blank(),
              strip.text = element_text(size=12, color="black", face="bold"),
              strip.text.y = element_text(angle=-180, hjust=1))

    length_dist_plot = ggplot(data = length_distribution %>%
                                     group_by(sample, status) %>%
                                     mutate(normcount = count/max(count)),
                                 aes(x=length, y=normcount)) +
        geom_col(fill="#114477") +
        scale_x_continuous(breaks=scales::pretty_breaks(n=6), name="read length (nt)") +
        scale_y_continuous(breaks=scales::pretty_breaks(n=2), name="normalized counts") +
        facet_grid(sample~status, scales="free_y", switch="y") +
        ggtitle("read length distributions") +
        theme_standard +
        theme(axis.text.y = element_text(size=10, face="plain"))

    ggsave(seq_len_dist_out, plot=length_dist_plot, width=26, height=2+2*nsamples, units="cm", limitsize=FALSE)

    tile_quality_plot = ggplot(data = per_tile_quality %>% filter(status=="raw"),
                               aes(x=base, y=tile, fill=mean)) +
        geom_raster() +
        scale_fill_viridis(direction=-1, guide=guide_colorbar(title="mean\nquality\nscore", barheight=10)) +
        scale_x_continuous(expand=c(0,0), name="cycle number", breaks=scales::pretty_breaks(n=6)) +
        ylab("flow cell tile") +
        ggtitle("per tile sequencing quality") +
        facet_grid(sample~status, scales="free_y", switch="y") +
        theme_standard +
        theme(axis.text.y = element_blank(),
              axis.ticks.y = element_blank(),
              panel.grid.major.y = element_blank(),
              strip.text.x = element_blank())

    ggsave(per_tile_out, plot=tile_quality_plot, width=24, height=2+2*nsamples, units="cm", limitsize=FALSE)

    per_base_qual_plot = ggplot(data=per_base_qual, aes(x=base, y=fct_rev(sample),
                                               height=n, fill=mean, color=mean)) +
        geom_tile() +
        scale_color_viridis(guide=FALSE, direction=-1) +
        scale_fill_viridis(guide=guide_colorbar(title="mean quality", barwidth=12,
                                                barheight=1, title.position = "top",
                                                title.hjust=0.5),
                           direction=-1) +
        scale_x_continuous(expand=c(0,0), name="read length (nt)", breaks=scales::pretty_breaks(n=6)) +
        facet_grid(sample~status, scales="free_y", switch="y") +
        ggtitle("per base sequencing quality",
                subtitle = expression("bar height " %prop% " fraction of reads")) +
        theme_standard +
        theme(axis.text.y = element_blank(),
              axis.title.y = element_blank(),
              plot.subtitle = element_text(size=12),
              legend.position="top",
              legend.margin = margin(0,0,0,0))

    ggsave(per_base_qual_out, plot=per_base_qual_plot, width=26, height=2.5+1.25*nsamples, units="cm", limitsize=FALSE)

    adapter_plot = ggplot(data = adapter_content, aes(x=position, y=0, fill=pct, color=pct)) +
        geom_raster() +
        scale_color_viridis(guide=FALSE) +
        scale_fill_viridis(guide=guide_colorbar(title="% reads with adapter", barwidth=12,
                                                barheight=1, title.position = "top",
                                                title.hjust=0.5)) +
        scale_x_continuous(expand=c(0,0), name="read length (nt)") +
        scale_y_continuous(expand=c(0,0), name=NULL, breaks=0, labels=NULL) +
        facet_grid(sample~status, switch="y") +
        ggtitle("adapter content") +
        theme_standard +
        theme(legend.position="top",
              legend.margin = margin(0,0,0,0))

    ggsave(adapter_content_out, plot=adapter_plot, width=32, height=2+1.25*nsamples, units="cm", limitsize=FALSE)

    per_base_seq_plot = ggplot(data = per_base_seq, aes(x=position, y=pct, color=base)) +
        geom_line() +
        scale_color_ptol(guide=guide_legend(label.position="top", label.hjust=0.5,
                                            keyheight=0.2)) +
        scale_x_continuous(expand=c(0,1), name="position in read", breaks=scales::pretty_breaks(n=6)) +
        scale_y_continuous(name="% of reads", breaks=scales::pretty_breaks(n=2)) +
        facet_grid(sample~status, switch="y") +
        ggtitle("per base sequence content") +
        theme_standard +
        theme(legend.position="top",
              legend.title = element_blank(),
              legend.margin = margin(0,0,0,0),
              legend.key.size = unit(1, "cm"),
              legend.text = element_text(size=12, face="bold"),
              axis.text.y = element_text(size=10, face="plain"))

    ggsave(per_base_seq_out, plot=per_base_seq_plot, width=32, height=2+2.25*nsamples, units="cm", limitsize=FALSE)

    per_seq_gc_plot = ggplot(data = per_seq_gc, aes(x=gc_content, y=norm_count)) +
        geom_line(color="#114477") +
        scale_x_continuous(expand=c(0,0), name="GC%") +
        #xlab("GC%") +
        scale_y_continuous(breaks=scales::pretty_breaks(n=2), name="normalized counts") +
        facet_grid(sample~status, switch="y") +
        ggtitle("per sequence GC content") +
        theme_standard +
        theme(axis.text.y = element_text(size=10, face="plain"),
              panel.spacing.x = unit(1, "cm"),
              plot.margin = margin(5.5, 12, 5.5, 5.5, unit="pt"))

    ggsave(per_seq_gc_out, plot=per_seq_gc_plot, width=26, height=2+2*nsamples, units="cm", limitsize=FALSE)

    per_seq_qual_plot = ggplot(data = per_seq_qual, aes(x=quality, y=norm_count)) +
        geom_col(fill="#114477") +
        scale_x_continuous(breaks=scales::pretty_breaks(n=5), name="quality score") +
        scale_y_continuous(breaks=scales::pretty_breaks(n=2), name="normalized counts") +
        facet_grid(sample~status, switch="y") +
        ggtitle("per sequence quality scores") +
        theme_standard +
        theme(axis.text.y = element_text(size=10, face="plain"))

    ggsave(per_seq_qual_out, plot=per_seq_qual_plot, width=26, height=2+1.5*nsamples, units="cm", limitsize=FALSE)

    dup_level_plot = ggplot(data = duplication_levels, aes(x=duplication_level, y=pct_of_total)) +
        geom_col(fill="#114477") +
        xlab("duplication level") +
        ylab("% of total reads") +
        facet_grid(sample~status, switch="y") +
        ggtitle("sequence duplication levels") +
        theme_standard +
        theme(axis.text.x = element_text(size=10, face="plain", angle=60, hjust=1),
              axis.text.y = element_text(size=10, face="plain"))

    ggsave(seq_dup_out, plot=dup_level_plot, width=26, height=2+1.5*nsamples, units="cm", limitsize=FALSE)

    ##ermmm...no obvious way to make this one look nice, but then it doesn't really need to
    #kmer_content_plot = ggplot(data = kmer_content, aes(x=max_position, y=log2(obs_over_exp_max), label=sequence)) +
    #    geom_point(shape=16, stroke=0, size=1, alpha=0.5) +
    #    geom_label_repel(size=2, label.size=unit(0.05, "pt"), label.padding=unit(0.1, "pt"), label.r=unit(0,"pt"), segment.size=0.1,
    #                     box.padding=unit(0.05,"pt"), segment.alpha=0.4) +
    #    xlab("position in read") +
    #    ylab(expression(bold(log[2]~ frac("observed", "expected")))) +
    #    ggtitle("k-mer content",
    #            subtitle = "top 20 overrepresented k-mers") +
    #    facet_grid(sample~status, switch="y", scales="free_y") +
    #    theme_standard + theme(plot.subtitle = element_text(size=12, face="plain"))
    #
    #ggsave(kmer_out, plot=kmer_content_plot, width=35, height=2+5*nsamples, units="cm", limitsize=FALSE)
}

main(seq_len_dist_in = snakemake@input[["seq_length_dist"]],
     per_tile_in = snakemake@input[["per_tile_qual"]],
     per_base_qual_in = snakemake@input[["per_base_qual"]],
     per_base_seq_in = snakemake@input[["per_base_seq_content"]],
     per_base_n_in = snakemake@input[["per_base_n"]],
     per_seq_gc_in = snakemake@input[["per_seq_gc"]],
     per_seq_qual_in = snakemake@input[["per_seq_qual"]],
     adapter_content_in = snakemake@input[["adapter_content"]],
     seq_dup_in = snakemake@input[["seq_duplication"]],
     #kmer_in = snakemake@input[["kmer"]],
     seq_len_dist_out = snakemake@output[["seq_length_dist"]],
     per_tile_out = snakemake@output[["per_tile_qual"]],
     per_base_qual_out = snakemake@output[["per_base_qual"]],
     per_base_seq_out = snakemake@output[["per_base_seq_content"]],
     per_seq_gc_out = snakemake@output[["per_seq_gc"]],
     per_seq_qual_out = snakemake@output[["per_seq_qual"]],
     adapter_content_out = snakemake@output[["adapter_content"]],
     seq_dup_out = snakemake@output[["seq_duplication"]])
     #kmer_out = snakemake@output[["kmer"]])
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
import numpy as np
import pyBigWig as pybw

def main(chip_in="non-depleted-Rpb1-IP-3_Rpb1-chipseq-spikenorm-midpoints_smoothed.bw",
         input_in="non-depleted-untagged-input-3_Rpb1-chipseq-spikenorm-midpoints_smoothed.bw",
         ratio_out="ratio.bw"):
    chip = pybw.open(chip_in)
    input = pybw.open(input_in)
    ratio = pybw.open(ratio_out, "w")

    assert chip.chroms() == input.chroms(), "ChIP and input bigWig chromosomes don't match."

    ratio.addHeader(list(chip.chroms().items()))

    for chrom in chip.chroms():
        chip_values = chip.values(chrom, 0, chip.chroms(chrom), numpy=True)
        input_values = input.values(chrom, 0, chip.chroms(chrom), numpy=True)
        ratio.addEntries(chrom, 0, values=np.log2(np.divide(chip_values, input_values)), span=1, step=1)

    chip.close()
    input.close()
    ratio.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Given two bigWig coverage files, generate a coverage file of their log2 ratio.')
    parser.add_argument('-c', dest='chip_in', type=str, help='Path to numerator (ChIP) bigWig.')
    parser.add_argument('-i', dest='input_in', type=str, help='Path to denominator (input) bigWig.')
    parser.add_argument('-o', dest='ratio_out', type=str, help='Path to output bigWig.')
    args = parser.parse_args()
    main(chip_in=args.chip_in,
         input_in=args.input_in,
         ratio_out=args.ratio_out)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
library(argparse)
library(tidyverse)
library(magrittr)

parser = ArgumentParser()
parser$add_argument('-i', '--input', type='character')
parser$add_argument('-r', '--refpt', type='character', nargs="+")
parser$add_argument('-g', '--group', type='character', nargs="+")
parser$add_argument('-s', '--sample', type='character', nargs="+")
parser$add_argument('-t', '--sampletype', type='character', nargs=1)
parser$add_argument('-a', '--annotation', type='character', nargs="+")
parser$add_argument('-b', '--binsize', type='integer')
parser$add_argument('-u', '--upstream', type='integer')
parser$add_argument('-o', '--output', type='character')

args = parser$parse_args()

melt = function(inmatrix, refpt, group, sample, sampletype,
                annotation, binsize, upstream, outpath){
    raw = read_tsv(inmatrix, skip=3, col_names=FALSE)
    names(raw) = seq(ncol(raw))

    df = raw %>%
          rownames_to_column(var="index") %>%
          gather(key = variable, value=value, -index, convert=TRUE) %>%
          filter(!is.na(value)) %>%
          transmute(group = group, sample = sample,
                    sampletype = sampletype, annotation = annotation,
                    index = as.integer(index),
                    position = variable,
                    cpm = as.numeric(value))
    if(binsize>1){
        df %<>% mutate(position = (as.numeric(position)*binsize-(upstream+1.5*binsize))/1000)
    } else if (refpt=="TES"){
        df %<>% mutate(position = (as.numeric(position)-(1+upstream))/1000)
    } else {
        df %<>% mutate(position = (as.numeric(position)-(2+upstream))/1000)
    }
    write_tsv(df, path=outpath, col_names=FALSE)
    return(df)
}

melt(inmatrix = args$input,
     refpt = args$refpt,
     group = paste(args$group, collapse=" "),
     sample = paste(args$sample, collapse=" "),
     sampletype = args$sampletype,
     annotation = paste(args$annotation, collapse=" "),
     binsize = args$binsize,
     upstream = args$upstream,
     outpath = args$output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
library(tidyverse)
library(forcats)

main = function(in_table, out_path){
    df = read_tsv(in_table) %>%
        gather(key=sample, value=count, -fragsize) %>%
        mutate(sample = fct_inorder(sample, ordered=TRUE)) %>%
        group_by(sample) %>%
        mutate(density = count/sum(count, na.rm=TRUE))

    plot = ggplot(data = df, aes(x=fragsize, y=density)) +
        geom_area(fill="#114477", color="black") +
        facet_grid(sample~., switch="y") +
        scale_y_continuous(breaks = scales::pretty_breaks(n=2)) +
        xlab("fragment size (bp)") +
        theme_light() +
        theme(text = element_text(size=12, color="black", face="bold"),
              axis.text = element_text(color="black"),
              axis.text.x = element_text(size=12),
              axis.text.y = element_text(face="plain"),
              strip.background = element_blank(),
              strip.text = element_text(color="black", size=12),
              strip.placement = "outside",
              strip.text.y = element_text(angle=-180, hjust=1))

    ggsave(out_path, plot=plot,
           width=24, height=2+1.5*n_distinct(df[["sample"]]),
           units="cm", limitsize=FALSE)
}

main(in_table = snakemake@input[["table"]],
     out_path = snakemake@output[["plot"]])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
library(tidyverse)
library(magrittr)
library(viridis)
library(psych)
library(seriation)
library(ggthemes)
library(gtable)

hmap_ybreaks = function(limits){
    if (max(limits)-min(limits) >= 2000){
        return(seq(min(limits)+500, max(limits)-500, 500))
    } else if (between(max(limits)-min(limits), 200, 1999)){
        return(seq(min(limits)+100, max(limits)-100, 100))
    } else {
        return(round((max(limits)-min(limits))/2))
    }
}

main = function(in_path, samplelist, anno_paths, ptype, readtype, upstream, dnstream, scaled_length, pct_cutoff, log_transform, pcount,
                trim_pct, factorlabel, spread_type, refptlabel, endlabel, cmap, sortmethod, cluster_scale, cluster_samples, cluster_five, cluster_three, k, assay,
                heatmap_sample_out, heatmap_group_out, meta_sample_out, meta_sample_overlay_out, meta_group_out, meta_sampleanno_out, meta_groupanno_out, meta_sampleclust_out, meta_groupclust_out,
                anno_out, cluster_out){

    hmap = function(df, flimit, readtype="midpoint", logtxn=log_transform){

        heatmap_base = ggplot(data = df) +
            geom_vline(xintercept = 0, size=1.5)

        if (ptype=="scaled"){
            heatmap_base = heatmap_base +
                geom_vline(xintercept = scaled_length/1000, size=1.5)
        }

        if (logtxn){
            heatmap_base = heatmap_base +
                geom_raster(aes(x=position, y=new_index, fill=log2(cpm+pcount)), interpolate=FALSE) +
                scale_fill_viridis(option = cmap,
                                   name = bquote(log[2] ~ .(factorlabel) ~ "ChIP-seq" ~ .(readtype)),
                                   limits = c(NA, flimit), oob=scales::squish,
                                   guide=guide_colorbar(title.position="top",
                                                        barwidth=20,
                                                        barheight=1,
                                                        title.hjust=0.5))
        } else {
            heatmap_base = heatmap_base +
                geom_raster(aes(x=position, y=new_index, fill=cpm), interpolate=FALSE) +
                scale_fill_viridis(option = cmap,
                                   name = paste(factorlabel, "ChIP-seq", readtype),
                                   limits = c(NA, flimit), oob=scales::squish,
                                   guide=guide_colorbar(title.position="top",
                                                        barwidth=20, barheight=1, title.hjust=0.5))
        }

        heatmap_base = heatmap_base +
            scale_y_reverse(expand=c(0.005,5), breaks=hmap_ybreaks) +
            theme_minimal() +
            theme(text = element_text(size=16, face="plain", color="black"),
                  legend.position = "top",
                  legend.title = element_text(size=16, face="plain", color="black"),
                  legend.text = element_text(size=12, face="plain"),
                  legend.margin = margin(0,0,0,0),
                  legend.box.margin = margin(0,0,0,0),
                  strip.text.x = element_text(size=16, face="plain", color="black"),
                  axis.ticks.length = unit(0.125, "cm"),
                  axis.ticks = element_line(size=1.5),
                  axis.ticks.y = element_blank(),
                  axis.text.y = element_blank(),
                  axis.text.x = element_text(size=16, face="plain", color="black", margin = unit(c(3,0,0,0),"pt")),
                  axis.title.x = element_text(size=12, face="plain"),
                  axis.title.y = element_blank(),
                  panel.grid.major.x = element_line(color="black", size=1.5),
                  panel.grid.minor.x = element_line(color="black"),
                  panel.grid.major.y = element_line(color="black"),
                  panel.grid.minor.y = element_blank(),
                  panel.spacing.x = unit(.8, "cm"))

        if (ptype=="absolute"){
            heatmap_base = heatmap_base +
                scale_x_continuous(breaks=scales::pretty_breaks(n=3),
                                   labels= function(x){if_else(x==0, refptlabel,
                                                               if(upstream>500 | dnstream>500){as.character(x)}
                                                               else {as.character(x*1000)})},
                                   name=paste("distance from", refptlabel, if(upstream>500 | dnstream>500){"(kb)"}
                                              else {"(nt)"}),
                                   limits = c(-upstream/1000, dnstream/1000),
                                   expand=c(0,0.025))
        } else {
            heatmap_base = heatmap_base +
                scale_x_continuous(breaks=c(0, (scaled_length/2)/1000, scaled_length/1000),
                                   labels=c(refptlabel, "", endlabel),
                                   name="scaled distance",
                                   limits = c(-upstream/1000, (scaled_length+dnstream)/1000),
                                   expand=c(0,0.025))
        }
        return(heatmap_base)
    }

    meta = function(df, groupvar="sample", strand="protection"){
        if (groupvar=="sample"){
            metagene = ggplot(data = df, aes(x=position,
                                             group=interaction(sample, sampletype),
                                             color=group, fill=group))
        } else if (groupvar=="group"){
            metagene = ggplot(data = df, aes(x=position,
                                             group=interaction(group, sampletype),
                                             color=group, fill=group))
        } else if (groupvar=="sampleclust"){
            metagene = ggplot(data = df,
                              aes(x=position,
                                  group=interaction(sample, sampletype, cluster),
                                  color=factor(cluster), fill=factor(cluster)))
        } else if (groupvar=="groupclust"){
            metagene = ggplot(data = df,
                              aes(x=position,
                                  group=interaction(group, sampletype, cluster),
                                  color=factor(cluster), fill=factor(cluster)))
        } else if (groupvar=="sampleanno"){
            metagene = ggplot(data = df,
                              aes(x=position,
                                  group=interaction(sample, sampletype, annotation, cluster),
                                  color=interaction(annotation, cluster),
                                  fill=interaction(annotation, cluster)))
        } else if (groupvar=="groupanno"){
            metagene = ggplot(data = df %>%
                                  mutate(coloring = interaction(annotation, cluster)),
                              aes(x=position, color=coloring, fill=coloring))
        }

        metagene = metagene +
            geom_vline(xintercept = 0, size=1, color="grey65")

        if (ptype=="scaled"){
            metagene = metagene +
                geom_vline(xintercept = scaled_length/1000, size=1, color="grey65")
        }

        if (readtype=="enrichment"){
            metagene = metagene +
                geom_ribbon(aes(ymin=low,
                                ymax=high),
                            size=0, alpha=0.2) +
            geom_line(aes(y=mid))
        } else {
            metagene = metagene +
                geom_ribbon(aes(ymin=low,
                                ymax=high,
                                alpha=sampletype),
                            size=0) +
            geom_line(aes(y=mid,
                          linetype=sampletype)) +
            scale_linetype_manual(values = c("dashed", "solid"),
                                  guide=guide_legend(label.position=ifelse(groupvar %in% c("sampleanno", "groupanno"), "right", "top"),
                                                     label.hjust=ifelse(groupvar %in% c("sampleanno", "groupanno"), 0, 0.5))) +
            scale_alpha_manual(values=c(0.05, 0.2),
                               guide=guide_legend(label.position=ifelse(groupvar %in% c("sampleanno", "groupanno"), "right", "top"),
                                                  label.hjust=ifelse(groupvar %in% c("sampleanno", "groupanno"), 0, 0.5)))
        }

        metagene = metagene +
            scale_y_continuous(limits = c(NA, NA),
                               name=ifelse(readtype=="enrichment",
                                           expression(textstyle(frac("IP", "input"))),
                                           "normalized counts")) +
            # scale_color_manual(values=rep(ptol_pal()(min(n_groups, 12)), ceiling(n_groups/12)),
            scale_color_ptol(guide=guide_legend(label.position=ifelse(groupvar %in% c("sampleanno", "groupanno"), "right", "top"),
                                                label.hjust=ifelse(groupvar %in% c("sampleanno", "groupanno"), 0, 0.5))) +
            # scale_fill_manual(values=rep(ptol_pal()(min(n_groups, 12)), ceiling(n_groups/12))) +
            scale_fill_ptol() +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype)) +
            theme_light() +
            theme(text = element_text(size=12, color="black", face="plain"),
                  axis.text = element_text(size=12, color="black"),
                  axis.text.y = element_text(size=10, face="plain"),
                  axis.title = element_text(size=10, face="plain"),
                  strip.placement="outside",
                  strip.background = element_blank(),
                  strip.text = element_text(size=12, color="black", face="plain"),
                  legend.text = element_text(size=12),
                  legend.title = element_blank(),
                  legend.position = ifelse(groupvar %in% c("sampleanno", "groupanno"), "bottom", "top"),
                  legend.key.width = unit(3, "cm"),
                  plot.title = element_text(size=12),
                  plot.subtitle = element_text(size=10, face="plain"),
                  panel.spacing.x = unit(0.8, "cm"))

        if (ptype=="absolute"){
            metagene = metagene +
                scale_x_continuous(breaks=scales::pretty_breaks(n=3),
                                   labels= function(x){if_else(x==0, refptlabel,
                                                               if(upstream>500 | dnstream>500){as.character(x)}
                                                               else {as.character(x*1000)})},
                                   name=paste("distance from", refptlabel, if(upstream>500 | dnstream>500){"(kb)"}
                                              else {"(nt)"}),
                                   limits = c(-upstream/1000, dnstream/1000),
                                   expand=c(0,0))
        } else {
            metagene = metagene +
                scale_x_continuous(breaks=c(0, (scaled_length/2)/1000, scaled_length/1000),
                                   labels=c(refptlabel, "", endlabel),
                                   name="scaled distance",
                                   limits = c(-upstream/1000, (scaled_length+dnstream)/1000),
                                   expand=c(0,0))
        }

        if (groupvar %in% c("sampleclust", "groupclust")){
            metagene = metagene +
                scale_color_colorblind(guide=guide_legend(label.position="top", label.hjust=0.5)) +
                scale_fill_colorblind()
        }
        return(metagene)
    }

    nest_right_facets = function(ggp, level=2, outer="replicate", inner="annotation"){
        og_grob = ggplotGrob(ggp)
        strip_loc = grep("strip-r", og_grob[["layout"]][["name"]])
        strip = gtable_filter(og_grob, "strip-r", trim=FALSE)
        strip_heights = gtable_filter(og_grob, "strip-r")[["heights"]]

        strip_top = min(strip[["layout"]][["t"]])
        strip_bot = max(strip[["layout"]][["b"]])
        strip_x = strip[["layout"]][["r"]][1]

        mat = matrix(vector("list", length=(length(strip)*2-1)*level), ncol=level)
        mat[] = list(zeroGrob())

        facet_grob = gtable_matrix("rightcol", grobs=mat,
                                   widths=unit(rep(1,level), "null"),
                                   heights=strip_heights)

        if(level==3){
            rep_grob_indices = seq(1, length(strip_loc), sum(k))
            for (rep_idx in 1:max_reps){
                #add replicate facet label
                facet_grob %<>%
                        gtable_add_grob(grobs = og_grob$grobs[[strip_loc[rep_grob_indices[rep_idx]]]]$grobs[[level]],
                                        t = ((sum(k)*2))*(rep_idx-1)+1,
                                        b = ((sum(k)*2))*(rep_idx)-1,
                                        l = level, r = level)
                #for each annotation within each replicate
                for (anno_idx in 1:n_anno){
                    t = ((sum(k)*2))*(rep_idx-1)+1+sum(k[1:anno_idx])-k[1]+2*(anno_idx-1)
                    b = t + k[anno_idx]
                    facet_grob %<>%
                            gtable_add_grob(grobs = og_grob$grobs[[strip_loc[rep_grob_indices[rep_idx]]+
                                                                       sum(k[1:anno_idx])-k[1]]]$grobs[[2]],
                                            t = t, b = b, l = 2, r = 2)
                }
            }
        } else if(level==2){
            if (outer=="annotation"){
                outer_grob_indices = 1+lag(k, default=0)
                n_outer = n_anno
            } else if (outer=="replicate"){
                outer_grob_indices = seq(1, length(strip_loc), sum(k))
                n_outer = max_reps
            }
            for (idx in 1:n_outer){
                if (outer=="annotation"){
                    t=((k[idx]*2))*(idx-1)+1
                    b=((k[idx]*2))*(idx)-1
                } else if (outer=="replicate"){
                    if (inner=="cluster"){
                        t=((k*2))*(idx-1)+1
                        b=((k*2))*(idx)-1
                    } else {
                        t = (n_anno*2)*(idx-1)+1
                        b = (n_anno*2)*(idx)-1
                    }
                }
                facet_grob %<>%
                    gtable_add_grob(grobs = og_grob$grobs[[strip_loc[outer_grob_indices[idx]]]]$grobs[[2]],
                                    t=t, b=b, l=2, r=2)
            }
        }
        new_grob = gtable_add_grob(og_grob, facet_grob, t=strip_top, r=strip_x, l=strip_x, b=strip_bot, name='rstrip')
        return(new_grob)
    }

    nest_top_facets = function(ggp, level=2, inner="cluster", intype="gg"){
        if (intype=="gg"){
            og_grob = ggplotGrob(ggp)
        } else if (intype=="gtable"){
            og_grob = ggp
        }

        strip_loc = grep("strip-t", og_grob[["layout"]][["name"]])
        strip = gtable_filter(og_grob, "strip-t", trim=FALSE)
        strip_widths = gtable_filter(og_grob, "strip-t")[["widths"]]

        strip_l = min(strip[["layout"]][["l"]])
        strip_r = max(strip[["layout"]][["r"]])
        strip_y = strip[["layout"]][["t"]][1]

        mat = matrix(vector("list", length=(length(strip)*2-1)*level), nrow=level)
        mat[] = list(zeroGrob())

        facet_grob = gtable_matrix("toprow", grobs=mat,
                                   heights=unit(rep(1,level), "null"),
                                   widths=strip_widths)
        if (inner=="cluster"){
            outer_grob_indices = 1+lag(k, default=0)
            n_outer = n_anno
        } else if (inner=="strand"){
            outer_grob_indices = seq(1, n_groups*2, 2)
            n_outer = n_groups
        }

        for (idx in 1:n_outer){
            if (inner=="cluster"){
                l=((k[idx]*2))*(idx-1)+1
                r=((k[idx]*2))*(idx)-1
            } else if (inner=="strand"){
                l=4*(idx-1)+1
                r=4*(idx)-1
            }

            facet_grob %<>%
                gtable_add_grob(grobs = og_grob$grobs[[strip_loc[outer_grob_indices[idx]]]]$grobs[[1]],
                                l=l, r=r, t=1, b=1)
        }
        new_grob = gtable_add_grob(og_grob, facet_grob, t=strip_y, r=strip_r, l=strip_l, b=strip_y, name='rstrip')
        return(new_grob)
    }

    df = read_tsv(in_path, col_names = c("group", "sample", "sampletype", "annotation", "index", "position", "cpm")) %>%
        filter((sample %in% samplelist | sample %in% cluster_samples) & !is.na(cpm)) %>%
        group_by(annotation) %>%
        mutate(annotation_labeled = paste(n_distinct(index), annotation),
               sampletype = ordered(sampletype, levels=c("input", "ChIP"))) %>%
        ungroup() %>%
        mutate(annotation = annotation_labeled) %>%
        select(-annotation_labeled)%>%
        mutate_at(vars(group, sample, annotation), ~(fct_inorder(., ordered=TRUE)))

    #get replicate info for sample facetting
    repl_df = df %>%
        select(group, sample) %>%
        distinct() %>%
        group_by(group) %>%
        mutate(replicate=row_number()) %>%
        ungroup() %>%
        select(-group)
    max_reps = max(repl_df[["replicate"]])

    df %<>% left_join(repl_df, by="sample")

    n_anno = n_distinct(df[["annotation"]])

    #import annotation information
    annotations = df %>%
        distinct(annotation) %>%
        pull(annotation)
    bed = tibble()
    for (i in 1:n_anno){
        bed = read_tsv(anno_paths[i], col_names=c('chrom','start','end','name','score','strand')) %>%
            mutate(annotation=annotations[i]) %>%
            rowid_to_column(var="index") %>%
            bind_rows(bed, .)
    }

    n_samples = length(samplelist)
    n_groups = n_distinct(df[["group"]])

    #clustering, length sorting, or no sorting
    if (sortmethod=="cluster"){
        reorder = tibble()

        #cluster for each annotation
        for (i in 1:length(annotations)){
            # filter samples and positions to cluster on,
            # using mean of samples in a group
            rr = df %>%
                filter(annotation==annotations[i] & sample %in% cluster_samples &
                       between(position, cluster_five/1000, cluster_three/1000)) %>%
                group_by(group, sampletype, annotation, index, position) %>%
                summarise(cpm=mean(cpm))
            # if specified, rescale data for each index 0 to 1
            if (cluster_scale){
                rr %<>%
                    group_by(group, annotation, index) %>%
                    mutate(cpm = scales::rescale(cpm))
            }
            rr %<>%
                ungroup() %>%
                select(-annotation) %>%
                unite(cid, c(group, sampletype, position), sep="~") %>%
                spread(cid, cpm, fill=0) %>%
                select(-index)

            d = dist(rr, method="euclidean")
            l = kmeans(d, k[i])[["cluster"]]

            pdf(file=cluster_out[i], width=6, height=6)
            unsorted = dissplot(d, method=NA, newpage=TRUE,
                                main=paste0(annotations[i], "\nEuclidean distances, unsorted"),
                                options=list(silhouettes=FALSE, col=viridis(100, direction=-1)))
            if (k[i] > 1) {
                seriated = dissplot(d, labels=l, method="OLO", newpage=TRUE,
                                    main=paste0(annotations[i], "\nEuclidean distances, ",
                                                k[i], "-means clustered,\nOLO inter- and intracluster sorting"),
                                    options=list(silhouettes=TRUE, col=viridis(100, direction=-1)))
                dev.off()

                sub_reorder = tibble(annotation = annotations[i],
                                     cluster = seriated[["labels"]],
                                     og_index = seriated[["order"]]) %>%
                    mutate(new_index = row_number())
            } else if (k[i]==1) {
                seriated = seriate(d, method="OLO")
                dev.off()
                sub_reorder = tibble(annotation = annotations[i],
                                     cluster = as.integer(1),
                                     og_index = get_order(seriated)) %>%
                    mutate(new_index = row_number())
            }

            reorder %<>% bind_rows(sub_reorder)

            sorted = sub_reorder %>%
                left_join(bed, by=c("annotation", "og_index"="index")) %>%
                select(-c(annotation, og_index, new_index))
            for (j in 1:k[i]){
                sorted %>% filter(cluster==j) %>%
                    select(-cluster) %>%
                    write_tsv(anno_out[sum(k[0:(i-1)])+j], col_names=FALSE)
            }
        }

        df %<>%
            left_join(reorder, by=c("annotation", "index"="og_index")) %>%
            group_by(annotation, cluster) %>%
            mutate(new_index = as.integer(new_index+1-min(new_index))) %>%
            ungroup() %>%
            arrange(annotation, cluster, new_index)
    } else if (sortmethod=="length"){
        sorted = bed %>%
            group_by(annotation) %>%
            arrange(end-start, .by_group=TRUE) %>%
            rowid_to_column(var= "new_index") %>%
            mutate(new_index = as.integer(new_index+1-min(new_index))) %>%
            ungroup()

        for (i in 1:n_anno){
            sorted %>% filter(annotation==annotations[i]) %>%
                select(-c(new_index, index, annotation)) %>%
                write_tsv(path=anno_out[i], col_names=FALSE)
        }

        df = sorted %>%
            select(index, new_index, annotation) %>%
            right_join(df, by=c("annotation", "index")) %>%
            mutate(cluster=as.integer(1))
    } else {
        df %<>% mutate(new_index = index,
                           cluster = as.integer(1))

        for (i in 1:n_anno){
            bed %>% filter(annotation==annotations[i]) %>%
                select(-c(index, annotation)) %>%
                write_tsv(path=anno_out[i], col_names=FALSE)
        }
    }

    df_sample = df %>%
        mutate(replicate = fct_inorder(paste("replicate", replicate), ordered=TRUE),
               cluster = fct_inorder(paste("cluster", cluster), ordered=TRUE))
    sample_cutoff = df_sample %>%
        filter(cpm > 0) %>%
        pull(cpm) %>%
        quantile(probs=pct_cutoff, na.rm=TRUE)

    df_group = df %>%
        group_by(group, annotation, position, cluster, new_index, sampletype) %>%
        summarise(cpm = mean(cpm)) %>%
        ungroup() %>%
        mutate(cluster = fct_inorder(paste("cluster", cluster), ordered=TRUE))
    group_cutoff = df_group %>%
        filter(cpm > 0) %>%
        pull(cpm) %>%
        quantile(probs=pct_cutoff, na.rm=TRUE)

    # if the sortmethod isn't length, fill missing data with minimum signal
    # (only for heatmaps, don't want to influence metagene values)
    if (sortmethod != "length"){
        df_sample %<>%
            group_by(group, sample, annotation, sampletype, replicate, cluster) %>%
            complete(new_index, position, fill=list(cpm=min(df_sample[["cpm"]]))) %>%
            ungroup()
        df_group %<>%
            group_by(group, annotation, cluster, sampletype) %>%
            complete(new_index, position, fill=list(cpm=min(df_group[["cpm"]]))) %>%
            ungroup()
    }

    heatmap_sample = hmap(df_sample, sample_cutoff, readtype=readtype, logtxn=log_transform)
    heatmap_group = hmap(df_group, group_cutoff, readtype=readtype, logtxn=log_transform)

    if (n_anno==1 && max(k)==1){
        heatmap_sample = heatmap_sample +
                ylab(annotations[1]) +
                theme(axis.title.y = element_text(size=16, face="plain", color="black", angle=90),
                      strip.text.y = element_text(size=16, face="plain", color="black"),
                      strip.background = element_rect(fill="white", size=0))

        heatmap_group = heatmap_group +
            ylab(annotations[1]) +
            theme(axis.title.y = element_text(size=16, face="plain", color="black", angle=90),
                  strip.background = element_rect(fill="white", size=0))
        if (readtype=="enrichment"){
            heatmap_sample = heatmap_sample +
                facet_grid(replicate ~ group, scales="free_y", space="free_y")

            heatmap_group = heatmap_group +
                facet_grid(. ~ group)
        } else {
            heatmap_sample = heatmap_sample +
                facet_grid(replicate ~ group + sampletype, scales="free_y", space="free_y")
            heatmap_sample %<>% nest_top_facets(inner="strand")

            heatmap_group = heatmap_group +
                facet_grid(. ~ group + sampletype)
                heatmap_group %<>% nest_top_facets(inner="strand")
        }

    } else if (n_anno==1 && max(k)>1){
        heatmap_sample = heatmap_sample +
            ylab(annotations[1]) +
            theme(axis.title.y = element_text(size=16, face="plain", color="black", angle=90),
                  strip.text.y = element_text(size=12, face="plain", color="black"),
                  strip.background = element_rect(fill="white", size=0))

        heatmap_group = heatmap_group +
            ylab(annotations[1]) +
            theme(axis.title.y = element_text(size=16, face="plain", color="black", angle=90),
                  strip.background = element_rect(fill="white", size=0))

        if (readtype=="enrichment"){
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + cluster ~ group, scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=2, outer="replicate", inner="cluster")

            heatmap_group = heatmap_group +
                facet_grid(cluster ~ group, scales="free_y", space="free_y")
        } else {
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + cluster ~ group + sampletype, scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=2, outer="replicate", inner="cluster") %>%
                nest_top_facets(inner="strand", intype="gtable")

            heatmap_group = heatmap_group +
                facet_grid(cluster ~ group + sampletype, scales="free_y", space="free_y")
            heatmap_group %<>%
                nest_top_facets(inner="strand")
        }
    } else if (n_anno>1 && max(k)==1){
        heatmap_sample = heatmap_sample +
            theme(strip.text.y = element_text(size=12, face="plain", color="black"),
                  strip.background = element_rect(fill="white", size=0))

        heatmap_group = heatmap_group +
            theme(strip.background = element_rect(fill="white", size=0))

        if (readtype=="enrichment"){
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + annotation ~ group, scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=2, outer="replicate")

            heatmap_group = heatmap_group +
                facet_grid(annotation ~ group, scales="free_y", space="free_y")

        } else {
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + annotation ~ group + sampletype, scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=2, outer="replicate") %>%
                nest_top_facets(inner="strand", intype="gtable")

            heatmap_group = heatmap_group +
                facet_grid(annotation ~ group + sampletype, scales="free_y", space="free_y")
            heatmap_group %<>%
                nest_top_facets(inner="strand")
        }

    } else if (n_anno>1 && max(k)>1){
        heatmap_sample = heatmap_sample +
            theme(strip.text.y = element_text(size=12, face="plain", color="black"),
                  strip.background = element_rect(fill="white", size=0))

        heatmap_group = heatmap_group +
            theme(strip.text.y = element_text(size=16, face="plain", color="black"),
                  strip.background = element_rect(fill="white", size=0))

        if (readtype=="enrichment"){
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + annotation + cluster ~ group,
                           scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=3)

            heatmap_group = heatmap_group +
                facet_grid(annotation + cluster ~ group + strand, scales="free_y", space="free_y")
            heatmap_group %<>%
                nest_right_facets(level=2, outer="annotation")
        } else {
            heatmap_sample = heatmap_sample +
                facet_grid(replicate + annotation + cluster ~ group + sampletype,
                           scales="free_y", space="free_y")
            heatmap_sample %<>%
                nest_right_facets(level=3) %>%
                nest_top_facets(inner="strand", intype="gtable")

            heatmap_group = heatmap_group +
                facet_grid(annotation + cluster ~ group, scales="free_y", space="free_y")
            heatmap_group %<>%
                nest_right_facets(level=2, outer="annotation") %>%
                nest_top_facets(inner="strand", intype="gtable")
        }
    }
    ggsave(heatmap_sample_out, plot=heatmap_sample, width=2+18*n_groups, height=10+10*max_reps, units="cm", limitsize=FALSE)
    ggsave(heatmap_group_out, plot=heatmap_group, width=2+18*n_groups, height=30, units="cm", limitsize=FALSE)

    metadf_sample = df %>%
        group_by(group, sample, sampletype,
                 annotation, position, cluster, replicate)

    if (spread_type=="conf_int"){
        metadf_sample %<>%
            summarise(mid = winsor.mean(cpm, trim=trim_pct),
                      sd = winsor.sd(cpm, trim=trim_pct)) %>%
            mutate(low = mid-sd,
                   high = mid+sd)

        #with SD correction for small sample sizes (Gurland and Tripathi 1971)
        metadf_group = metadf_sample %>%
            group_by(group, sampletype, annotation, position, cluster) %>%
            summarise(sd = sd(mid),
                      n = n_distinct(replicate),
                      mid = mean(mid)) %>%
            mutate(sem = sqrt((n-1)/2)*gamma((n-1)/2)/gamma(n/2)*sd/sqrt(n))
    } else if (spread_type=="quantile") {

        metadf_sample %<>%
            summarise(mid = median(cpm),
                      low = quantile(cpm, probs=trim_pct),
                      high = quantile(cpm, probs=(1-trim_pct)))

        metadf_group = df %>%
            group_by(group, sampletype, annotation, position, cluster) %>%
            summarise(mid = median(cpm),
                      low = quantile(cpm, probs=trim_pct),
                      high = quantile(cpm, probs=(1-trim_pct)))
    }

    metadf_sample %<>%
        ungroup() %>% arrange(replicate) %>%
        mutate(replicate = fct_inorder(paste("replicate", replicate), ordered=TRUE)) %>%
        arrange(cluster) %>%
        mutate(cluster = fct_inorder(paste("cluster", cluster), ordered=TRUE))

    metadf_group %<>%
        ungroup() %>%
        arrange(cluster) %>%
        mutate(cluster = fct_inorder(paste("cluster", cluster), ordered=TRUE))

    meta_sample = meta(metadf_sample, strand=readtype)
    meta_group = meta(metadf_group, groupvar="group", strand=readtype)
    meta_sampleclust = meta(metadf_sample, groupvar="sampleclust", strand=readtype)
    meta_groupclust = meta(metadf_group, groupvar="groupclust", strand=readtype)

    if(max(k) > 1 | n_anno > 1){
        meta_sampleanno = meta(metadf_sample, groupvar="sampleanno", strand=readtype)
        meta_groupanno = meta(metadf_group, groupvar="groupanno", strand=readtype)
    }

    if (n_anno==1 && max(k)==1){
        meta_sample  = meta_sample +
            scale_color_manual(values=rep("#4477AA", 100)) +
            scale_fill_manual(values=rep("#4477AA", 100)) +
            facet_grid(replicate~group) +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype),
                    subtitle = annotations[1]) +
            theme(legend.position="none")

        meta_sample_overlay = meta(metadf_sample) +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype),
                    subtitle = annotations[1]) +
            theme(legend.position="right",
                  legend.key.width=unit(0.8, "cm"))

        meta_group = meta_group +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype),
                    subtitle = annotations[1]) +
            theme(legend.position="right",
                  legend.key.width=unit(0.8, "cm"))

        meta_sampleclust = meta_sampleclust +
            facet_grid(. ~ group) +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype),
                    subtitle = annotations[1]) +
            theme(legend.position="right",
                  legend.key.width=unit(1, "cm"))

        meta_groupclust = meta_groupclust +
            facet_grid(. ~ group) +
            ggtitle(paste(factorlabel, "ChIP-seq", readtype),
                    subtitle = annotations[1]) +
            theme(legend.position="right",
                  legend.key.width=unit(1, "cm"))

        ggsave(meta_sample_out, plot = meta_sample, width=3+7*n_groups, height=2+5*max_reps, units="cm", limitsize=FALSE)
        ggsave(meta_sample_overlay_out, plot = meta_sample_overlay, width=16, height=9, units="cm", limitsize=FALSE)
        ggsave(meta_sampleanno_out, plot = meta_sample_overlay, width=16, height=9, units="cm", limitsize=FALSE)
        ggsave(meta_group_out, plot = meta_group, width=16, height=9, units="cm", limitsize=FALSE)
        ggsave(meta_groupanno_out, plot = meta_group, width=16, height=9, units="cm", limitsize=FALSE)
        ggsave(meta_sampleclust_out, plot = meta_sampleclust, width=6+7*n_groups, height=10, units="cm", limitsize=FALSE)
        ggsave(meta_groupclust_out, plot = meta_groupclust, width=6+7*n_groups, height=10, units="cm", limitsize=FALSE)
    } else if (n_anno>1 && max(k)==1){
        meta_sample = meta_sample +
            facet_grid(replicate ~ annotation)

        meta_sample_overlay = meta_sample +
            facet_grid(.~annotation)

        meta_group = meta_group +
            facet_grid(.~annotation)

        meta_sampleanno = meta_sampleanno +
            facet_grid(.~group) +
            theme(legend.direction="vertical")

        meta_groupanno = meta_groupanno +
            facet_grid(.~group) +
            theme(legend.direction="vertical")

        meta_sampleclust = meta_sampleclust +
            facet_grid(annotation ~ group)

        meta_groupclust = meta_groupclust +
            facet_grid(annotation ~ group)

    } else if (max(k)>1){
        meta_sample = meta_sample +
            facet_grid(replicate ~ annotation + cluster) +
            theme(strip.background = element_rect(fill="white", size=0))
        meta_sample %<>% nest_top_facets(level=2)

        meta_sample_overlay = meta(metadf_sample) +
            facet_grid(cluster ~ annotation)

        meta_group = meta_group +
            facet_grid(cluster ~ annotation)

        meta_sampleanno = meta_sampleanno +
            facet_grid(.~group) +
            theme(legend.direction="vertical")

        meta_groupanno = meta_groupanno +
            facet_grid(.~group) +
            theme(legend.direction="vertical")

        meta_sampleclust = meta_sampleclust +
            facet_grid(annotation ~ group) +
            theme(legend.key.width=unit(2, "cm"))

        meta_groupclust = meta_groupclust +
            facet_grid(annotation ~ group) +
            theme(legend.key.width=unit(2, "cm"))
    }

    if (!(n_anno==1 && max(k)==1)){
        ggsave(meta_sample_out, plot = meta_sample, width=3+6*sum(k), height=2+5*max_reps, units="cm", limitsize=FALSE)
        ggsave(meta_sample_overlay_out, plot = meta_sample_overlay, width=3+7*n_anno, height=2+5*max(k), units="cm", limitsize=FALSE)
        ggsave(meta_group_out, plot = meta_group, width=3+7*n_anno, height=2+5*max(k), units="cm", limitsize=FALSE)
        ggsave(meta_sampleanno_out, plot = meta_sampleanno, width=3+7*n_groups, height=9+.75*sum(k), units="cm", limitsize=FALSE)
        ggsave(meta_groupanno_out, plot = meta_groupanno, width=3+7*n_groups, height=9+.75*sum(k), units="cm", limitsize=FALSE)
        ggsave(meta_sampleclust_out, plot = meta_sampleclust, width=3+7*n_groups, height=3+6*n_anno, units="cm", limitsize=FALSE)
        ggsave(meta_groupclust_out, plot = meta_groupclust, width=3+7*n_groups, height=3+6*n_anno, units="cm", limitsize=FALSE)
    }
}

main(in_path = snakemake@input[["matrix"]],
     samplelist = snakemake@params[["samplelist"]],
     anno_paths = snakemake@input[["annotations"]],
     ptype = snakemake@params[["plottype"]],
     readtype = snakemake@params[["readtype"]],
     upstream = snakemake@params[["upstream"]],
     dnstream = snakemake@params[["dnstream"]],
     scaled_length = snakemake@params[["scaled_length"]],
     pct_cutoff = snakemake@params[["pct_cutoff"]],
     log_transform = snakemake@params[["log_transform"]],
     pcount = snakemake@params[["pcount"]],
     spread_type = snakemake@params[["spread_type"]],
     trim_pct = snakemake@params[["trim_pct"]],
     factorlabel = snakemake@wildcards[["factor"]],
     refptlabel = snakemake@params[["refpointlabel"]],
     endlabel = snakemake@params[["endlabel"]],
     cmap = snakemake@params[["cmap"]],
     sortmethod = snakemake@params[["sortmethod"]],
     cluster_scale = snakemake@params[["cluster_scale"]],
     cluster_samples = snakemake@params[["cluster_samples"]],
     cluster_five = snakemake@params[["cluster_five"]],
     cluster_three = snakemake@params[["cluster_three"]],
     k = snakemake@params[["k"]],
     heatmap_sample_out = snakemake@output[["heatmap_sample"]],
     heatmap_group_out = snakemake@output[["heatmap_group"]],
     meta_sample_out = snakemake@output[["meta_sample"]],
     meta_sample_overlay_out = snakemake@output[["meta_sample_overlay"]],
     meta_sampleanno_out = snakemake@output[["meta_sampleanno"]],
     meta_groupanno_out = snakemake@output[["meta_groupanno"]],
     meta_group_out = snakemake@output[["meta_group"]],
     meta_sampleclust_out = snakemake@output[["meta_sampleclust"]],
     meta_groupclust_out = snakemake@output[["meta_groupclust"]],
     anno_out = snakemake@params[["annotations_out"]],
     cluster_out = snakemake@params[["clusters_out"]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
library(tidyverse)
library(GGally)
library(viridis)

main = function(intable, factor, binsize, pcount, samplelist, outpath){
    df = intable %>% read_tsv() %>%
        gather(key=sample, value=signal, -name) %>%
        filter(sample %in% samplelist) %>%
        mutate_at(vars(sample), ~(fct_inorder(., ordered=TRUE))) %>%
        spread(sample, signal) %>%
        select(-name)

    df = df[which(rowSums(df)>0),]

    cor_matrix = df %>% na_if(0) %>% log10() %>%
        cor(method="pearson", use="pairwise.complete.obs")

    maxsignal = max(df) + pcount
    mincor = min(cor_matrix) * 0.98
    plots = list()

    #for each row
    for (i in 1:ncol(df)){
        #for each column
        for (j in 1:ncol(df)){
            idx = ncol(df)*(i-1)+j
            if (i < j){
                #upper right (correlation)
                cor_value = cor_matrix[i,j]
                plot = ggplot(data = tibble(x=c(0,1), y=c(0,1), corr=cor_value)) +
                        geom_rect(aes(fill=corr), xmin=0, ymin=0, xmax=1, ymax=1) +
                        annotate("text", x=0.5, y=0.5, label=sprintf("%.2f",round(cor_value,2)), size=10*abs(cor_value)) +
                        scale_x_continuous(breaks=NULL) +
                        scale_y_continuous(breaks=NULL) +
                        scale_fill_distiller(palette="Blues", limits = c(mincor,1), direction=1)
                plots[[idx]] = plot
            } else if (i == j){
                #top left to bot right diag (density)
                subdf = df %>% select(i) %>% gather(sample, value)
                plot = ggplot(data = subdf, aes(x=(value+pcount))) +
                        geom_density(aes(y=..scaled..), fill="#114477", size=0.8) +
                        scale_y_continuous(breaks=c(0,.5,1)) +
                        scale_x_log10(limit = c(pcount, maxsignal)) +
                        annotate("text", x=.90*maxsignal, y=0.5, hjust=1,
                                 label=unique(subdf$sample), size=2, fontface="bold")
                plots[[idx]] = plot
            } else {
                #bottom left (scatter)
                #filtering is an optional hack to avoid the (0,0) bin taking up
                #all of the colorspace
                subdf = df %>% select(i,j) %>% gather(xsample, xvalue, -1) %>%
                            gather(ysample, yvalue, -c(2:3)) #%>%
                            # filter(!(xvalue < 6*pcount & yvalue < 6*pcount))
                plot = ggplot(data = subdf, aes(x=xvalue+pcount, y=yvalue+pcount)) +
                            geom_abline(intercept = 0, slope=1, color="grey80", size=.5) +
                            stat_bin_hex(geom="point", aes(color=log10(..count..)), binwidth=c(.04,.04), size=.5, shape=16, stroke=0) +
                            scale_fill_viridis(option="inferno") +
                            scale_color_viridis(option="inferno") +
                            scale_x_log10(limit = c(pcount, maxsignal)) +
                            scale_y_log10(limit = c(pcount, maxsignal))
                plots[[idx]] = plot
            }
        }
    }

    mat = ggmatrix(plots, nrow=ncol(df), ncol=ncol(df),
                   title = paste0(factor, " ChIP-seq signal, ", binsize, "bp bins"),
                   xAxisLabels = names(df), yAxisLabels = names(df), switch="both") +
                    theme_light() +
                    theme(plot.title = element_text(size=12, color="black", face="bold"),
                          axis.text = element_text(size=9),
                          strip.background = element_blank(),
                          strip.text = element_text(size=12, color="black", face="bold"),
                          strip.text.x = element_text(angle=15, hjust=1, vjust=1, size=8),
                          strip.text.y = element_text(angle=180, hjust=1),
                          strip.placement="outside",
                          strip.switch.pad.grid = unit(0, "points"),
                          strip.switch.pad.wrap = unit(0, "points"))
    w = 3+ncol(df)*4.5
    h = 9/16*w+0.5
    ggsave(outpath, mat, width=w, height=h, units="cm", limitsize=FALSE)
    print(warnings())
}

main(intable = snakemake@input[[1]],
     factor = snakemake@wildcards[["factor"]],
     binsize = snakemake@wildcards[["windowsize"]],
     pcount = snakemake@params[["pcount"]],
     samplelist = snakemake@params[["samplelist"]],
     outpath = snakemake@output[[1]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
library(tidyverse)
library(forcats)
library(viridis)

survival_plot = function(df, scalefactor, ylabel){
    plot = ggplot(data = df, aes(x=step, y=count/scalefactor, group=sample)) +
        # geom_hline(aes(yintercept=count/scalefactor), color="grey50", size=0.2) +
        geom_step(direction="vh", position=position_nudge(x=0.5),
                  color="#114477", size=0.8) +
        scale_x_continuous(expand=c(0,0), breaks=2:6, name=NULL,
                           labels=c("raw reads", "reads cleaned",
                                    "aligned", "uniquely mapping",
                                    "no duplicates")) +
        scale_y_continuous(sec.axis=dup_axis(), name=ylabel) +
        facet_grid(sample~., switch="y") +
        theme_light() +
        theme(strip.placement="outside",
              strip.background = element_blank(),
              text = element_text(size=12, color="black", face="bold"),
              strip.text.y = element_text(size=12, angle=-180, color="black", hjust=1),
              axis.text.x = element_text(size=12, color="black", face="bold", angle=30, hjust=0.95),
              axis.text.y = element_text(size=10, color="black", face="plain"),
              axis.title.y.right = element_blank(),
              panel.grid.major.x = element_line(color="grey40"),
              # panel.grid.major.y = element_blank(),
              # panel.grid.minor.y = element_blank(),
              plot.subtitle = element_text(size=12, face="plain"))
    return(plot)
}

main = function(in_table, surv_abs_out, surv_rel_out, loss_out){
    df = read_tsv(in_table) %>%
        mutate(sample=fct_inorder(sample, ordered=TRUE))

    nsamples = nrow(df)

    loss = df %>% gather(step, count, -sample, factor_key=TRUE) %>%
        group_by(sample) %>%
        mutate(og_count = lag(count)) %>%
        filter(step != "raw") %>%
        mutate(loss = (og_count-count)/og_count*100)

    #some hacking to get a survival-curve like thing
    #TODO: make the color fill the AUC?
    survival = df %>% mutate(dummy=raw) %>%
        select(sample, dummy, 2:6) %>%
        gather(step, count, -sample, factor_key=TRUE) %>%
        mutate_at(vars(step), as.numeric)

    surv_abs = survival_plot(survival, scalefactor = 1e6, ylabel = "library size (M reads)") +
        ggtitle("read processing summary",
                subtitle = "absolute library size")

    surv_rel = survival_plot(survival %>% group_by(sample) %>% mutate(count=count/max(count)),
                             scalefactor = .01, ylabel = "% of raw reads") +
        ggtitle("read processing summary", subtitle = "relative to library size")

    ggsave(surv_abs_out, plot=surv_abs, width=20, height=2+2.5*nsamples, units="cm")
    ggsave(surv_rel_out, plot=surv_rel, width=20, height=2+2.5*nsamples, units="cm")

    loss_plot = ggplot(data = loss, aes(x=step, y=0, fill=loss)) +
        geom_raster() +
        geom_text(aes(label=round(loss, 2)), size=4) +
        scale_fill_viridis(name="% loss", guide=guide_colorbar(barheight = 10, barwidth=1)) +
        scale_color_viridis(guide=FALSE) +
        scale_x_discrete(labels = c("reads cleaned", "aligned",
                                    "uniquely mapping", "no duplicates"),
                         expand=c(0,0), name=NULL) +
        scale_y_continuous(breaks=0, expand=c(0,0), name=NULL) +
        facet_grid(sample~., switch="y") +
        ggtitle("read processing percent loss") +
        theme_light() +
        theme(strip.placement="outside",
              strip.background = element_blank(),
              text = element_text(size=12, color="black", face="bold"),
              strip.text.y = element_text(size=12, angle=-180, color="black", hjust=1),
              axis.text.x = element_text(size=12, color="black", face="bold", angle=30, hjust=0.95),
              axis.text.y = element_blank(),
              axis.title.y.right = element_blank(),
              plot.subtitle = element_text(size=12, face="plain"),
              panel.border = element_blank())

    ggsave(loss_out, plot=loss_plot, width=20, height=2+1.5*nsamples, units="cm")
}

main(in_table = snakemake@input[[1]],
     surv_abs_out = snakemake@output[["surv_abs_out"]],
     surv_rel_out = snakemake@output[["surv_rel_out"]],
     loss_out = snakemake@output[["loss_out"]])
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import argparse
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d as gsmooth
import pyBigWig as pybw

parser = argparse.ArgumentParser(description='Smooth bigwig file with Gaussian kernel of given bandwidth.')
parser.add_argument('-b', dest = 'bandwidth', type=int, default = 20, help='Gaussian kernel bandwidth (standard deviation)')
parser.add_argument('-i', dest = 'infile', type=str, help='path to input BigWig')
parser.add_argument('-o', dest = 'outfile', type=str, help='path to smoothed output BigWig')
args = parser.parse_args()

inbw = pybw.open(args.infile)
outbw = pybw.open(args.outfile, "w")

outbw.addHeader(list(inbw.chroms().items()))

for chrom in inbw.chroms():
    raw = inbw.values(chrom, 0, inbw.chroms(chrom), numpy=True)
    smoothed = gsmooth(raw, sigma=args.bandwidth, order=0, mode='mirror')
    outbw.addEntries(chrom, 0, values=smoothed, span=1, step=1)

inbw.close()
outbw.close()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
library(tidyverse)
library(gridExtra)
library(ggthemes)

main = function(in_path, sample_list, controls, conditions, plot_out, stats_out){
    df = read_tsv(in_path) %>%
        filter(sample %in% sample_list) %>%
        mutate(abundance = (experimental_counts_IP / spikein_counts_IP ) *
                                    (spikein_counts_input / experimental_counts_input)) %>%
        mutate_at(vars(sample, group), ~(fct_inorder(., ordered=TRUE))) %>%
        group_by(group) %>%
        mutate(outlier= ifelse(abundance >
                                   2.5*quantile(abundance, .75) -
                                   1.5*quantile(abundance, .25) |
                                   abundance <
                                   -2.5*quantile(abundance, .25) -
                                   1.5*quantile(abundance, .75),
                               TRUE, FALSE))

    n_samples = nrow(df)
    n_groups = df %>% pull(group) %>% n_distinct()

    barplot = ggplot(data=df, aes(x=sample, fill=group, y=abundance)) +
        geom_col() +
        geom_text(aes(label=round(abundance, 2)), size=12/75*25.4,
                  position=position_stack(vjust=0.9)) +
        scale_fill_ptol(guide=FALSE) +
        ylab("spike-in normalized\nabundance vs. input") +
        theme_light() +
        theme(axis.text = element_text(size=10, color="black"),
              axis.text.x = element_text(angle=30, hjust=0.9),
              axis.title.x = element_blank(),
              axis.title.y = element_text(size=10, color="black",
                                          angle=0, vjust=0.5, hjust=1))

    boxplot = ggplot(data = df, aes(x=group, y=abundance, fill=group)) +
        geom_boxplot(outlier.shape=16, outlier.size=1.5, outlier.color="red", outlier.stroke=0) +
        geom_point(shape=16, size=1, stroke=0) +
        scale_fill_ptol(guide=FALSE) +
        scale_y_continuous(name = "spike-in normalized\nabundance vs. input",
                           limits = c(0, NA)) +
        theme_light() +
        theme(axis.text = element_text(size=10, color="black"),
              axis.text.x = element_text(angle=30, hjust=0.9),
              axis.title.x = element_blank(),
              axis.title.y = element_text(size=10, color="black"))

    stats_table = df %>%
        add_count(group, name="n") %>%
        group_by(group) %>%
        mutate(median = median(abundance)) %>%
        ungroup() %>%
        filter(!outlier) %>%
        add_count(group, name="nn") %>%
        group_by(group) %>%
        summarise(n = first(n),
                  median = first(median),
                  n_no_outlier = first(nn),
                  mean_no_outlier = mean(abundance),
                  sd_no_outlier = sd(abundance)) %>%
        write_tsv(path = stats_out, col_names=TRUE)

    #set width
    wl = 1+1.6*n_samples
    wr = 1+1.8*n_groups
    th = 0
    if (!(is.null(conditions) || is.null(controls))){
        levels_df = tibble(condition=conditions, control=controls) %>%
            left_join(stats_table %>% select(group, mean_no_outlier),
                      by=c("condition"="group")) %>%
            rename(condition_abundance=mean_no_outlier) %>%
            left_join(stats_table %>% select(group, mean_no_outlier),
                      by=c("control"="group")) %>%
            rename(control_abundance=mean_no_outlier) %>%
            mutate(levels = condition_abundance/control_abundance)

        levels_table = levels_df %>%
            select(condition, control, levels) %>%
            mutate_at("levels", ~(round(., digits=3)))
        levels_draw = tableGrob(levels_table,
                                rows=NULL,
                                cols=c("condition","control","relative levels"),
                                ttheme_minimal(base_size=10))

        th = 1+length(conditions)/2
        page = arrangeGrob(barplot, boxplot, levels_draw,
                           layout_matrix=rbind(c(1,2),c(3,3)),
                           widths=unit(c(wl, wr), "cm"),
                           heights=unit(c(9,th),"cm"))
    } else {
        page = arrangeGrob(barplot, boxplot,
                           widths=unit(c(wl, wr), "cm"),
                           heights=unit(c(9,th),"cm"))

    }
    ggsave(plot_out, page, width = wl+wr, height=9+th+.5, units = "cm")
}

main(in_path = snakemake@input[[1]],
     sample_list = snakemake@params[["samplelist"]],
     controls = snakemake@params[["controls"]],
     conditions = snakemake@params[["conditions"]],
     plot_out = snakemake@output[["plot"]],
     stats_out = snakemake@output[["stats"]])
ShowHide 39 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/khalillab/coop-TF-chipseq
Name: coop-tf-chipseq
Version: v1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...