analyze substitution rate and mutation behavior within variants.

public public 1yr ago Version: VE_resubmission 0 bookmarks

Richard A. Neher

Continued evolution and adaptation of SARS-CoV-2 has lead to more transmissible and immune-evasive variants with profound impact on the course of the pandemic. Here I analyze the evolution of the virus over 2.5 years since its emergence and estimate rates of evolution for synonymous and non-synonymous changes separately for evolution within clades -- well defined mono-phyletic groups with gradual evolution -- and for the pandemic overall. The rate of synonymous mutations is found to be around 6 changes per year. Synonymous rates within variants vary little from variant to variant and are compatible with the overall rate. In contrast, the rate at which variants accumulate amino acid changes (non-synonymous mutation) was initially around 12-16 changes per year, but in 2021 and 2022 dropped to 6-9 changes per year. The overall rate of non-synonymous evolution, that is across variants, is estimated to be about 25 amino acid changes per year. This 2-fold higher rate indicates that the evolutionary process that gave rise to the different variants is qualitatively different from that in typical transmission chains and likely dominated by adaptive evolution. I further quantify the spectrum of mutations and purifying selection in different SARS-CoV-2 proteins. Many accessory proteins evolve under limited evolutionary constraint with little short term purifying selection. About half of the mutations in other proteins are strongly deleterious and rarely observed, not even at low frequency.

Repository structure

This repository contains scripts and source files associated with a manuscript on SARS-CoV-2 virus evolution.

The analysis can be run using snakemake and requires standard python libraries, as well as treetime ( phylo-treetime in pip).

The analysis can be run using the open data files provisioned by Nextstrain (download default in the workflow).

The directory manuscript contains the TeX files associated with manuscript, the bibliography, and the figures.

The data directory contains some derived files, including the rate estimates, the mutation distribution, and fitness cost landscapes.

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pandas as pd
import argparse,json
from collections import defaultdict
from datetime import datetime
import numpy as np
import matplotlib as mpl
mpl.rcParams['axes.formatter.useoffset'] = False

import matplotlib.pyplot as plt
import seaborn as sns
from root_to_tip import filter_and_transform, get_clade_gts

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--metadata', type=str, nargs='+', required=True, help="input data")
    parser.add_argument('--clade-gts', type=str, required=True, help="input data")
    parser.add_argument('--clade', type=str, required=True, help="input data")
    parser.add_argument('--sub-clades', type=str, required=True, help="input data")
    parser.add_argument('--min-date', type=float, help="input data")
    parser.add_argument('--output-plot', type=str, help="plot file")
    #parser.add_argument('--output-json', type=str, help="rate file")
    args = parser.parse_args()

    clade_gt = get_clade_gts(args.clade_gts, args.sub_clades)

    d = pd.concat([pd.read_csv(x, sep='\t').fillna('') for x in args.metadata])
    filtered_data, _ = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.min_date + 0.3, completeness=0, swap_root=args.clade_gts=='19B+')
    #filtered_data=filtered_data.loc[filtered_data.country!='China']

    intra_subs_dis = filtered_data["divergence"].value_counts().sort_index()
    intra_aaSubs_dis = filtered_data["aaDivergence"].value_counts().sort_index()
    intra_geno = filtered_data["intra_substitutions_str"].value_counts()


    ls = ['-', '--', '-.', ':']
    plt.figure()
    dates = sorted(filtered_data.loc[:,"numdate"])
    plt.plot(dates, (np.arange(len(dates))+1), c='k', lw=3, alpha=0.3, label="all")
    ls_counter = defaultdict(int)
    for x,i in intra_geno.items():
        nmuts = len(x.split(',')) if x else 0
        ind = filtered_data["intra_substitutions_str"]==x
        dates = sorted(filtered_data.loc[ind,"numdate"])
        factor = 3 if 'C8782T' in x else 1
        if dates[0]<args.min_date+0.2:
        #if (nmuts<2 or i>len(filtered_data)/100) and ls_counter[nmuts]<10:
            if nmuts>5:
                continue
            plt.plot(dates, factor*(np.arange(i)+1), c=f'C{nmuts}', lw=2 if nmuts else 3,
                            ls=ls[ls_counter[nmuts]%len(ls)], marker='o' if len(dates)<3 else '',
                            label=f"gt={x if x else 'founder'}"[:30])
            ls_counter[nmuts]+=1
    plt.ylim(0.5,1000)
    plt.xlim(args.min_date,args.min_date+0.3)
    plt.yscale('log')
    plt.legend()

    if args.output_plot:
        plt.savefig(args.output_plot)
    else:
        plt.show()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import numpy as np
import pandas as pd
from scipy.stats import linregress
import matplotlib as mpl
mpl.rcParams['axes.formatter.useoffset'] = False
from root_to_tip import add_panel_label

import matplotlib.pyplot as plt
import seaborn as sns


if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    fs=14
    parser.add_argument('--rate-table', type=str, required=True, help="input data")
    parser.add_argument('--output-plot', type=str, help="plot file")
    parser.add_argument('--output-plot-rates', type=str, help="plot file")
    parser.add_argument('--output-plot-rates-genes', type=str, help="plot file")
    args = parser.parse_args()

    L = 29903
    Laa = 9716  # all CDS other than ORF9)

    rates = pd.read_csv(args.rate_table, sep='\t', index_col='clade')
    inter_clade_rates = {}
    fig, axs = plt.subplots(2,2, figsize=(12,12))
    xticks = [2020,2020.5,2021,2021.5,2022,2022.5]
    for ax, mut_type, ax_label, panel in zip(axs.flatten(),['nuc', 'aa', 'syn'],
                        ['total divergence', 'amino acid divergence', 'synonymous divergence'], ['A', 'B', 'C']):
        ax.set_ylabel(ax_label, fontsize=fs)
        add_panel_label(ax, panel, fs=fs*1.8)
        inter_clade = []
        ci=0
        for clade, row in rates.iterrows():
            dt = np.linspace(0, 0.75,2)
            t = dt + row[f'{mut_type}_origin']
            slope = row[f'{mut_type}_rate']
            ls = '-' if ci<10 else '--'
            m = 'o' if ci<10 else 's'
            ax.plot(t, dt*slope + row[f'{mut_type}_div'], label=clade, c=f"C{ci%10}", ls=ls)
            ax.scatter([[row[f'{mut_type}_origin']]], [[row[f'{mut_type}_div']]], c=f"C{ci%10}", marker=m)
            if row[f'{mut_type}_origin']>2019.7 and row[f'{mut_type}_origin']<2022.7:
                inter_clade.append([row[f'{mut_type}_origin'], row[f'{mut_type}_div']])
            ci += 1
        inter_clade = np.array(inter_clade)
        reg = linregress(inter_clade[:,0], inter_clade[:,1])
        inter_clade_rates[mut_type] = reg.slope

        x = np.linspace(2019.8, 2022.5, 21)
        y = reg.slope*x + reg.intercept
        std_dev = np.sqrt(np.maximum(0,reg.slope*x + reg.intercept))
        ax.plot(x, y, c='k', lw=3, alpha=0.5)
        ax.fill_between(x, y+std_dev, np.maximum(0, y-std_dev), fc='k', lw=3, alpha=0.1, ec=None)
        ax.set_xlim(x.min(), x.max())
        ax.set_ylim(0)
        if mut_type=='nuc':
            ax.text( 0.5, 0.05,f"overall rate:\n{reg.slope:1.1f}/year\n{reg.slope/L:1.1e}/year/site",
                    fontsize=fs, transform=ax.transAxes)
        else:
            ax.text( 0.5, 0.05,f"overall rate:\n{reg.slope:1.1f}/year\n{reg.slope/Laa:1.1e}/year/codon",
                   fontsize=fs, transform=ax.transAxes)
        if mut_type == 'aa': ax.legend(ncol=2)
        ax.set_xticks(xticks, [str(x) for x in xticks], fontsize=fs)

    for i, (mut_type, rate) in enumerate(inter_clade_rates.items()):
        axs[-1,-1].fill_between([i-0.5, i+0.5], [35,35], facecolor='k', alpha=0.075*(2+i%2))
        axs[-1,-1].plot([i-0.4, i+0.4], [rate, rate], lw=3, c='k', alpha=0.5)
        clade_rates = rates[f"{mut_type}_rate"]
        x_offset = i + np.linspace(-.35, 0.35,len(clade_rates))
        axs[-1,-1].scatter(x_offset[:10],
                           clade_rates[:10], c=[f"C{ci%10}" for ci in range(10)],
                           marker='o')
        axs[-1,-1].scatter(x_offset[10:],
                           clade_rates[10:], c=[f"C{ci%10}" for ci in range(len(clade_rates) - 10)],
                           marker='s')
    axs[-1,-1].set_ylabel("substitutions per year", fontsize=fs)
    axs[-1,-1].set_xticks([0,1,2], ['total', 'amino acid', 'synonymous'], rotation=20,
                          ha='center', fontsize=fs)
    axs[-1,-1].set_ylim(0)
    add_panel_label(axs[-1,-1], 'D', fs=fs*1.8)

    if args.output_plot:
        plt.savefig(args.output_plot)
    else:
        plt.show()

    plt.figure()
    plt.plot(rates["nuc_origin"], rates["aa_rate"], 'o', label='amino acid rate')
    plt.plot(rates["nuc_origin"], rates["syn_rate"], 'o', label='synonymous rate')
    plt.plot(rates["nuc_origin"], np.ones_like(rates['nuc_origin'])*inter_clade_rates["aa"], label='inter-clade amino acid rate', c=f"C{0}", lw=3)
    plt.plot(rates["nuc_origin"], np.ones_like(rates['nuc_origin'])*inter_clade_rates["syn"], label='inter-clade synonymous rate', c=f"C{1}", lw=3)
    plt.ylabel('rate estimate [1/y]')
    plt.ylim(0)
    plt.legend()

    if args.output_plot_rates:
        plt.savefig(args.output_plot_rates)
    else:
        plt.show()

    plt.figure()
    plt.plot(rates["aa_rate"], 'o-', label='Overall amino acid rate')
    plt.plot(rates["spike_rate"], 's-', label='spike protein')
    plt.plot(rates["orf1a_rate"], 'd-', label='ORF1a')
    plt.plot(rates["orf1b_rate"], 'd-', label='ORF1b')
    plt.plot(rates["enm_rate"], 'd-', label='E,M,N')
    plt.plot(rates["aa_rate"] - rates["spike_rate"] - rates["orf1a_rate"] - rates["orf1b_rate"]- rates["enm_rate"],
             'v-', label='other ORFs')
    plt.ylabel('rate estimate [subs/y]')
    plt.legend()
    plt.xticks(range(len(rates)), rates.index, rotation=60, ha='right')
    plt.tight_layout()

    if args.output_plot_rates_genes:
        plt.savefig(args.output_plot_rates_genes)
    else:
        plt.show()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import argparse,json
from unittest.mock import NonCallableMagicMock
import pandas as pd
import numpy as np
import matplotlib as mpl
from Bio import SeqIO,Seq
mpl.rcParams['axes.formatter.useoffset'] = False
from collections import defaultdict
from scipy.stats import scoreatpercentile

import matplotlib.pyplot as plt
import seaborn as sns

columns = ['seqName', 'Nextclade_pango', "privateNucMutations.unlabeledSubstitutions", 'qc.overallStatus', 'privateNucMutations.reversionSubstitutions']

def get_sequence(mods, root_seq):
    seq = root_seq.copy()
    for mut in mods['nuc']:
        a,pos,d = mut[0], int(mut[1:-1])-1, mut[-1]
        if a!=seq[pos]:
            print(seq[pos], mut)
        seq[pos]=d
    return seq

def translate(codon):
    try:
        return Seq.translate("".join(codon))
    except:
        return 'X'

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--metadata', type=str, required=True, help="input data")
    parser.add_argument('--reference', type=str, required=True, help="input data")
    parser.add_argument('--pango-gts', type=str, required=True, help="input data")
    parser.add_argument('--rare-cutoff', type=int, default=2, help="minimal number of occurrences to count mutations")
    parser.add_argument('--output-fitness', type=str, required=True, help="fitness table")
    parser.add_argument('--output-events', type=str, required=True, help="events table")
    parser.add_argument('--output-mutations', type=str, required=True, help="fitness table")
    args = parser.parse_args()

    ## load reference
    ref = SeqIO.read(args.reference, 'genbank')
    ref_array = np.array(ref.seq)
    base_content = {x:ref.seq.count(x) for x in 'ACGT'}

    ## make a map of codon positions, ignore orf9b (overlaps N)
    codon_pos = np.zeros(len(ref))
    map_to_gene = {}
    gene_position = {}
    gene_length = {}
    for feat in ref.features:
        if feat.type=='CDS':
            gene_name = feat.qualifiers['gene'][0]
            gene_position[gene_name] = feat.location
            gene_length[gene_name] = (feat.location.end - feat.location.start)//3
            if gene_name!='ORF9b':
                for gpos, pos in enumerate(feat.location):
                    codon_pos[pos] = (gpos%3)+1
                    map_to_gene[pos]= gene_name


    # load pango genotypes
    with open(args.pango_gts) as fh:
        pango_gts = json.load(fh)

    # load and filter metadata
    d = pd.read_csv(args.metadata, sep='\t', usecols=columns).fillna('')
    d = d.loc[d["qc.overallStatus"]=='good',:]

    # glob all rare mutations by pango lineage
    mutation_counter = defaultdict(lambda: defaultdict(int))
    pango_counter = defaultdict(int)
    for r, row in d.iterrows():
        pango = row.Nextclade_pango
        if pango[0]=='X':
            continue

        pango_counter[pango] += 1
        muts = row["privateNucMutations.unlabeledSubstitutions"].split(',')
        rev_muts = row["privateNucMutations.reversionSubstitutions"].split(',')
        if len(muts) + len(rev_muts):
            for m in muts + rev_muts:
                if m:
                    mutation_counter[pango][m] += 1

    lineage_size_cutoff = 100
    big_lineages = [pango for pango in pango_counter if pango_counter[pango]>lineage_size_cutoff]
    nlin = len(big_lineages)
    tmp_number_of_pango_muts = defaultdict(int)
    for pango, muts in pango_gts.items():
        if pango_counter[pango]>lineage_size_cutoff:
            for m in muts['nuc']:
                tmp_number_of_pango_muts[int(m[1:-1])-1] += 1
    number_of_pango_muts = np.array([tmp_number_of_pango_muts.get(pos,0)/nlin for pos in range(len(ref))])


    # for each pango lineage, count mutations by position, as well as synoymous and non-synonyomous by codon
    position_counter = defaultdict(lambda: defaultdict(list))
    position_freq = defaultdict(lambda: defaultdict(list))
    syn =    defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    nonsyn = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for pango, pmuts in mutation_counter.items():
        if pango not in pango_gts:
            print("missing pango", pango)
            continue
        seq = get_sequence(pango_gts[pango], ref_array)
        # translate each codon in pango reference sequence
        for mut,c in pmuts.items():
            a, pos, d = mut[0], int(mut[1:-1])-1, mut[-1]
            if a=='-':
                continue
            position_counter[pos][(a,d)].append(c)
            if pango_counter[pango]>lineage_size_cutoff and c>1:
                position_freq[pos][(a,d)].append(c/pango_counter[pango])

            if pos in map_to_gene:
                gene = map_to_gene[pos]
                pos_in_gene = pos - gene_position[gene].start
                frame = pos_in_gene%3
                codon = seq[pos-frame:pos-frame+3]
                codon_number = pos_in_gene//3
                alt = codon.copy()
                alt[frame] = d
                aa = translate(codon)
                alt_aa = translate(alt)
                if aa==alt_aa and aa!='X':
                    syn[gene][codon_number][a].append(c)
                elif 'X' not in [aa,alt_aa]:
                    nonsyn[gene][codon_number][a].append(c)

    sorted_transitions = []
    for a in 'ACGT':
        for d in 'ACGT':
            if a!=d:
                sorted_transitions.append((a,d))

    # determine the total number of events and counts at each position
    def total_len_filtered_lists(l, cutoff, rates=None):
        if rates is None:
            rates = {}
        return sum([len([y for y in x if y>=cutoff])/rates.get(a,1.0) for (a,d),x in l.items()])

    all_events = []
    for pos in range(len(ref)):
        tmp = []
        for a, d in sorted_transitions:
            tmp.append(total_len_filtered_lists({(a,d): position_counter[pos][(a,d)]}, args.rare_cutoff))
        all_events.append(tmp)

    # sort all mutations by type of mutation
    events_by_transition = defaultdict(int)
    for m in position_counter.values():
        for t in m:
            events_by_transition[t]+=len(m[t])
    # determine the mutations rate for each type, as well as the rate out of a nuc
    transition_rate = dict()
    away_rate = defaultdict(float)
    total_rate = np.sum(list(events_by_transition.values()))
    for a,d in events_by_transition:
        if a in base_content:
            transition_rate[(a,d)] = events_by_transition[(a,d)]/base_content[a]/total_rate*len(ref)
            away_rate[a] += transition_rate[(a,d)]


    total_events = np.array([total_len_filtered_lists(position_counter[pos], args.rare_cutoff)
                            if pos in position_counter else 0
                            for pos in range(len(ref))])

    total_events_weighted = np.array([total_len_filtered_lists(position_counter[pos], args.rare_cutoff, away_rate)
                            if pos in position_counter else 0
                            for pos in range(len(ref))])

    total_counts = np.array([sum([sum(x) for x in position_counter[pos].values()])
                             if pos in position_counter else 0
                             for pos in range(len(ref))])

    # average mutation rate used to scale rates
    total_events_rescaled = np.array([c/away_rate[nuc] for c,nuc in zip(total_events, ref.seq)])
    total_events_rescaled /= np.median(total_events_rescaled)
    total_events_weighted /= np.median(total_events_weighted)

    with open(args.output_events, 'w') as fh:
        fh.write('\t'.join(['position', 'ref_state', 'lineage_fraction_with_changes', 'gene', 'codon', 'pos_in_codon']
                            + [f"{a}->{d}" for a,d in sorted_transitions]) + '\n')
        for pos, nuc in enumerate(ref_array):
            gene = map_to_gene.get(pos,"")
            data = "\t".join([f"{all_events[pos][i]}" for i in range(len(sorted_transitions))])
            if gene:
                gene_pos = pos - gene_position[gene].start
                codon = gene_pos//3 + 1
                cp = gene_pos%3 + 1
                fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t{gene}\t{codon}\t{cp}\t' + data + "\n")
            else:
                fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t\t\t\t' + data + "\n")

    with open(args.output_fitness, 'w') as fh:
        fh.write('\t'.join(['position', 'ref_state', 'lineage_fraction_with_changes', 'gene', 'codon', 'pos_in_codon', 'total_count', 'total_events', 'tolerance']) + '\n')
        for pos, nuc in enumerate(ref_array):
            gene = map_to_gene.get(pos,"")
            if gene:
                gene_pos = pos - gene_position[gene].start
                codon = gene_pos//3 + 1
                cp = gene_pos%3 + 1
                fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t{gene}\t{codon}\t{cp}\t{total_counts[pos]}\t{total_events[pos]}\t{total_events_weighted[pos]:1.3f}\n')
            else:
                fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t\t\t\t{total_counts[pos]}\t{total_events[pos]}\t{total_events_weighted[pos]:1.3f}\n')

    with open(args.output_mutations, 'w') as fh:
        fh.write('\t'.join(['mutation', 'raw_counts', 'origin_sites', 'scaled_rate']) + '\n')
        for a in 'ACTG':
            for d in 'ACGT':
                if a==d: continue
                key = (a,d)
                fh.write(f"{a}->{d}\t{events_by_transition[key]}\t{base_content[a]}\t{transition_rate[key]:1.3f}\n")


    # # calculate synonymous and non-synonymous distributions
    # total_events_by_type = {}
    # for mut_counts, label in [(syn, 'syn'), (nonsyn, 'nonsyn')]:
    #     total_events_by_type[label] = {}
    #     for gene in mut_counts:
    #         total_events_by_type[label][gene] = [np.sum([len([x for x in mut_counts[gene][pos][a] if x>cutoff])/away_rate[a]
    #                                                  for a in mut_counts[gene][pos]])
    #                                                for pos in range(gene_length[gene])]


    # def calc_fitness_cost(freqs, transition):
    #     mu = 0.0004/50
    #     mut_rate = transition_rate[transition]*mu
    #     avg_freq = np.sum(freqs)/nlin
    #     return mut_rate/(avg_freq + 1e-3/nlin)

    # fitness_cost = []
    # for pos,nuc in enumerate(ref_array):
    #     total_muts = np.sum([len(v) for v in position_freq[pos].values()])
    #     reference_muts = np.sum([len(v) for (a,d), v in position_freq[pos].items() if a==nuc])
    #     if reference_muts>0.9*total_muts:
    #         fitness_cost.append([np.inf if d==nuc
    #                              else calc_fitness_cost(position_freq[pos][(nuc,d)],(nuc,d))
    #                             for d in 'ACGT'])
    #     else:
    #         fitness_cost.append([np.nan, np.nan, np.nan, np.nan])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import pandas as pd
import argparse,json
from datetime import datetime
import numpy as np
from root_to_tip import filter_and_transform, get_clade_gts
from collections import defaultdict

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--metadata', nargs='+', type=str, required=True, help="input data")
    parser.add_argument('--clade-gts', type=str, required=True, help="input data")
    parser.add_argument('--clade', type=str, required=True, help="input data")
    parser.add_argument('--sub-clades', type=str, required=True, help="input data")
    parser.add_argument('--min-date', type=float, help="input data")
    parser.add_argument('--max-date', type=float, help="input data")
    parser.add_argument('--max-group', type=int, help="input data")
    parser.add_argument('--query', type=str, help="filters")
    parser.add_argument('--bin-size', type=float, help="input data")
    parser.add_argument('--output-json', type=str, help="rate file")
    args = parser.parse_args()

    clade_gt = get_clade_gts(args.clade_gts, args.sub_clades)

    d = pd.concat([pd.read_csv(x, sep='\t').fillna('') for x in args.metadata])
    filtered_data, _ = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.max_date,
                                         query = args.query, max_group=args.max_group, QC_threshold=80 if args.clade=='21H' else 30,
                                         completeness=0, swap_root=args.clade.startswith('19B+'))
    filtered_data["day"] = filtered_data.datetime.apply(lambda x:x.toordinal())

    print("clade", args.clade, "done filtering")
    bins = np.arange(np.min(filtered_data.day),np.max(filtered_data.day), args.bin_size)
    all_sequences = np.histogram(filtered_data.day, bins=bins)[0]
    cumulative_sum = np.zeros_like(bins[:-1])

    mutation_number = {}
    nmax = 5
    nmax_extra = 15
    for n in range(nmax):
        ind = filtered_data["divergence"]==n
        mutation_number[n] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0]
        cumulative_sum += mutation_number[n]
    mutation_number[f'{nmax}+'] = all_sequences - cumulative_sum
    for n in range(nmax,nmax_extra):
        ind = filtered_data["divergence"]==n
        mutation_number[n] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0]
        cumulative_sum += mutation_number[n]

    print("clade", args.clade, "done mutation_number")

    mutation_counts = defaultdict(int)
    cutoff = args.min_date + (args.max_date - args.min_date)/2
    window = cutoff - 14/365, cutoff + 14/365

    ind_early = filtered_data.numdate<cutoff
    ind_window = (filtered_data.numdate>=window[0]) & (filtered_data.numdate<window[1])
    n_early = ind_early.sum()
    n_window = ind_window.sum()

    # Find common mutations
    for muts in filtered_data.loc[ind_early, "intra_substitutions"]:
        for m in muts:
            mutation_counts[m] +=1

    relevant_muts = [x[0] for x in sorted(list(mutation_counts.items()), key=lambda k:k[1])[-10:]]
    mutations = {}
    nmax = 5
    for mut in relevant_muts:
        ind = filtered_data["intra_substitutions"].apply(lambda x: mut in x)
        mutations[mut] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0]

    # Get mutation _spectrum
    mutation_spectrum = defaultdict(float)
    for muts in filtered_data.loc[ind_window, "intra_substitutions"]:
        for m in muts:
            mutation_spectrum[m] += 1.0/n_window

    relevant_muts = [x[0] for x in sorted(list(mutation_counts.items()), key=lambda k:k[1])[-10:]]
    mutations = {}
    nmax = 5
    for mut in relevant_muts:
        ind = filtered_data["intra_substitutions"].apply(lambda x: mut in x)
        mutations[mut] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0]

    print("clade", args.clade, "done mutations")

    intra_geno = filtered_data.loc[ind_early,"intra_substitutions_str"].value_counts()
    genotypes = {}
    gt_count = 0
    for x,i in intra_geno.items():
        nmuts = len(x.split(',')) if x else 0
        ind = filtered_data["intra_substitutions_str"]==x
        if i<10 and nmuts or nmuts>4:
            continue
        if nmuts==0 or (nmuts==1 and ind.sum()>0.005*n_early) or (ind.sum()>0.01*n_early):
            genotypes[x] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0]
            gt_count +=1
            if gt_count>8:
                break

    print("clade", args.clade, "done genotypes")
    with open(args.output_json, 'w') as fh:
        json.dump({'bins':[int(x) for x in bins],
                   'all_samples':[int(x) for x in all_sequences],
                   'mutation_number': {k:[int(x) for x in v] for k,v in mutation_number.items()},
                   'mutation_spectrum': {k: float(v) for k,v in mutation_spectrum.items()},
                   'mutations': {k:[int(x) for x in v] for k,v in mutations.items()},
                   'genotypes': {k:[int(x) for x in v] for k,v in genotypes.items()}}, fh)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pandas as pd
from collections import defaultdict
import argparse,json

def genotype_struct():
    return {'nuc':{}, 'aa':defaultdict(dict)}

def assign_genotypes(n):
    gt = n["genotype"]
    if "children" in n:
        for c in n["children"]:
            cgt = genotype_struct()

            cgt["nuc"].update({k:v for k,v in gt["nuc"].items()})
            if "nuc" in c["branch_attrs"]["mutations"]:
                for mut in c["branch_attrs"]["mutations"]["nuc"]:
                    a,pos,d = mut[0], int(mut[1:-1]), mut[-1]
                    cgt["nuc"][pos] = d


            for gene in set(c["branch_attrs"]["mutations"].keys()).union(set(gt["aa"].keys())):
                if gene=='nuc':
                    continue
                cgt["aa"][gene].update({k:v for k,v in gt["aa"][gene].items()})
                if gene in c["branch_attrs"]["mutations"]:
                    for mut in c["branch_attrs"]["mutations"][gene]:
                        a,pos,d = mut[0], int(mut[1:-1]), mut[-1]
                        cgt['aa'][gene][pos] = d
            c["genotype"] = cgt
            assign_genotypes(c)

def get_pango_genotypes(n, pango_gts):
    name = n["name"]
    if "children" in n and len(n["children"]):
        for c in n["children"]:
            get_pango_genotypes(c, pango_gts)
    elif (name in ['A', 'B'] or '.' in name) and not ('/' in name):
        pango_gts[n["name"]] = n["genotype"]


if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="get variant genotypes",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--tree', type=str, required=True, help="input data")
    parser.add_argument('--root', type=str, required=True, help="input data")
    parser.add_argument('--output', type=str, required=True, help="output data")
    args = parser.parse_args()

    with open(args.tree) as fh:
        tree = json.load(fh)['tree']

    with open(args.root) as fh:
        root_sequence = json.load(fh)

    tree["genotype"] = genotype_struct()
    assign_genotypes(tree)
    pango_gts = {}
    get_pango_genotypes(tree, pango_gts)

    subs = {}
    for clade in pango_gts:
        tmp_nuc = []
        if 'nuc' in pango_gts[clade]:
            for pos, d in pango_gts[clade]['nuc'].items():
                a=root_sequence['nuc'][pos-1]
                if d not in ['N', '-'] and d!=a:
                    tmp_nuc.append(f"{a}{pos}{d}")

        tmp_aa = []
        for gene in pango_gts[clade]['aa']:
            for pos, d in pango_gts[clade]['aa'][gene].items():
                a=root_sequence[gene][pos-1]
                if d not in ['X', '-'] and d!=a:
                    tmp_aa.append(f"{gene}:{a}{pos}{d}")
        subs[clade] = {'nuc':tmp_nuc, 'aa':tmp_aa}

    with open(args.output, 'w') as fh:
        json.dump(subs, fh)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse,json
from datetime import datetime
import numpy as np
import matplotlib as mpl
mpl.rcParams['axes.formatter.useoffset'] = False

import matplotlib.pyplot as plt
import pandas as pd

from plot_genotype_counts import fit_poisson


if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--counts', nargs='+', type=str, required=True, help="input data")
    parser.add_argument('--output-plot', type=str, help="figure file")
    parser.add_argument('--output-rates', type=str, help="table")
    args = parser.parse_args()
    counts = {}
    for fi,fname in enumerate(args.counts):
        clade = fname.split('/')[-1].split('_')[0]
        with open(fname) as fh:
            counts[clade] = json.load(fh)

    fig, axs = plt.subplots(1,1, figsize = (8,6))
    ax = axs
    ls = ['-', '-.', '--']
    for fi,(clade,d) in enumerate(counts.items()):
        ax.plot(sorted(d["mutation_spectrum"].values()), np.linspace(1,0,len(d["mutation_spectrum"])+1)[:-1],
                       label=clade, ls=ls[fi//10])

    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.set_xlabel('mutation frequency')
    ax.set_ylabel('fraction above')
    x = np.logspace(-4,-0.3,21)
    ax.plot([1e-5,1e-1], [1e-0,1e-4], c='k', lw=3, alpha=0.5, label='1/x')
    # ax.plot(x, np.log(x)/np.log(x[0]), c='k', lw=3, alpha=0.5, label='~log(x)')
    plt.legend(ncol=3)
    plt.savefig(args.output_plot)


    ls = ['-', '-.', '--']
    data = []
    for fi,(clade,d) in enumerate(counts.items()):
        dates =   np.array([x for x in d['bins'][:-1]])
        datetimes =   np.array([datetime.fromordinal(x) for x in d['bins'][:-1]])
        t0 = dates[0]
        dates -= t0
        total =   np.array(d["all_samples"])
        k_vecs = { i: np.array(d["mutation_number"][f'{i}']) for i in ['0', '1', '2', '3', '4', '5+']}

        ind = total>0
        res = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, dates[ind])
        start_date = datetime.fromordinal(int(t0 + res['offset'])).strftime("%Y-%m-%d")

        res_fixed_rate = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, dates[ind], rate=15/365)
        start_date_fixed = datetime.fromordinal(int(t0 + res_fixed_rate['offset'])).strftime("%Y-%m-%d")

        data.append({'clade':clade, 'rate':res['rate']*365, 'origin':start_date, 'origin_fixed_rate':start_date_fixed})

    pd.DataFrame(data).to_csv(args.output_rates, sep='\t')
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--fitness', type=str, required=True, help="input data")
    parser.add_argument('--output-fitness-landscape', type=str, help="fitness figure")
    args = parser.parse_args()

    fitness = pd.read_csv(args.fitness, sep='\t', index_col=0)
    genes = [x for x in fitness.gene.unique() if not pd.isna(x)]
    pos = np.arange(len(fitness))

    fig = plt.figure(figsize=(15,12))
    gene_groups = [ ['ORF3a', 'E', 'M', 'ORF6','ORF7a', 'ORF7b', 'ORF8'],['S', 'N'],['ORF1b'],['ORF1a'] ]
    gene_length = {k:(fitness.gene==k).sum() for k in genes}

    fs = 14
    ws = 20
    w = np.ones(ws)/ws
    ws_fine = 7
    w_fine = np.ones(ws_fine)/ws_fine
    lower_fitness_cutoff = 0.03
    n_rows = len(gene_groups)
    h_spread = 0.003
    h_margin = 0.05
    v_spread = 0.02
    v_margin = 0.05
    height = (1 - v_spread*n_rows - v_margin)/n_rows
    for row,group in enumerate(gene_groups):
        axis_bottom = row*(v_spread + height) + v_margin
        available_width = 1 - len(group)*h_spread - h_margin
        left = h_margin
        total_genome = np.sum([gene_length[g] for g in group])
        for gi,gene in enumerate(group):
            width = gene_length[gene]/total_genome*available_width
            ax = fig.add_axes((left,axis_bottom, width, height))
            left += width + h_spread
            for cp in range(3):
                ind = (fitness.pos_in_codon==(cp+1)) & (fitness.gene==gene)
                gene_pos = (fitness.codon[ind]*3 + cp - 1)/3
                gene_pos_smooth = np.convolve(gene_pos, w, mode='valid')
                ax.plot(gene_pos_smooth,
                        np.convolve(np.log10(fitness.tolerance[ind]+lower_fitness_cutoff), w, mode='valid'),
                                    c=f'C{cp}', label=f'codon pos {cp+1}')

                gene_pos_smooth = np.convolve(gene_pos, w_fine, mode='valid')
                ax.plot(gene_pos_smooth,
                        np.convolve(np.log10(fitness.tolerance[ind]+lower_fitness_cutoff), w_fine, mode='valid'),
                                    c=f'C{cp}', alpha = 0.5)

                ax.plot(gene_pos,
                    np.log10(fitness.lineage_fraction_with_changes[ind]+.001)/3 - 1.2, 'o', c=f'C{cp}')

            ax.plot(gene_pos, np.zeros_like(gene_pos), c='k', alpha=0.3, lw=2)
            ax.set_ylim(-2.0, 1)
            ax.set_xlim(0, len(gene_pos))

            if gi==0:
                ax.set_ylabel('log10 scaled tolerance')
                if row==n_rows-1:
                    ax.legend(ncol=3, fontsize=fs)
            else:
                ax.set_yticklabels([])

            ax.text(0.15,0.75, gene, fontsize=fs*1.1)

    plt.savefig(args.output_fitness_landscape)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse,json
import pandas as pd
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from scipy.stats import scoreatpercentile

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--fitness', type=str, required=True, help="input data")
    parser.add_argument('--mutations', type=str, required=True, help="input data")
    parser.add_argument('--output-fitness', type=str, help="fitness figure")
    parser.add_argument('--output-fitness-by-gene', type=str, help="fitness figure")
    parser.add_argument('--output-mutations', type=str, help="fitness figure")
    args = parser.parse_args()

    mutations = pd.read_csv(args.mutations, sep='\t', index_col=0)
    fitness = pd.read_csv(args.fitness, sep='\t', index_col=0)

    ## figure with mutation distributions
    plt.figure()
    sorted_muts = mutations.sort_values('scaled_rate', ascending=False)
    mut_sum = np.sum(sorted_muts.scaled_rate)
    muts = [(t,v/mut_sum) for t,v in sorted_muts.scaled_rate.items()]
    plt.bar(np.arange(len(muts)), height=[m[1] for m in muts])
    plt.xticks(np.arange(len(muts)), [m[0] for m in muts], rotation=30)
    plt.ylabel('fraction')
    plt.savefig(args.output_mutations)

    ## Figure with 1st, 2nd, 3rd positions
    plt.figure()
    for p in range(1,4):
        ind = fitness.pos_in_codon==p
        plt.plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()),
                        label = f"codon pos={p}")

    ind = fitness.pos_in_codon.isna()
    plt.plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()),
                    label = f"non-coding")

    ind = fitness.pos_in_codon==3
    syn_cutoff = scoreatpercentile(fitness.tolerance[ind],10)
    plt.plot([syn_cutoff, syn_cutoff], [0,1], c='k', alpha=0.3)
    for i in range(1,3):
        ind = fitness.pos_in_codon==i
        print("codon pos", i, np.mean(fitness.tolerance[ind]<syn_cutoff))
    ind = fitness.pos_in_codon.isna()
    print("non coding", np.mean(fitness.tolerance[ind]<syn_cutoff))

    plt.xscale('log')
    plt.ylabel("fraction below")
    plt.xlabel("scaled number of lineages with mutations")
    plt.legend()
    plt.savefig(args.output_fitness)

    # # figure with pdfs instead of cdfs
    # # plt.figure()
    # measure = total_events_rescaled
    # bins = np.logspace(0, np.ceil(np.log10(measure.max())), 101)
    # bc = np.sqrt(bins[:-1]*bins[1:])
    # bins[0] = 0
    # rate_estimate = {}
    # for p in range(4):
    #     ind = codon_pos==p
    #     y,x = np.histogram(measure[ind], bins=bins)
    #     rate_estimate[p] = {"mean": np.sum(bc*y/y.sum()),
    #                         "geo-mean": np.exp(np.sum(np.log(bc)*y/y.sum())),
    #                         "median": np.median(measure[ind])}
    #     # plt.plot(bc,y/y.sum(), label = 'non-coding' if p==0 else f"codon pos={p}")

    # two panel figure with 1/2nd and 3rd position mutations
    fig, axs = plt.subplots(2,1, figsize = (6,10), sharex=True)
    pos = np.arange(len(fitness))
    genes = [x for x in fitness.gene.unique() if not pd.isna(x)]
    for i,gene in enumerate(genes):
        c = f"C{i}"
        ls = '--' if i>9 else '-'
        ind = (fitness.pos_in_codon==3) & (fitness.gene==gene)
        axs[1].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), ls=ls, c=c,label = f'{gene}')

        ind = (fitness.pos_in_codon>0) & (fitness.pos_in_codon<3) & (fitness.gene==gene)
        axs[0].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()),
                        ls=ls, c=c)

    ind = fitness.pos_in_codon.isna()
    axs[1].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()),
                       label = 'non-coding', c='k')

    plt.xscale('log')
    plt.xlim(0.01, 10)
    for ax in axs:
        ax.grid()

    axs[1].legend(ncol=2)
    axs[1].set_title("3rd codon positions or non-coding")
    axs[0].set_title("1st and 2nd codon positions")
    axs[0].set_ylabel("fraction below")
    axs[1].set_ylabel("fraction below")
    axs[1].set_xlabel("scaled number of lineages with mutations")
    plt.savefig(args.output_fitness_by_gene)

    # fitness_cost = np.array(fitness_cost)
    # plt.figure()
    pos = np.arange(len(fitness))
    ws = 10
    w = np.ones(ws)/ws
    for i,gene in enumerate(genes):
        c = f"C{i}"
        ls = '--' if i>9 else '-'
        plt.figure()
        for cp in range(3):
            ind = (fitness.pos_in_codon==(cp+1)) & (fitness.gene==gene)
            #ind = (codon_pos==(cp+1)) & (pos>=gene_range.start) & (pos<gene_range.end) & (~np.isnan(fitness_cost).all(axis=1))
            #plt.hist(np.min(np.log(fitness_cost[ind]), axis=1))
            gene_pos = (fitness.codon[ind]*3 + cp - 1)/3
            gene_pos_smooth = np.convolve(gene_pos, w, mode='valid')
            plt.plot(gene_pos_smooth,
                     np.convolve(np.log10(fitness.tolerance[ind]+.03), w, mode='valid'),
                                c=f'C{cp}', label=f'codon pos {cp+1}')
            plt.plot(gene_pos, np.zeros_like(gene_pos), c='k', alpha=0.3, lw=2)
            plt.plot(gene_pos,
                     np.log10((np.log(fitness.lineage_fraction_with_changes[ind]+.001)+3.01)), 'o', c=f'C{cp}')
            plt.ylabel('log10 scaled mutations')
        plt.ylim(-1.5, 1)
        plt.title(gene)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import argparse,json
from datetime import datetime
from sys import exec_prefix
import numpy as np
import matplotlib as mpl
mpl.rcParams['axes.formatter.useoffset'] = False

import matplotlib.pyplot as plt
import seaborn as sns
from root_to_tip import add_panel_label
fs=12
def poisson(lam, order):
    p = np.exp(-lam)
    try:
        for i in range(0,int(order)):
            p *= lam/(i+1)
    except:
        p_cum = np.copy(p)
        for i in range(0,int(order[:-1])-1):
            p *= lam/(i+1)
            p_cum += p
        p = 1 - p_cum

    return np.maximum(0.0, np.minimum(1.0,p))

def fit_poisson(n,k_vecs,t, rate=None):
    from scipy.optimize import minimize
    eps = 1e-16
    def binom(x,n,k_vecs,t, rate):
        res = 0
        mu = rate or x[1]**2
        tu = np.maximum(0,mu*(t-x[0]))
        for order, k in k_vecs.items():
            p = poisson(tu, order)
            res -= np.sum(((n-k)*np.log(1-p+eps) + k*np.log(p+eps)))

        return res

    if rate is None:
        sol = minimize(binom, (-10,0.25), args=(n,k_vecs,t,rate), method="Powell")
        res = {'offset':sol['x'][0], 'rate':sol['x'][1]**2}
    else:
        sol = minimize(binom, (-10,), args=(n,k_vecs,t,rate), method="Powell")
        res = {'offset':sol['x'][0], 'rate':rate}

    return res


if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--counts', type=str, required=True, help="input data")
    parser.add_argument('--output-plot', type=str, help="figure file")
    args = parser.parse_args()

    with open(args.counts) as fh:
        counts = json.load(fh)

    plt.figure()
    nmuts = [m for m in counts['mutation_number'] if m[-1]!='+']
    for week in range(5,len(counts['mutation_number']['0'])):
        c = np.array([counts['mutation_number'][str(n)][week] for n in nmuts])
        if c.sum()>100:
            plt.plot(range(len(nmuts)), c/c.sum(), '-')


    dates = np.array([datetime.fromordinal(x) for x in counts['bins'][:-1]])
    t0 = counts['bins'][0]
    rel_date = np.array([x-t0 for x in counts['bins'][:-1]])
    fig, axs = plt.subplots(1,3, figsize = (18,6))

    ax = axs[2]
    ax.set_title("mutations per genome", fontsize=1.2*fs)
    # ax.plot(dates, counts['all_samples'], lw=3, c='k', alpha=0.3)
    total = np.array(counts['all_samples'])
    ind = total>0
    k_vecs = {m: np.array(counts['mutation_number'][m]) for m in ['0', '1', '2', '3', '4', '5+']}
    res = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, rel_date[ind])
    tu = (rel_date[ind]-res['offset'])*res['rate']
    for i, (m,k) in enumerate(k_vecs.items()):
        ax.plot(dates[ind], k[ind]/total[ind], 'o', label=f'{m} mutations',c=f"C{i}")
        plt.plot(dates[ind], poisson(tu,m), ls='-', c=f'C{i}')

    ax.set_ylim(8e-5, 2)

    ax = axs[1]
    ax.set_title("mutation frequencies", fontsize=1.2*fs)
    ax.plot(dates, total, lw=3, c='k', alpha=0.3, label='total')
    for m in sorted(counts['mutations'].keys(), key=lambda x:int(x[1:-1])):
        ax.plot(dates, counts['mutations'][m], '-o', label=f'{m}')

    ax = axs[0]
    ax.set_title("genotype frequencies", fontsize=1.2*fs)
    ax.plot(dates, total, lw=3, c='k', alpha=0.3, label='total')
    for m in sorted(counts['genotypes'].keys(), key=lambda x: len(x)):
        ax.plot(dates, counts['genotypes'][m], '-o', label=f'{m}' if m else "founder")

    fig.autofmt_xdate()
    for ax, label in zip(axs, 'DEF'):
        ax.set_yscale('log')
        ax.legend()
        add_panel_label(ax, label, fs=fs*1.8)


    plt.savefig(args.output_plot)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import pandas as pd
import argparse,json
from treetime.utils import numeric_date, datestring_from_numeric
from datetime import datetime
import numpy as np
from scipy.stats import linregress, scoreatpercentile
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from collections import defaultdict
mpl.rcParams['axes.formatter.useoffset'] = False

fs=14
date_reference = datetime(2020,1,1).toordinal()
def date_to_week_since2020(d):
    return (d.toordinal() - date_reference)//7

def week_since2020_to_date(d):
    return datetime.fromordinal(int(d*7 + date_reference))

def week_since2020_to_numdate(d):
    return numeric_date(week_since2020_to_date(d))

def filter_and_transform(d, clade_gt, min_date=None, max_date=None, query=None, completeness=None, swap_root=False, max_group=None, QC_threshold=30):
    dropped_seqs = {}
    # filter for incomplete dates
    d = d.loc[d.date.apply(lambda x:len(x)==10 and 'X' not in x)]

    if query:
        pre = len(d)
        d = d.query(query)
        dropped_seqs['query'] = pre - len(d)

    d['datetime'] = d.date.apply(lambda x: datetime.strptime(x, '%Y-%m-%d'))
    d['numdate'] = d.datetime.apply(lambda x: numeric_date(x))
    d['CW'] = d.datetime.apply(date_to_week_since2020)
    # filter date range
    if min_date:
        pre = len(d)
        d = d.loc[d.numdate>min_date]
        dropped_seqs['min_date'] = pre - len(d)
    if max_date:
        pre = len(d)
        d = d.loc[d.numdate<max_date]
        dropped_seqs['max_date'] = pre - len(d)

    pre = len(d)
    d = d.loc[d.QC_overall_score<QC_threshold]
    dropped_seqs['QC'] = pre - len(d)

    # look for clade defining substitutions
    d["clade_substitutions"] = d.substitutions.apply(lambda x:     [y for y in x.split(',') if y in clade_gt['nuc']] if x else [])
    # assign number of substitutions missing in the clade definition
    d["missing_subs"] = d.clade_substitutions.apply(lambda x: len(clade_gt['nuc'])-len(x))

    # define "with-in clade substitutions"
    d["intra_substitutions"] = d.substitutions.apply(lambda x:     [y for y in x.split(',')
                                                     if all([y not in clade_gt['nuc'], int(y[1:-1])>150,  int(y[1:-1])<29753])] if x else [])
    # define "with-in clade substitutions"
    d["intra_aaSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and 'ORF9' not in y] if x else [])
    d["intra_SpikeSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[0]=='S'] if x else [])
    d["intra_ORF1aSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[:5]=='ORF1a'] if x else [])
    d["intra_ORF1bSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[:5]=='ORF1b'] if x else [])
    d["intra_ENMSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[0] in ['E','N','M']] if x else [])

    if swap_root:
        muts = [("C8782T","T8782C"), ("T28144C","C28144T")]
        def swap(mutations, pair):
            return [y for y in mutations if y!=pair[0]] if pair[0] in mutations else mutations + [pair[1]]
        for m in muts:
            d["intra_substitutions"] = d.intra_substitutions.apply(lambda x: swap(x,m))
        aa_mut = ("ORF8:L84S","ORF8:S84L")
        d["intra_aaSubstitutions"] = d.intra_aaSubstitutions.apply(lambda x: swap(x,aa_mut))


    # make a hashable string representation
    d["intra_substitutions_str"] = d.intra_substitutions.apply(lambda x: ','.join(x))
    # within clade divergence
    d["divergence"] =   d.intra_substitutions.apply(lambda x:   len(x))
    d["aaDivergence"] = d.intra_aaSubstitutions.apply(lambda x: len(x))
    d["spikeDivergence"] = d.intra_SpikeSubstitutions.apply(lambda x: len(x))
    d["orf1aDivergence"] = d.intra_ORF1aSubstitutions.apply(lambda x: len(x))
    d["orf1bDivergence"] = d.intra_ORF1bSubstitutions.apply(lambda x: len(x))
    d["enmDivergence"] = d.intra_ENMSubstitutions.apply(lambda x: len(x))
    d["synDivergence"] = d["divergence"] - d["aaDivergence"]


    # filter
    if completeness is not None:
        pre = len(d)
        d = d.loc[d.missing_subs<=completeness]
        dropped_seqs['completeness'] = pre - len(d)

    if max_group:
        return d.groupby(['CW', 'country']).sample(max_group, replace=True).drop_duplicates(subset='strain')

    return d, dropped_seqs

def weighted_regression(x,y,w):
    '''
    This function determine slope and intercept by minimizing
    sum_i w_i*(y_i - f(x_i))^2 with f(x) = slope*x + intercept

    sum_i w_i*(y_i - f(x_i)) = 0 => sum_i w_i y_i = intercept * sum_i w_i + slope sum_i w_i x_i
    sum_i w_i*(y_i - f(x_i)) x_i = 0 => sum_i w_i y_i x_i = intercept * sum_i w_i x_i + slope sum_i w_i x_i^2
    '''
    wa=np.array(w)
    xa=np.array(x)
    ya=np.array(y)
    wx = np.sum(xa*wa)
    wy = np.sum(ya*wa)
    wxy = np.sum(xa*ya*wa)
    wxx = np.sum(xa**2*wa)
    wsum = np.sum(wa)
    wmean = np.mean(wa)

    slope = (wy*wx - wxy*wsum)/(wx**2 - wxx*wsum)
    intercept = (wy - slope*wx)/wsum

    # not correct
    # hessianinv = np.linalg.inv(np.array([[wxx, wx], [wx, wsum]])/wsum)
    # stderrs = wmean*np.sqrt(hessianinv.diagonal())

    return {"slope": slope, "intercept":intercept} #, 'slope_err':stderrs[0], 'intercept_err':stderrs[1]}


def regression_by_week(d, field, min_count=5):
    val = d.loc[:,[field,'CW']].groupby('CW').mean()
    std = d.loc[:,[field,'CW']].groupby('CW').std()
    count = d.loc[:,[field,'CW']].groupby('CW').count()
    ind = count[field]>min_count
    #reg = linregress(val.index[ind], val.loc[ind, field])
    # slope = reg.slope * 365 / 7
    # intercept = reg.intercept - 2020*slope
    reg = weighted_regression(val.index[ind], val.loc[ind, field], np.array((count[ind]-min_count)**0.25).squeeze())
    slope = reg["slope"] * 365 / 7
    intercept = reg["intercept"] - 2020*slope

    return {"slope":slope, "intercept":intercept,
            "origin": -intercept/slope,
            "date":[week_since2020_to_numdate(x) for x in val.index[ind]],
            "mean":[x for x in val.loc[ind, field]],
            "stderr":[x for x in std.loc[ind, field]],
            "count":[x for x in count.loc[ind, field]]}

def make_date_ticks(ax):
    ax.set_xlabel('')
    ax.set_xticklabels([datestring_from_numeric(x) for x in ax.get_xticks()], rotation=30, horizontalalignment='right')

def get_clade_gts(all_gts, subclade_str):
    with open(all_gts) as fh:
        clade_gts = json.load(fh)

    subclades = subclade_str.split(',')
    if len(subclades)>1:
        clade_gt = {'nuc':{}, 'aa':{}}
        clade_gt['nuc'] = set.intersection(*[set(clade_gts[x]['nuc']) for x in subclades])
        clade_gt['aa'] = set.intersection(*[set(clade_gts[x]['aa']) for x in subclades])
    else:
        clade_gt = clade_gts[subclade_str]

    return clade_gt

def add_panel_label(ax, t, fs=16):
    ax.text(-0.1,1,t, fontsize=fs,transform=ax.transAxes)

if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--metadata',  type=str, required=True, help="input data")
    parser.add_argument('--clade-gts', type=str, required=True, help="input data")
    parser.add_argument('--clade', type=str, required=True, help="input data")
    parser.add_argument('--sub-clades', type=str, required=True, help="input data")
    parser.add_argument('--qc-cutoff-scale', type=float, default=2, help="number of std dev of rtt filter")
    parser.add_argument('--qc-cutoff-offset', type=float, default=3, help="number extra mutation")
    parser.add_argument('--min-date', type=float, help="input data")
    parser.add_argument('--max-date', type=float, help="input data")
    parser.add_argument('--max-group', type=int, help="input data")
    parser.add_argument('--query', type=str, help="filters")
    parser.add_argument('--output-plot', type=str, help="plot file")
    parser.add_argument('--output-json', type=str, help="rate file")
    args = parser.parse_args()

    clade_gt = get_clade_gts(args.clade_gts, args.sub_clades)

    d = pd.read_csv(args.metadata, sep='\t').fillna('')
    filtered_data, dropped_seqs = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.max_date,
                                         query = args.query, max_group=args.max_group, QC_threshold=80 if args.clade=='21H' else 30,
                                         completeness=3 if args.clade>='22D' else 0, swap_root=args.clade.startswith('19B+'))

    regression = linregress(filtered_data.numdate, filtered_data.divergence)
    filtered_data["residuals"] = filtered_data.apply(lambda x: x.divergence - (regression.intercept + regression.slope*x.numdate), axis=1)
 #   iqd = scoreatpercentile(filtered_data.residuals, 75) -  scoreatpercentile(filtered_data.residuals, 25)
 #   filtered_data["outlier"] = filtered_data.residuals.apply(lambda x: np.abs(x)>5*iqd)

    tolerance = lambda t: args.qc_cutoff_offset + args.qc_cutoff_scale*np.sqrt(np.maximum(0,(regression.intercept + regression.slope*t)))
    filtered_data["outlier"] = filtered_data.apply(lambda x: np.abs(x.residuals)>tolerance(x.numdate), axis=1)
    ind = filtered_data.outlier==False
    # regression_clean = linregress(filtered_data.numdate[ind], filtered_data.divergence[ind])
    # regression_clean_aa = linregress(filtered_data.numdate[ind], filtered_data.aaDivergence[ind])
    regression_clean = regression_by_week(filtered_data.loc[ind], "divergence")
    regression_clean_aa = regression_by_week(filtered_data.loc[ind], "aaDivergence")
    regression_clean_syn = regression_by_week(filtered_data.loc[ind], "synDivergence")
    regression_clean_spike = regression_by_week(filtered_data.loc[ind], "spikeDivergence")
    regression_clean_ORF1a = regression_by_week(filtered_data.loc[ind], "orf1aDivergence")
    regression_clean_ORF1b = regression_by_week(filtered_data.loc[ind], "orf1bDivergence")
    regression_clean_ENM = regression_by_week(filtered_data.loc[ind], "enmDivergence")

    fig, axs = plt.subplots(1,3, figsize=(18,6), sharex=True, sharey=True)
    ymax = 20
    bins = bins=(20,np.arange(-0.5,ymax+0.5))
    sns.histplot(x=filtered_data.numdate, y=np.minimum(ymax*1.5, filtered_data.divergence), bins=bins, ax=axs[0])
    x = np.linspace(*axs[0].get_xlim(),101)
    axs[0].set_title(f'all differences', fontsize=fs*1.2)
    axs[0].plot(x, regression_clean["intercept"] + regression_clean["slope"]*x, lw=4, label=f"slope = {regression_clean['slope']:1.1f} subs/year")
    axs[0].errorbar(regression_clean["date"], regression_clean["mean"], regression_clean["stderr"])
    axs[0].plot(x, regression.intercept + regression.slope*x + tolerance(x), lw=4)

    axs[1].set_title(f'amino acid differences', fontsize=fs*1.2)
    sns.histplot(x=filtered_data.numdate[ind], y=np.minimum(ymax*1.5, filtered_data.aaDivergence[ind]), bins=bins, ax=axs[1])
    axs[1].plot(x, regression_clean_aa["intercept"] + regression_clean_aa["slope"]*x, lw=4, label=f"slope = {regression_clean_aa['slope']:1.1f} subs/year")
    axs[1].errorbar(regression_clean_aa["date"], regression_clean_aa["mean"], regression_clean_aa["stderr"])

    axs[2].set_title(f'synonymous differences', fontsize=fs*1.2)
    sns.histplot(x=filtered_data.numdate[ind], y=np.minimum(ymax*1.5, filtered_data.synDivergence[ind]), bins=bins, ax=axs[2])
    axs[2].plot(x, regression_clean_syn["intercept"] + regression_clean_syn["slope"]*x, lw=4, label=f"slope = {regression_clean_syn['slope']:1.1f} subs/year")
    axs[2].errorbar(regression_clean_syn["date"], regression_clean_syn["mean"], regression_clean_syn["stderr"])

    axs[2].text(0.8,0.9, args.clade, fontsize=fs*1.5, transform=axs[2].transAxes)
    axs[0].set_ylabel("Divergence", fontsize=fs)
    for ax,label in zip(axs,'ABC'):
        make_date_ticks(ax)
        ax.set_yticks(np.arange(0,ymax,3))
        ax.legend(loc=2, fontsize=fs)
        ax.set_ylim(-0.5,ymax-0.5)
        add_panel_label(ax, label, fs=fs*1.8)

    if args.output_plot:
        plt.savefig(args.output_plot)
    else:
        plt.show()

    aaCounter = defaultdict(int)
    nucCounter = defaultdict(int)
    for subs in filtered_data.intra_aaSubstitutions:
        for a in subs: aaCounter[a]+=1
    for subs in filtered_data.intra_substitutions:
        for a in subs: nucCounter[a]+=1

    total = len(filtered_data)
    top_aaSubs = {a:v/total for a,v in sorted(aaCounter.items(),key=lambda x:x[1], reverse=True)[:100]}
    top_nucSubs = {a:v/total for a,v in sorted(nucCounter.items(),key=lambda x:x[1], reverse=True)[:100]}

    rate_data = {'clade':args.clade, 'nuc':regression_clean, 'aa':regression_clean_aa,
                 'syn':regression_clean_syn, 'spike':regression_clean_spike,
                 'orf1a':regression_clean_ORF1a,'orf1b':regression_clean_ORF1b,'enm':regression_clean_ENM,
                 "top_aaSubs": top_aaSubs,  "top_nucSubs": top_nucSubs,
                 "outliers_removed": int(np.sum(filtered_data.outlier)),
                 "qc_filter_fail": dropped_seqs["QC"],
                 "incomplete": dropped_seqs["completeness"],
                 "total_sequences": len(filtered_data)}
    if args.output_json:
        with open(args.output_json, 'w') as fh:
            json.dump(rate_data, fh)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pandas as pd
import numpy as np
import argparse

columns = ['strain', 'virus', 'gisaid_epi_isl', 'genbank_accession',
'date', 'region', 'country', 'length', 'host', 'Nextstrain_clade', 'pango_lineage',
'Nextclade_pango', 'missing_data', 'divergence', 'nonACGTN',
'rare_mutations', 'reversion_mutations', 'potential_contaminants',
'QC_overall_score', 'QC_overall_status', 'frame_shifts', 'deletions',
'insertions', 'substitutions', 'aaSubstitutions', 'clock_deviation']


if __name__=="__main__":
    parser = argparse.ArgumentParser(
        description="remove time info",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--metadata', type=str, required=True, help="input data")
    parser.add_argument('--variant-labels', nargs='+', type=str, required=True, help="input data")
    parser.add_argument('--variants', nargs='+', type=str, required=True, help="input data")
    args = parser.parse_args()

    files = [f"subsets/{v}.tsv" for v in args.variant_labels]

    d = pd.read_csv(args.metadata, sep='\t', usecols=columns).fillna('')
    d["Nextstrain_clade"] = d.Nextstrain_clade.apply(lambda x:x.split()[0] if x else '')

    for v, name in zip(args.variants, files):
        if len(v.split(','))>1:
            ind = np.any([d.Nextstrain_clade==y for y in v.split(',')], axis=0)
        else:
            ind = d.Nextstrain_clade==v

        subset = d.loc[ind]
        print(name, v, len(subset))
        subset.to_csv(name, sep='\t')
68
69
70
71
shell:
    """
    curl https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz -o {output}
    """
76
77
78
79
shell:
    """
    curl https://data.nextstrain.org/files/ncov/open/nextclade.tsv.gz -o {output}
    """
85
86
87
88
89
shell:
    """
    curl https://data.nextstrain.org/nextclade_sars-cov-2.json | gunzip > {output.tree}
    curl https://data.nextstrain.org/nextclade_sars-cov-2_root-sequence.json | gunzip > {output.root}
    """
100
101
102
103
shell:
    """
    python3 scripts/split_by_variant.py --metadata {input.metadata} --variants {params.variants} --variant-labels {params.variant_labels}
    """
SnakeMake From line 100 of master/Snakefile
112
113
114
115
shell:
    """
    python3 scripts/get_genotypes_pango.py --tree {input.tree} --root {input.root} --output {output.gt}
    """
SnakeMake From line 112 of master/Snakefile
130
131
132
133
134
135
136
137
138
139
140
    shell:
        """
        python3 scripts/root_to_tip.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \
                                       --clade-gts data/clade_gts.json \
                                       --min-date {params.mindate} \
                                       --max-date {params.maxdate} \
				       --qc-cutoff-scale 2 --qc-cutoff-offset 3 \
                                       {params.filter_query} \
                                       --output-plot {output.figure} \
                                       --output-json {output.json}
        """
SnakeMake From line 130 of master/Snakefile
155
156
157
158
159
160
161
162
163
164
shell:
    """
    python3 scripts/get_genotype_counts.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \
                                   --clade-gts data/clade_gts.json \
                                   --min-date {params.mindate} \
                                   --max-date {params.maxdate} \
                                   {params.filter_query} \
                                   --bin-size {params.bin_size} \
                                   --output-json {output.json}
    """
SnakeMake From line 155 of master/Snakefile
171
172
173
174
shell:
    """
    python3 scripts/plot_genotype_counts.py --counts {input.json} --output-plot {output.fig}
    """
SnakeMake From line 171 of master/Snakefile
187
188
189
190
191
192
193
shell:
    """
    python3 scripts/clone_growth.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \
                                   --clade-gts data/clade_gts.json \
                                   --min-date {params.mindate} \
                                   --output-plot {output.figure}
    """
SnakeMake From line 187 of master/Snakefile
201
202
203
204
shell:
    """
    python3 scripts/plot_af.py --counts {input.count_files} --output-plot {output.af_fig} --output-rates {output.rates}
    """
SnakeMake From line 201 of master/Snakefile
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
run:
    import json
    import pandas as pd
    from treetime.utils import datestring_from_numeric
    with open(input.gt) as fin:
        clade_gts = json.load(fin)

    offset_nuc = lambda clade: -2 if clade=='19B' else 2
    offset_nonsyn = lambda clade: -1 if clade=='19B' else 1

    data = []
    qc_data = []
    for fname in input.rate_files:
        with open(fname) as fin:
            d = json.load(fin)
        base_clade = d['clade'][:3]
        aa_div = len([x for x in clade_gts[base_clade]['aa'] if 'ORF9' not in x]) + offset_nonsyn(base_clade)
        nuc_div = len(clade_gts[base_clade[:3]]['nuc']) + offset_nuc(base_clade)
        data.append({'clade':d['clade'], 'nseqs':d['total_sequences'],
                     'nuc_rate': d['nuc']['slope'],
                     'nuc_origin': d['nuc']['origin'], 'nuc_origin_date': datestring_from_numeric(d['nuc']['origin']),
                     'aa_rate': d['aa']['slope'],
                     'aa_origin':d['aa']['origin'], 'aa_origin_date':datestring_from_numeric(d['aa']['origin']),
                     'syn_rate': d['syn']['slope'],
                     'syn_origin':d['syn']['origin'],'syn_origin_date':datestring_from_numeric(d['syn']['origin']),
                     'spike_rate': d['spike']['slope'],
                     'orf1a_rate': d['orf1a']['slope'],
                     'orf1b_rate': d['orf1b']['slope'],
                     'enm_rate': d['enm']['slope'],
                     'nuc_div': nuc_div, 'aa_div':aa_div, 'syn_div':nuc_div-aa_div})
        qc_data.append({'clade':d['clade'], 'nseqs':d['total_sequences'], "outliers":d["outliers_removed"],
                        "qc_fail":d["qc_filter_fail"]})

    df = pd.DataFrame(data)
    df.to_csv(output.rate_table, sep='\t')
    df.to_latex(output.rate_table_tex, float_format="%.2f", index=False)

    dfqc = pd.DataFrame(qc_data)
    dfqc.to_csv(output.qc_table, sep='\t')
274
275
276
277
278
279
280
shell:
    """
    python3 scripts/combine_fits.py --rate-table {input.rate_table}\
                                   --output-plot {output.figure} \
                                   --output-plot-rates {output.figure_rates} \
                                   --output-plot-rates-genes {output.figure_rates_genes}
    """
SnakeMake From line 274 of master/Snakefile
292
293
294
295
296
297
298
299
300
shell:
    """
    python3 scripts/count_mutations.py --metadata {input.nextclade}\
             --reference {input.ref} \
             --pango-gts {input.pango_gts} \
             --output-fitness {output.fitness_costs} \
             --output-events {output.all_events} \
             --output-mutations {output.mutation_rates}
    """
SnakeMake From line 292 of master/Snakefile
310
311
312
313
314
315
316
317
shell:
    """
    python3 scripts/plot_fitness.py --fitness {input.fitness_costs}\
             --mutations {input.mutation_rates} \
             --output-fitness {output.fitness_figure} \
             --output-fitness-by-gene {output.fitness_figure_by_gene} \
             --output-mutations {output.mutation_figure}
    """
SnakeMake From line 310 of master/Snakefile
324
325
326
327
328
shell:
    """
    python3 scripts/plot_fitness_landscape.py --fitness {input.fitness_costs}\
             --output-fitness-landscape {output.fitness_landscape}
    """
SnakeMake From line 324 of master/Snakefile
ShowHide 21 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/neherlab/SC2_variant_rates
Name: sc2_variant_rates
Version: VE_resubmission
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...