analyze substitution rate and mutation behavior within variants.
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output, operation, topic
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
Richard A. Neher
Continued evolution and adaptation of SARS-CoV-2 has lead to more transmissible and immune-evasive variants with profound impact on the course of the pandemic. Here I analyze the evolution of the virus over 2.5 years since its emergence and estimate rates of evolution for synonymous and non-synonymous changes separately for evolution within clades -- well defined mono-phyletic groups with gradual evolution -- and for the pandemic overall. The rate of synonymous mutations is found to be around 6 changes per year. Synonymous rates within variants vary little from variant to variant and are compatible with the overall rate. In contrast, the rate at which variants accumulate amino acid changes (non-synonymous mutation) was initially around 12-16 changes per year, but in 2021 and 2022 dropped to 6-9 changes per year. The overall rate of non-synonymous evolution, that is across variants, is estimated to be about 25 amino acid changes per year. This 2-fold higher rate indicates that the evolutionary process that gave rise to the different variants is qualitatively different from that in typical transmission chains and likely dominated by adaptive evolution. I further quantify the spectrum of mutations and purifying selection in different SARS-CoV-2 proteins. Many accessory proteins evolve under limited evolutionary constraint with little short term purifying selection. About half of the mutations in other proteins are strongly deleterious and rarely observed, not even at low frequency.
Repository structure
This repository contains scripts and source files associated with a manuscript on SARS-CoV-2 virus evolution.
The analysis can be run using
snakemake
and requires standard python libraries, as well as
treetime
(
phylo-treetime
in pip).
The analysis can be run using the open data files provisioned by Nextstrain (download default in the workflow).
The directory
manuscript
contains the
TeX
files associated with manuscript, the bibliography, and the figures.
The
data
directory contains some derived files, including the rate estimates, the mutation distribution, and fitness cost landscapes.
Code Snippets
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import pandas as pd import argparse,json from collections import defaultdict from datetime import datetime import numpy as np import matplotlib as mpl mpl.rcParams['axes.formatter.useoffset'] = False import matplotlib.pyplot as plt import seaborn as sns from root_to_tip import filter_and_transform, get_clade_gts if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--metadata', type=str, nargs='+', required=True, help="input data") parser.add_argument('--clade-gts', type=str, required=True, help="input data") parser.add_argument('--clade', type=str, required=True, help="input data") parser.add_argument('--sub-clades', type=str, required=True, help="input data") parser.add_argument('--min-date', type=float, help="input data") parser.add_argument('--output-plot', type=str, help="plot file") #parser.add_argument('--output-json', type=str, help="rate file") args = parser.parse_args() clade_gt = get_clade_gts(args.clade_gts, args.sub_clades) d = pd.concat([pd.read_csv(x, sep='\t').fillna('') for x in args.metadata]) filtered_data, _ = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.min_date + 0.3, completeness=0, swap_root=args.clade_gts=='19B+') #filtered_data=filtered_data.loc[filtered_data.country!='China'] intra_subs_dis = filtered_data["divergence"].value_counts().sort_index() intra_aaSubs_dis = filtered_data["aaDivergence"].value_counts().sort_index() intra_geno = filtered_data["intra_substitutions_str"].value_counts() ls = ['-', '--', '-.', ':'] plt.figure() dates = sorted(filtered_data.loc[:,"numdate"]) plt.plot(dates, (np.arange(len(dates))+1), c='k', lw=3, alpha=0.3, label="all") ls_counter = defaultdict(int) for x,i in intra_geno.items(): nmuts = len(x.split(',')) if x else 0 ind = filtered_data["intra_substitutions_str"]==x dates = sorted(filtered_data.loc[ind,"numdate"]) factor = 3 if 'C8782T' in x else 1 if dates[0]<args.min_date+0.2: #if (nmuts<2 or i>len(filtered_data)/100) and ls_counter[nmuts]<10: if nmuts>5: continue plt.plot(dates, factor*(np.arange(i)+1), c=f'C{nmuts}', lw=2 if nmuts else 3, ls=ls[ls_counter[nmuts]%len(ls)], marker='o' if len(dates)<3 else '', label=f"gt={x if x else 'founder'}"[:30]) ls_counter[nmuts]+=1 plt.ylim(0.5,1000) plt.xlim(args.min_date,args.min_date+0.3) plt.yscale('log') plt.legend() if args.output_plot: plt.savefig(args.output_plot) else: plt.show() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import argparse import numpy as np import pandas as pd from scipy.stats import linregress import matplotlib as mpl mpl.rcParams['axes.formatter.useoffset'] = False from root_to_tip import add_panel_label import matplotlib.pyplot as plt import seaborn as sns if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) fs=14 parser.add_argument('--rate-table', type=str, required=True, help="input data") parser.add_argument('--output-plot', type=str, help="plot file") parser.add_argument('--output-plot-rates', type=str, help="plot file") parser.add_argument('--output-plot-rates-genes', type=str, help="plot file") args = parser.parse_args() L = 29903 Laa = 9716 # all CDS other than ORF9) rates = pd.read_csv(args.rate_table, sep='\t', index_col='clade') inter_clade_rates = {} fig, axs = plt.subplots(2,2, figsize=(12,12)) xticks = [2020,2020.5,2021,2021.5,2022,2022.5] for ax, mut_type, ax_label, panel in zip(axs.flatten(),['nuc', 'aa', 'syn'], ['total divergence', 'amino acid divergence', 'synonymous divergence'], ['A', 'B', 'C']): ax.set_ylabel(ax_label, fontsize=fs) add_panel_label(ax, panel, fs=fs*1.8) inter_clade = [] ci=0 for clade, row in rates.iterrows(): dt = np.linspace(0, 0.75,2) t = dt + row[f'{mut_type}_origin'] slope = row[f'{mut_type}_rate'] ls = '-' if ci<10 else '--' m = 'o' if ci<10 else 's' ax.plot(t, dt*slope + row[f'{mut_type}_div'], label=clade, c=f"C{ci%10}", ls=ls) ax.scatter([[row[f'{mut_type}_origin']]], [[row[f'{mut_type}_div']]], c=f"C{ci%10}", marker=m) if row[f'{mut_type}_origin']>2019.7 and row[f'{mut_type}_origin']<2022.7: inter_clade.append([row[f'{mut_type}_origin'], row[f'{mut_type}_div']]) ci += 1 inter_clade = np.array(inter_clade) reg = linregress(inter_clade[:,0], inter_clade[:,1]) inter_clade_rates[mut_type] = reg.slope x = np.linspace(2019.8, 2022.5, 21) y = reg.slope*x + reg.intercept std_dev = np.sqrt(np.maximum(0,reg.slope*x + reg.intercept)) ax.plot(x, y, c='k', lw=3, alpha=0.5) ax.fill_between(x, y+std_dev, np.maximum(0, y-std_dev), fc='k', lw=3, alpha=0.1, ec=None) ax.set_xlim(x.min(), x.max()) ax.set_ylim(0) if mut_type=='nuc': ax.text( 0.5, 0.05,f"overall rate:\n{reg.slope:1.1f}/year\n{reg.slope/L:1.1e}/year/site", fontsize=fs, transform=ax.transAxes) else: ax.text( 0.5, 0.05,f"overall rate:\n{reg.slope:1.1f}/year\n{reg.slope/Laa:1.1e}/year/codon", fontsize=fs, transform=ax.transAxes) if mut_type == 'aa': ax.legend(ncol=2) ax.set_xticks(xticks, [str(x) for x in xticks], fontsize=fs) for i, (mut_type, rate) in enumerate(inter_clade_rates.items()): axs[-1,-1].fill_between([i-0.5, i+0.5], [35,35], facecolor='k', alpha=0.075*(2+i%2)) axs[-1,-1].plot([i-0.4, i+0.4], [rate, rate], lw=3, c='k', alpha=0.5) clade_rates = rates[f"{mut_type}_rate"] x_offset = i + np.linspace(-.35, 0.35,len(clade_rates)) axs[-1,-1].scatter(x_offset[:10], clade_rates[:10], c=[f"C{ci%10}" for ci in range(10)], marker='o') axs[-1,-1].scatter(x_offset[10:], clade_rates[10:], c=[f"C{ci%10}" for ci in range(len(clade_rates) - 10)], marker='s') axs[-1,-1].set_ylabel("substitutions per year", fontsize=fs) axs[-1,-1].set_xticks([0,1,2], ['total', 'amino acid', 'synonymous'], rotation=20, ha='center', fontsize=fs) axs[-1,-1].set_ylim(0) add_panel_label(axs[-1,-1], 'D', fs=fs*1.8) if args.output_plot: plt.savefig(args.output_plot) else: plt.show() plt.figure() plt.plot(rates["nuc_origin"], rates["aa_rate"], 'o', label='amino acid rate') plt.plot(rates["nuc_origin"], rates["syn_rate"], 'o', label='synonymous rate') plt.plot(rates["nuc_origin"], np.ones_like(rates['nuc_origin'])*inter_clade_rates["aa"], label='inter-clade amino acid rate', c=f"C{0}", lw=3) plt.plot(rates["nuc_origin"], np.ones_like(rates['nuc_origin'])*inter_clade_rates["syn"], label='inter-clade synonymous rate', c=f"C{1}", lw=3) plt.ylabel('rate estimate [1/y]') plt.ylim(0) plt.legend() if args.output_plot_rates: plt.savefig(args.output_plot_rates) else: plt.show() plt.figure() plt.plot(rates["aa_rate"], 'o-', label='Overall amino acid rate') plt.plot(rates["spike_rate"], 's-', label='spike protein') plt.plot(rates["orf1a_rate"], 'd-', label='ORF1a') plt.plot(rates["orf1b_rate"], 'd-', label='ORF1b') plt.plot(rates["enm_rate"], 'd-', label='E,M,N') plt.plot(rates["aa_rate"] - rates["spike_rate"] - rates["orf1a_rate"] - rates["orf1b_rate"]- rates["enm_rate"], 'v-', label='other ORFs') plt.ylabel('rate estimate [subs/y]') plt.legend() plt.xticks(range(len(rates)), rates.index, rotation=60, ha='right') plt.tight_layout() if args.output_plot_rates_genes: plt.savefig(args.output_plot_rates_genes) else: plt.show() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | import argparse,json from unittest.mock import NonCallableMagicMock import pandas as pd import numpy as np import matplotlib as mpl from Bio import SeqIO,Seq mpl.rcParams['axes.formatter.useoffset'] = False from collections import defaultdict from scipy.stats import scoreatpercentile import matplotlib.pyplot as plt import seaborn as sns columns = ['seqName', 'Nextclade_pango', "privateNucMutations.unlabeledSubstitutions", 'qc.overallStatus', 'privateNucMutations.reversionSubstitutions'] def get_sequence(mods, root_seq): seq = root_seq.copy() for mut in mods['nuc']: a,pos,d = mut[0], int(mut[1:-1])-1, mut[-1] if a!=seq[pos]: print(seq[pos], mut) seq[pos]=d return seq def translate(codon): try: return Seq.translate("".join(codon)) except: return 'X' if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--metadata', type=str, required=True, help="input data") parser.add_argument('--reference', type=str, required=True, help="input data") parser.add_argument('--pango-gts', type=str, required=True, help="input data") parser.add_argument('--rare-cutoff', type=int, default=2, help="minimal number of occurrences to count mutations") parser.add_argument('--output-fitness', type=str, required=True, help="fitness table") parser.add_argument('--output-events', type=str, required=True, help="events table") parser.add_argument('--output-mutations', type=str, required=True, help="fitness table") args = parser.parse_args() ## load reference ref = SeqIO.read(args.reference, 'genbank') ref_array = np.array(ref.seq) base_content = {x:ref.seq.count(x) for x in 'ACGT'} ## make a map of codon positions, ignore orf9b (overlaps N) codon_pos = np.zeros(len(ref)) map_to_gene = {} gene_position = {} gene_length = {} for feat in ref.features: if feat.type=='CDS': gene_name = feat.qualifiers['gene'][0] gene_position[gene_name] = feat.location gene_length[gene_name] = (feat.location.end - feat.location.start)//3 if gene_name!='ORF9b': for gpos, pos in enumerate(feat.location): codon_pos[pos] = (gpos%3)+1 map_to_gene[pos]= gene_name # load pango genotypes with open(args.pango_gts) as fh: pango_gts = json.load(fh) # load and filter metadata d = pd.read_csv(args.metadata, sep='\t', usecols=columns).fillna('') d = d.loc[d["qc.overallStatus"]=='good',:] # glob all rare mutations by pango lineage mutation_counter = defaultdict(lambda: defaultdict(int)) pango_counter = defaultdict(int) for r, row in d.iterrows(): pango = row.Nextclade_pango if pango[0]=='X': continue pango_counter[pango] += 1 muts = row["privateNucMutations.unlabeledSubstitutions"].split(',') rev_muts = row["privateNucMutations.reversionSubstitutions"].split(',') if len(muts) + len(rev_muts): for m in muts + rev_muts: if m: mutation_counter[pango][m] += 1 lineage_size_cutoff = 100 big_lineages = [pango for pango in pango_counter if pango_counter[pango]>lineage_size_cutoff] nlin = len(big_lineages) tmp_number_of_pango_muts = defaultdict(int) for pango, muts in pango_gts.items(): if pango_counter[pango]>lineage_size_cutoff: for m in muts['nuc']: tmp_number_of_pango_muts[int(m[1:-1])-1] += 1 number_of_pango_muts = np.array([tmp_number_of_pango_muts.get(pos,0)/nlin for pos in range(len(ref))]) # for each pango lineage, count mutations by position, as well as synoymous and non-synonyomous by codon position_counter = defaultdict(lambda: defaultdict(list)) position_freq = defaultdict(lambda: defaultdict(list)) syn = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) nonsyn = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for pango, pmuts in mutation_counter.items(): if pango not in pango_gts: print("missing pango", pango) continue seq = get_sequence(pango_gts[pango], ref_array) # translate each codon in pango reference sequence for mut,c in pmuts.items(): a, pos, d = mut[0], int(mut[1:-1])-1, mut[-1] if a=='-': continue position_counter[pos][(a,d)].append(c) if pango_counter[pango]>lineage_size_cutoff and c>1: position_freq[pos][(a,d)].append(c/pango_counter[pango]) if pos in map_to_gene: gene = map_to_gene[pos] pos_in_gene = pos - gene_position[gene].start frame = pos_in_gene%3 codon = seq[pos-frame:pos-frame+3] codon_number = pos_in_gene//3 alt = codon.copy() alt[frame] = d aa = translate(codon) alt_aa = translate(alt) if aa==alt_aa and aa!='X': syn[gene][codon_number][a].append(c) elif 'X' not in [aa,alt_aa]: nonsyn[gene][codon_number][a].append(c) sorted_transitions = [] for a in 'ACGT': for d in 'ACGT': if a!=d: sorted_transitions.append((a,d)) # determine the total number of events and counts at each position def total_len_filtered_lists(l, cutoff, rates=None): if rates is None: rates = {} return sum([len([y for y in x if y>=cutoff])/rates.get(a,1.0) for (a,d),x in l.items()]) all_events = [] for pos in range(len(ref)): tmp = [] for a, d in sorted_transitions: tmp.append(total_len_filtered_lists({(a,d): position_counter[pos][(a,d)]}, args.rare_cutoff)) all_events.append(tmp) # sort all mutations by type of mutation events_by_transition = defaultdict(int) for m in position_counter.values(): for t in m: events_by_transition[t]+=len(m[t]) # determine the mutations rate for each type, as well as the rate out of a nuc transition_rate = dict() away_rate = defaultdict(float) total_rate = np.sum(list(events_by_transition.values())) for a,d in events_by_transition: if a in base_content: transition_rate[(a,d)] = events_by_transition[(a,d)]/base_content[a]/total_rate*len(ref) away_rate[a] += transition_rate[(a,d)] total_events = np.array([total_len_filtered_lists(position_counter[pos], args.rare_cutoff) if pos in position_counter else 0 for pos in range(len(ref))]) total_events_weighted = np.array([total_len_filtered_lists(position_counter[pos], args.rare_cutoff, away_rate) if pos in position_counter else 0 for pos in range(len(ref))]) total_counts = np.array([sum([sum(x) for x in position_counter[pos].values()]) if pos in position_counter else 0 for pos in range(len(ref))]) # average mutation rate used to scale rates total_events_rescaled = np.array([c/away_rate[nuc] for c,nuc in zip(total_events, ref.seq)]) total_events_rescaled /= np.median(total_events_rescaled) total_events_weighted /= np.median(total_events_weighted) with open(args.output_events, 'w') as fh: fh.write('\t'.join(['position', 'ref_state', 'lineage_fraction_with_changes', 'gene', 'codon', 'pos_in_codon'] + [f"{a}->{d}" for a,d in sorted_transitions]) + '\n') for pos, nuc in enumerate(ref_array): gene = map_to_gene.get(pos,"") data = "\t".join([f"{all_events[pos][i]}" for i in range(len(sorted_transitions))]) if gene: gene_pos = pos - gene_position[gene].start codon = gene_pos//3 + 1 cp = gene_pos%3 + 1 fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t{gene}\t{codon}\t{cp}\t' + data + "\n") else: fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t\t\t\t' + data + "\n") with open(args.output_fitness, 'w') as fh: fh.write('\t'.join(['position', 'ref_state', 'lineage_fraction_with_changes', 'gene', 'codon', 'pos_in_codon', 'total_count', 'total_events', 'tolerance']) + '\n') for pos, nuc in enumerate(ref_array): gene = map_to_gene.get(pos,"") if gene: gene_pos = pos - gene_position[gene].start codon = gene_pos//3 + 1 cp = gene_pos%3 + 1 fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t{gene}\t{codon}\t{cp}\t{total_counts[pos]}\t{total_events[pos]}\t{total_events_weighted[pos]:1.3f}\n') else: fh.write(f'{pos+1}\t{nuc}\t{number_of_pango_muts[pos]:1.3f}\t\t\t\t{total_counts[pos]}\t{total_events[pos]}\t{total_events_weighted[pos]:1.3f}\n') with open(args.output_mutations, 'w') as fh: fh.write('\t'.join(['mutation', 'raw_counts', 'origin_sites', 'scaled_rate']) + '\n') for a in 'ACTG': for d in 'ACGT': if a==d: continue key = (a,d) fh.write(f"{a}->{d}\t{events_by_transition[key]}\t{base_content[a]}\t{transition_rate[key]:1.3f}\n") # # calculate synonymous and non-synonymous distributions # total_events_by_type = {} # for mut_counts, label in [(syn, 'syn'), (nonsyn, 'nonsyn')]: # total_events_by_type[label] = {} # for gene in mut_counts: # total_events_by_type[label][gene] = [np.sum([len([x for x in mut_counts[gene][pos][a] if x>cutoff])/away_rate[a] # for a in mut_counts[gene][pos]]) # for pos in range(gene_length[gene])] # def calc_fitness_cost(freqs, transition): # mu = 0.0004/50 # mut_rate = transition_rate[transition]*mu # avg_freq = np.sum(freqs)/nlin # return mut_rate/(avg_freq + 1e-3/nlin) # fitness_cost = [] # for pos,nuc in enumerate(ref_array): # total_muts = np.sum([len(v) for v in position_freq[pos].values()]) # reference_muts = np.sum([len(v) for (a,d), v in position_freq[pos].items() if a==nuc]) # if reference_muts>0.9*total_muts: # fitness_cost.append([np.inf if d==nuc # else calc_fitness_cost(position_freq[pos][(nuc,d)],(nuc,d)) # for d in 'ACGT']) # else: # fitness_cost.append([np.nan, np.nan, np.nan, np.nan]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import pandas as pd import argparse,json from datetime import datetime import numpy as np from root_to_tip import filter_and_transform, get_clade_gts from collections import defaultdict if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--metadata', nargs='+', type=str, required=True, help="input data") parser.add_argument('--clade-gts', type=str, required=True, help="input data") parser.add_argument('--clade', type=str, required=True, help="input data") parser.add_argument('--sub-clades', type=str, required=True, help="input data") parser.add_argument('--min-date', type=float, help="input data") parser.add_argument('--max-date', type=float, help="input data") parser.add_argument('--max-group', type=int, help="input data") parser.add_argument('--query', type=str, help="filters") parser.add_argument('--bin-size', type=float, help="input data") parser.add_argument('--output-json', type=str, help="rate file") args = parser.parse_args() clade_gt = get_clade_gts(args.clade_gts, args.sub_clades) d = pd.concat([pd.read_csv(x, sep='\t').fillna('') for x in args.metadata]) filtered_data, _ = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.max_date, query = args.query, max_group=args.max_group, QC_threshold=80 if args.clade=='21H' else 30, completeness=0, swap_root=args.clade.startswith('19B+')) filtered_data["day"] = filtered_data.datetime.apply(lambda x:x.toordinal()) print("clade", args.clade, "done filtering") bins = np.arange(np.min(filtered_data.day),np.max(filtered_data.day), args.bin_size) all_sequences = np.histogram(filtered_data.day, bins=bins)[0] cumulative_sum = np.zeros_like(bins[:-1]) mutation_number = {} nmax = 5 nmax_extra = 15 for n in range(nmax): ind = filtered_data["divergence"]==n mutation_number[n] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0] cumulative_sum += mutation_number[n] mutation_number[f'{nmax}+'] = all_sequences - cumulative_sum for n in range(nmax,nmax_extra): ind = filtered_data["divergence"]==n mutation_number[n] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0] cumulative_sum += mutation_number[n] print("clade", args.clade, "done mutation_number") mutation_counts = defaultdict(int) cutoff = args.min_date + (args.max_date - args.min_date)/2 window = cutoff - 14/365, cutoff + 14/365 ind_early = filtered_data.numdate<cutoff ind_window = (filtered_data.numdate>=window[0]) & (filtered_data.numdate<window[1]) n_early = ind_early.sum() n_window = ind_window.sum() # Find common mutations for muts in filtered_data.loc[ind_early, "intra_substitutions"]: for m in muts: mutation_counts[m] +=1 relevant_muts = [x[0] for x in sorted(list(mutation_counts.items()), key=lambda k:k[1])[-10:]] mutations = {} nmax = 5 for mut in relevant_muts: ind = filtered_data["intra_substitutions"].apply(lambda x: mut in x) mutations[mut] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0] # Get mutation _spectrum mutation_spectrum = defaultdict(float) for muts in filtered_data.loc[ind_window, "intra_substitutions"]: for m in muts: mutation_spectrum[m] += 1.0/n_window relevant_muts = [x[0] for x in sorted(list(mutation_counts.items()), key=lambda k:k[1])[-10:]] mutations = {} nmax = 5 for mut in relevant_muts: ind = filtered_data["intra_substitutions"].apply(lambda x: mut in x) mutations[mut] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0] print("clade", args.clade, "done mutations") intra_geno = filtered_data.loc[ind_early,"intra_substitutions_str"].value_counts() genotypes = {} gt_count = 0 for x,i in intra_geno.items(): nmuts = len(x.split(',')) if x else 0 ind = filtered_data["intra_substitutions_str"]==x if i<10 and nmuts or nmuts>4: continue if nmuts==0 or (nmuts==1 and ind.sum()>0.005*n_early) or (ind.sum()>0.01*n_early): genotypes[x] = np.histogram(filtered_data.loc[ind,"day"], bins=bins)[0] gt_count +=1 if gt_count>8: break print("clade", args.clade, "done genotypes") with open(args.output_json, 'w') as fh: json.dump({'bins':[int(x) for x in bins], 'all_samples':[int(x) for x in all_sequences], 'mutation_number': {k:[int(x) for x in v] for k,v in mutation_number.items()}, 'mutation_spectrum': {k: float(v) for k,v in mutation_spectrum.items()}, 'mutations': {k:[int(x) for x in v] for k,v in mutations.items()}, 'genotypes': {k:[int(x) for x in v] for k,v in genotypes.items()}}, fh) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import pandas as pd from collections import defaultdict import argparse,json def genotype_struct(): return {'nuc':{}, 'aa':defaultdict(dict)} def assign_genotypes(n): gt = n["genotype"] if "children" in n: for c in n["children"]: cgt = genotype_struct() cgt["nuc"].update({k:v for k,v in gt["nuc"].items()}) if "nuc" in c["branch_attrs"]["mutations"]: for mut in c["branch_attrs"]["mutations"]["nuc"]: a,pos,d = mut[0], int(mut[1:-1]), mut[-1] cgt["nuc"][pos] = d for gene in set(c["branch_attrs"]["mutations"].keys()).union(set(gt["aa"].keys())): if gene=='nuc': continue cgt["aa"][gene].update({k:v for k,v in gt["aa"][gene].items()}) if gene in c["branch_attrs"]["mutations"]: for mut in c["branch_attrs"]["mutations"][gene]: a,pos,d = mut[0], int(mut[1:-1]), mut[-1] cgt['aa'][gene][pos] = d c["genotype"] = cgt assign_genotypes(c) def get_pango_genotypes(n, pango_gts): name = n["name"] if "children" in n and len(n["children"]): for c in n["children"]: get_pango_genotypes(c, pango_gts) elif (name in ['A', 'B'] or '.' in name) and not ('/' in name): pango_gts[n["name"]] = n["genotype"] if __name__=="__main__": parser = argparse.ArgumentParser( description="get variant genotypes", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--tree', type=str, required=True, help="input data") parser.add_argument('--root', type=str, required=True, help="input data") parser.add_argument('--output', type=str, required=True, help="output data") args = parser.parse_args() with open(args.tree) as fh: tree = json.load(fh)['tree'] with open(args.root) as fh: root_sequence = json.load(fh) tree["genotype"] = genotype_struct() assign_genotypes(tree) pango_gts = {} get_pango_genotypes(tree, pango_gts) subs = {} for clade in pango_gts: tmp_nuc = [] if 'nuc' in pango_gts[clade]: for pos, d in pango_gts[clade]['nuc'].items(): a=root_sequence['nuc'][pos-1] if d not in ['N', '-'] and d!=a: tmp_nuc.append(f"{a}{pos}{d}") tmp_aa = [] for gene in pango_gts[clade]['aa']: for pos, d in pango_gts[clade]['aa'][gene].items(): a=root_sequence[gene][pos-1] if d not in ['X', '-'] and d!=a: tmp_aa.append(f"{gene}:{a}{pos}{d}") subs[clade] = {'nuc':tmp_nuc, 'aa':tmp_aa} with open(args.output, 'w') as fh: json.dump(subs, fh) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import argparse,json from datetime import datetime import numpy as np import matplotlib as mpl mpl.rcParams['axes.formatter.useoffset'] = False import matplotlib.pyplot as plt import pandas as pd from plot_genotype_counts import fit_poisson if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--counts', nargs='+', type=str, required=True, help="input data") parser.add_argument('--output-plot', type=str, help="figure file") parser.add_argument('--output-rates', type=str, help="table") args = parser.parse_args() counts = {} for fi,fname in enumerate(args.counts): clade = fname.split('/')[-1].split('_')[0] with open(fname) as fh: counts[clade] = json.load(fh) fig, axs = plt.subplots(1,1, figsize = (8,6)) ax = axs ls = ['-', '-.', '--'] for fi,(clade,d) in enumerate(counts.items()): ax.plot(sorted(d["mutation_spectrum"].values()), np.linspace(1,0,len(d["mutation_spectrum"])+1)[:-1], label=clade, ls=ls[fi//10]) ax.set_yscale('log') ax.set_xscale('log') ax.set_xlabel('mutation frequency') ax.set_ylabel('fraction above') x = np.logspace(-4,-0.3,21) ax.plot([1e-5,1e-1], [1e-0,1e-4], c='k', lw=3, alpha=0.5, label='1/x') # ax.plot(x, np.log(x)/np.log(x[0]), c='k', lw=3, alpha=0.5, label='~log(x)') plt.legend(ncol=3) plt.savefig(args.output_plot) ls = ['-', '-.', '--'] data = [] for fi,(clade,d) in enumerate(counts.items()): dates = np.array([x for x in d['bins'][:-1]]) datetimes = np.array([datetime.fromordinal(x) for x in d['bins'][:-1]]) t0 = dates[0] dates -= t0 total = np.array(d["all_samples"]) k_vecs = { i: np.array(d["mutation_number"][f'{i}']) for i in ['0', '1', '2', '3', '4', '5+']} ind = total>0 res = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, dates[ind]) start_date = datetime.fromordinal(int(t0 + res['offset'])).strftime("%Y-%m-%d") res_fixed_rate = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, dates[ind], rate=15/365) start_date_fixed = datetime.fromordinal(int(t0 + res_fixed_rate['offset'])).strftime("%Y-%m-%d") data.append({'clade':clade, 'rate':res['rate']*365, 'origin':start_date, 'origin_fixed_rate':start_date_fixed}) pd.DataFrame(data).to_csv(args.output_rates, sep='\t') |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import argparse import pandas as pd import numpy as np import matplotlib.pyplot as plt if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--fitness', type=str, required=True, help="input data") parser.add_argument('--output-fitness-landscape', type=str, help="fitness figure") args = parser.parse_args() fitness = pd.read_csv(args.fitness, sep='\t', index_col=0) genes = [x for x in fitness.gene.unique() if not pd.isna(x)] pos = np.arange(len(fitness)) fig = plt.figure(figsize=(15,12)) gene_groups = [ ['ORF3a', 'E', 'M', 'ORF6','ORF7a', 'ORF7b', 'ORF8'],['S', 'N'],['ORF1b'],['ORF1a'] ] gene_length = {k:(fitness.gene==k).sum() for k in genes} fs = 14 ws = 20 w = np.ones(ws)/ws ws_fine = 7 w_fine = np.ones(ws_fine)/ws_fine lower_fitness_cutoff = 0.03 n_rows = len(gene_groups) h_spread = 0.003 h_margin = 0.05 v_spread = 0.02 v_margin = 0.05 height = (1 - v_spread*n_rows - v_margin)/n_rows for row,group in enumerate(gene_groups): axis_bottom = row*(v_spread + height) + v_margin available_width = 1 - len(group)*h_spread - h_margin left = h_margin total_genome = np.sum([gene_length[g] for g in group]) for gi,gene in enumerate(group): width = gene_length[gene]/total_genome*available_width ax = fig.add_axes((left,axis_bottom, width, height)) left += width + h_spread for cp in range(3): ind = (fitness.pos_in_codon==(cp+1)) & (fitness.gene==gene) gene_pos = (fitness.codon[ind]*3 + cp - 1)/3 gene_pos_smooth = np.convolve(gene_pos, w, mode='valid') ax.plot(gene_pos_smooth, np.convolve(np.log10(fitness.tolerance[ind]+lower_fitness_cutoff), w, mode='valid'), c=f'C{cp}', label=f'codon pos {cp+1}') gene_pos_smooth = np.convolve(gene_pos, w_fine, mode='valid') ax.plot(gene_pos_smooth, np.convolve(np.log10(fitness.tolerance[ind]+lower_fitness_cutoff), w_fine, mode='valid'), c=f'C{cp}', alpha = 0.5) ax.plot(gene_pos, np.log10(fitness.lineage_fraction_with_changes[ind]+.001)/3 - 1.2, 'o', c=f'C{cp}') ax.plot(gene_pos, np.zeros_like(gene_pos), c='k', alpha=0.3, lw=2) ax.set_ylim(-2.0, 1) ax.set_xlim(0, len(gene_pos)) if gi==0: ax.set_ylabel('log10 scaled tolerance') if row==n_rows-1: ax.legend(ncol=3, fontsize=fs) else: ax.set_yticklabels([]) ax.text(0.15,0.75, gene, fontsize=fs*1.1) plt.savefig(args.output_fitness_landscape) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import argparse,json import pandas as pd import numpy as np from collections import defaultdict import matplotlib.pyplot as plt from scipy.stats import scoreatpercentile if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--fitness', type=str, required=True, help="input data") parser.add_argument('--mutations', type=str, required=True, help="input data") parser.add_argument('--output-fitness', type=str, help="fitness figure") parser.add_argument('--output-fitness-by-gene', type=str, help="fitness figure") parser.add_argument('--output-mutations', type=str, help="fitness figure") args = parser.parse_args() mutations = pd.read_csv(args.mutations, sep='\t', index_col=0) fitness = pd.read_csv(args.fitness, sep='\t', index_col=0) ## figure with mutation distributions plt.figure() sorted_muts = mutations.sort_values('scaled_rate', ascending=False) mut_sum = np.sum(sorted_muts.scaled_rate) muts = [(t,v/mut_sum) for t,v in sorted_muts.scaled_rate.items()] plt.bar(np.arange(len(muts)), height=[m[1] for m in muts]) plt.xticks(np.arange(len(muts)), [m[0] for m in muts], rotation=30) plt.ylabel('fraction') plt.savefig(args.output_mutations) ## Figure with 1st, 2nd, 3rd positions plt.figure() for p in range(1,4): ind = fitness.pos_in_codon==p plt.plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), label = f"codon pos={p}") ind = fitness.pos_in_codon.isna() plt.plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), label = f"non-coding") ind = fitness.pos_in_codon==3 syn_cutoff = scoreatpercentile(fitness.tolerance[ind],10) plt.plot([syn_cutoff, syn_cutoff], [0,1], c='k', alpha=0.3) for i in range(1,3): ind = fitness.pos_in_codon==i print("codon pos", i, np.mean(fitness.tolerance[ind]<syn_cutoff)) ind = fitness.pos_in_codon.isna() print("non coding", np.mean(fitness.tolerance[ind]<syn_cutoff)) plt.xscale('log') plt.ylabel("fraction below") plt.xlabel("scaled number of lineages with mutations") plt.legend() plt.savefig(args.output_fitness) # # figure with pdfs instead of cdfs # # plt.figure() # measure = total_events_rescaled # bins = np.logspace(0, np.ceil(np.log10(measure.max())), 101) # bc = np.sqrt(bins[:-1]*bins[1:]) # bins[0] = 0 # rate_estimate = {} # for p in range(4): # ind = codon_pos==p # y,x = np.histogram(measure[ind], bins=bins) # rate_estimate[p] = {"mean": np.sum(bc*y/y.sum()), # "geo-mean": np.exp(np.sum(np.log(bc)*y/y.sum())), # "median": np.median(measure[ind])} # # plt.plot(bc,y/y.sum(), label = 'non-coding' if p==0 else f"codon pos={p}") # two panel figure with 1/2nd and 3rd position mutations fig, axs = plt.subplots(2,1, figsize = (6,10), sharex=True) pos = np.arange(len(fitness)) genes = [x for x in fitness.gene.unique() if not pd.isna(x)] for i,gene in enumerate(genes): c = f"C{i}" ls = '--' if i>9 else '-' ind = (fitness.pos_in_codon==3) & (fitness.gene==gene) axs[1].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), ls=ls, c=c,label = f'{gene}') ind = (fitness.pos_in_codon>0) & (fitness.pos_in_codon<3) & (fitness.gene==gene) axs[0].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), ls=ls, c=c) ind = fitness.pos_in_codon.isna() axs[1].plot(sorted(fitness.tolerance[ind]), np.linspace(0,1,ind.sum()), label = 'non-coding', c='k') plt.xscale('log') plt.xlim(0.01, 10) for ax in axs: ax.grid() axs[1].legend(ncol=2) axs[1].set_title("3rd codon positions or non-coding") axs[0].set_title("1st and 2nd codon positions") axs[0].set_ylabel("fraction below") axs[1].set_ylabel("fraction below") axs[1].set_xlabel("scaled number of lineages with mutations") plt.savefig(args.output_fitness_by_gene) # fitness_cost = np.array(fitness_cost) # plt.figure() pos = np.arange(len(fitness)) ws = 10 w = np.ones(ws)/ws for i,gene in enumerate(genes): c = f"C{i}" ls = '--' if i>9 else '-' plt.figure() for cp in range(3): ind = (fitness.pos_in_codon==(cp+1)) & (fitness.gene==gene) #ind = (codon_pos==(cp+1)) & (pos>=gene_range.start) & (pos<gene_range.end) & (~np.isnan(fitness_cost).all(axis=1)) #plt.hist(np.min(np.log(fitness_cost[ind]), axis=1)) gene_pos = (fitness.codon[ind]*3 + cp - 1)/3 gene_pos_smooth = np.convolve(gene_pos, w, mode='valid') plt.plot(gene_pos_smooth, np.convolve(np.log10(fitness.tolerance[ind]+.03), w, mode='valid'), c=f'C{cp}', label=f'codon pos {cp+1}') plt.plot(gene_pos, np.zeros_like(gene_pos), c='k', alpha=0.3, lw=2) plt.plot(gene_pos, np.log10((np.log(fitness.lineage_fraction_with_changes[ind]+.001)+3.01)), 'o', c=f'C{cp}') plt.ylabel('log10 scaled mutations') plt.ylim(-1.5, 1) plt.title(gene) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import argparse,json from datetime import datetime from sys import exec_prefix import numpy as np import matplotlib as mpl mpl.rcParams['axes.formatter.useoffset'] = False import matplotlib.pyplot as plt import seaborn as sns from root_to_tip import add_panel_label fs=12 def poisson(lam, order): p = np.exp(-lam) try: for i in range(0,int(order)): p *= lam/(i+1) except: p_cum = np.copy(p) for i in range(0,int(order[:-1])-1): p *= lam/(i+1) p_cum += p p = 1 - p_cum return np.maximum(0.0, np.minimum(1.0,p)) def fit_poisson(n,k_vecs,t, rate=None): from scipy.optimize import minimize eps = 1e-16 def binom(x,n,k_vecs,t, rate): res = 0 mu = rate or x[1]**2 tu = np.maximum(0,mu*(t-x[0])) for order, k in k_vecs.items(): p = poisson(tu, order) res -= np.sum(((n-k)*np.log(1-p+eps) + k*np.log(p+eps))) return res if rate is None: sol = minimize(binom, (-10,0.25), args=(n,k_vecs,t,rate), method="Powell") res = {'offset':sol['x'][0], 'rate':sol['x'][1]**2} else: sol = minimize(binom, (-10,), args=(n,k_vecs,t,rate), method="Powell") res = {'offset':sol['x'][0], 'rate':rate} return res if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--counts', type=str, required=True, help="input data") parser.add_argument('--output-plot', type=str, help="figure file") args = parser.parse_args() with open(args.counts) as fh: counts = json.load(fh) plt.figure() nmuts = [m for m in counts['mutation_number'] if m[-1]!='+'] for week in range(5,len(counts['mutation_number']['0'])): c = np.array([counts['mutation_number'][str(n)][week] for n in nmuts]) if c.sum()>100: plt.plot(range(len(nmuts)), c/c.sum(), '-') dates = np.array([datetime.fromordinal(x) for x in counts['bins'][:-1]]) t0 = counts['bins'][0] rel_date = np.array([x-t0 for x in counts['bins'][:-1]]) fig, axs = plt.subplots(1,3, figsize = (18,6)) ax = axs[2] ax.set_title("mutations per genome", fontsize=1.2*fs) # ax.plot(dates, counts['all_samples'], lw=3, c='k', alpha=0.3) total = np.array(counts['all_samples']) ind = total>0 k_vecs = {m: np.array(counts['mutation_number'][m]) for m in ['0', '1', '2', '3', '4', '5+']} res = fit_poisson(total[ind], {x:k_vecs[x][ind] for x in k_vecs}, rel_date[ind]) tu = (rel_date[ind]-res['offset'])*res['rate'] for i, (m,k) in enumerate(k_vecs.items()): ax.plot(dates[ind], k[ind]/total[ind], 'o', label=f'{m} mutations',c=f"C{i}") plt.plot(dates[ind], poisson(tu,m), ls='-', c=f'C{i}') ax.set_ylim(8e-5, 2) ax = axs[1] ax.set_title("mutation frequencies", fontsize=1.2*fs) ax.plot(dates, total, lw=3, c='k', alpha=0.3, label='total') for m in sorted(counts['mutations'].keys(), key=lambda x:int(x[1:-1])): ax.plot(dates, counts['mutations'][m], '-o', label=f'{m}') ax = axs[0] ax.set_title("genotype frequencies", fontsize=1.2*fs) ax.plot(dates, total, lw=3, c='k', alpha=0.3, label='total') for m in sorted(counts['genotypes'].keys(), key=lambda x: len(x)): ax.plot(dates, counts['genotypes'][m], '-o', label=f'{m}' if m else "founder") fig.autofmt_xdate() for ax, label in zip(axs, 'DEF'): ax.set_yscale('log') ax.legend() add_panel_label(ax, label, fs=fs*1.8) plt.savefig(args.output_plot) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | import pandas as pd import argparse,json from treetime.utils import numeric_date, datestring_from_numeric from datetime import datetime import numpy as np from scipy.stats import linregress, scoreatpercentile import matplotlib.pyplot as plt import matplotlib as mpl import seaborn as sns from collections import defaultdict mpl.rcParams['axes.formatter.useoffset'] = False fs=14 date_reference = datetime(2020,1,1).toordinal() def date_to_week_since2020(d): return (d.toordinal() - date_reference)//7 def week_since2020_to_date(d): return datetime.fromordinal(int(d*7 + date_reference)) def week_since2020_to_numdate(d): return numeric_date(week_since2020_to_date(d)) def filter_and_transform(d, clade_gt, min_date=None, max_date=None, query=None, completeness=None, swap_root=False, max_group=None, QC_threshold=30): dropped_seqs = {} # filter for incomplete dates d = d.loc[d.date.apply(lambda x:len(x)==10 and 'X' not in x)] if query: pre = len(d) d = d.query(query) dropped_seqs['query'] = pre - len(d) d['datetime'] = d.date.apply(lambda x: datetime.strptime(x, '%Y-%m-%d')) d['numdate'] = d.datetime.apply(lambda x: numeric_date(x)) d['CW'] = d.datetime.apply(date_to_week_since2020) # filter date range if min_date: pre = len(d) d = d.loc[d.numdate>min_date] dropped_seqs['min_date'] = pre - len(d) if max_date: pre = len(d) d = d.loc[d.numdate<max_date] dropped_seqs['max_date'] = pre - len(d) pre = len(d) d = d.loc[d.QC_overall_score<QC_threshold] dropped_seqs['QC'] = pre - len(d) # look for clade defining substitutions d["clade_substitutions"] = d.substitutions.apply(lambda x: [y for y in x.split(',') if y in clade_gt['nuc']] if x else []) # assign number of substitutions missing in the clade definition d["missing_subs"] = d.clade_substitutions.apply(lambda x: len(clade_gt['nuc'])-len(x)) # define "with-in clade substitutions" d["intra_substitutions"] = d.substitutions.apply(lambda x: [y for y in x.split(',') if all([y not in clade_gt['nuc'], int(y[1:-1])>150, int(y[1:-1])<29753])] if x else []) # define "with-in clade substitutions" d["intra_aaSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and 'ORF9' not in y] if x else []) d["intra_SpikeSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[0]=='S'] if x else []) d["intra_ORF1aSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[:5]=='ORF1a'] if x else []) d["intra_ORF1bSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[:5]=='ORF1b'] if x else []) d["intra_ENMSubstitutions"] = d.aaSubstitutions.apply(lambda x: [y for y in x.split(',') if y not in clade_gt['aa'] and y[0] in ['E','N','M']] if x else []) if swap_root: muts = [("C8782T","T8782C"), ("T28144C","C28144T")] def swap(mutations, pair): return [y for y in mutations if y!=pair[0]] if pair[0] in mutations else mutations + [pair[1]] for m in muts: d["intra_substitutions"] = d.intra_substitutions.apply(lambda x: swap(x,m)) aa_mut = ("ORF8:L84S","ORF8:S84L") d["intra_aaSubstitutions"] = d.intra_aaSubstitutions.apply(lambda x: swap(x,aa_mut)) # make a hashable string representation d["intra_substitutions_str"] = d.intra_substitutions.apply(lambda x: ','.join(x)) # within clade divergence d["divergence"] = d.intra_substitutions.apply(lambda x: len(x)) d["aaDivergence"] = d.intra_aaSubstitutions.apply(lambda x: len(x)) d["spikeDivergence"] = d.intra_SpikeSubstitutions.apply(lambda x: len(x)) d["orf1aDivergence"] = d.intra_ORF1aSubstitutions.apply(lambda x: len(x)) d["orf1bDivergence"] = d.intra_ORF1bSubstitutions.apply(lambda x: len(x)) d["enmDivergence"] = d.intra_ENMSubstitutions.apply(lambda x: len(x)) d["synDivergence"] = d["divergence"] - d["aaDivergence"] # filter if completeness is not None: pre = len(d) d = d.loc[d.missing_subs<=completeness] dropped_seqs['completeness'] = pre - len(d) if max_group: return d.groupby(['CW', 'country']).sample(max_group, replace=True).drop_duplicates(subset='strain') return d, dropped_seqs def weighted_regression(x,y,w): ''' This function determine slope and intercept by minimizing sum_i w_i*(y_i - f(x_i))^2 with f(x) = slope*x + intercept sum_i w_i*(y_i - f(x_i)) = 0 => sum_i w_i y_i = intercept * sum_i w_i + slope sum_i w_i x_i sum_i w_i*(y_i - f(x_i)) x_i = 0 => sum_i w_i y_i x_i = intercept * sum_i w_i x_i + slope sum_i w_i x_i^2 ''' wa=np.array(w) xa=np.array(x) ya=np.array(y) wx = np.sum(xa*wa) wy = np.sum(ya*wa) wxy = np.sum(xa*ya*wa) wxx = np.sum(xa**2*wa) wsum = np.sum(wa) wmean = np.mean(wa) slope = (wy*wx - wxy*wsum)/(wx**2 - wxx*wsum) intercept = (wy - slope*wx)/wsum # not correct # hessianinv = np.linalg.inv(np.array([[wxx, wx], [wx, wsum]])/wsum) # stderrs = wmean*np.sqrt(hessianinv.diagonal()) return {"slope": slope, "intercept":intercept} #, 'slope_err':stderrs[0], 'intercept_err':stderrs[1]} def regression_by_week(d, field, min_count=5): val = d.loc[:,[field,'CW']].groupby('CW').mean() std = d.loc[:,[field,'CW']].groupby('CW').std() count = d.loc[:,[field,'CW']].groupby('CW').count() ind = count[field]>min_count #reg = linregress(val.index[ind], val.loc[ind, field]) # slope = reg.slope * 365 / 7 # intercept = reg.intercept - 2020*slope reg = weighted_regression(val.index[ind], val.loc[ind, field], np.array((count[ind]-min_count)**0.25).squeeze()) slope = reg["slope"] * 365 / 7 intercept = reg["intercept"] - 2020*slope return {"slope":slope, "intercept":intercept, "origin": -intercept/slope, "date":[week_since2020_to_numdate(x) for x in val.index[ind]], "mean":[x for x in val.loc[ind, field]], "stderr":[x for x in std.loc[ind, field]], "count":[x for x in count.loc[ind, field]]} def make_date_ticks(ax): ax.set_xlabel('') ax.set_xticklabels([datestring_from_numeric(x) for x in ax.get_xticks()], rotation=30, horizontalalignment='right') def get_clade_gts(all_gts, subclade_str): with open(all_gts) as fh: clade_gts = json.load(fh) subclades = subclade_str.split(',') if len(subclades)>1: clade_gt = {'nuc':{}, 'aa':{}} clade_gt['nuc'] = set.intersection(*[set(clade_gts[x]['nuc']) for x in subclades]) clade_gt['aa'] = set.intersection(*[set(clade_gts[x]['aa']) for x in subclades]) else: clade_gt = clade_gts[subclade_str] return clade_gt def add_panel_label(ax, t, fs=16): ax.text(-0.1,1,t, fontsize=fs,transform=ax.transAxes) if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--metadata', type=str, required=True, help="input data") parser.add_argument('--clade-gts', type=str, required=True, help="input data") parser.add_argument('--clade', type=str, required=True, help="input data") parser.add_argument('--sub-clades', type=str, required=True, help="input data") parser.add_argument('--qc-cutoff-scale', type=float, default=2, help="number of std dev of rtt filter") parser.add_argument('--qc-cutoff-offset', type=float, default=3, help="number extra mutation") parser.add_argument('--min-date', type=float, help="input data") parser.add_argument('--max-date', type=float, help="input data") parser.add_argument('--max-group', type=int, help="input data") parser.add_argument('--query', type=str, help="filters") parser.add_argument('--output-plot', type=str, help="plot file") parser.add_argument('--output-json', type=str, help="rate file") args = parser.parse_args() clade_gt = get_clade_gts(args.clade_gts, args.sub_clades) d = pd.read_csv(args.metadata, sep='\t').fillna('') filtered_data, dropped_seqs = filter_and_transform(d, clade_gt, min_date=args.min_date, max_date=args.max_date, query = args.query, max_group=args.max_group, QC_threshold=80 if args.clade=='21H' else 30, completeness=3 if args.clade>='22D' else 0, swap_root=args.clade.startswith('19B+')) regression = linregress(filtered_data.numdate, filtered_data.divergence) filtered_data["residuals"] = filtered_data.apply(lambda x: x.divergence - (regression.intercept + regression.slope*x.numdate), axis=1) # iqd = scoreatpercentile(filtered_data.residuals, 75) - scoreatpercentile(filtered_data.residuals, 25) # filtered_data["outlier"] = filtered_data.residuals.apply(lambda x: np.abs(x)>5*iqd) tolerance = lambda t: args.qc_cutoff_offset + args.qc_cutoff_scale*np.sqrt(np.maximum(0,(regression.intercept + regression.slope*t))) filtered_data["outlier"] = filtered_data.apply(lambda x: np.abs(x.residuals)>tolerance(x.numdate), axis=1) ind = filtered_data.outlier==False # regression_clean = linregress(filtered_data.numdate[ind], filtered_data.divergence[ind]) # regression_clean_aa = linregress(filtered_data.numdate[ind], filtered_data.aaDivergence[ind]) regression_clean = regression_by_week(filtered_data.loc[ind], "divergence") regression_clean_aa = regression_by_week(filtered_data.loc[ind], "aaDivergence") regression_clean_syn = regression_by_week(filtered_data.loc[ind], "synDivergence") regression_clean_spike = regression_by_week(filtered_data.loc[ind], "spikeDivergence") regression_clean_ORF1a = regression_by_week(filtered_data.loc[ind], "orf1aDivergence") regression_clean_ORF1b = regression_by_week(filtered_data.loc[ind], "orf1bDivergence") regression_clean_ENM = regression_by_week(filtered_data.loc[ind], "enmDivergence") fig, axs = plt.subplots(1,3, figsize=(18,6), sharex=True, sharey=True) ymax = 20 bins = bins=(20,np.arange(-0.5,ymax+0.5)) sns.histplot(x=filtered_data.numdate, y=np.minimum(ymax*1.5, filtered_data.divergence), bins=bins, ax=axs[0]) x = np.linspace(*axs[0].get_xlim(),101) axs[0].set_title(f'all differences', fontsize=fs*1.2) axs[0].plot(x, regression_clean["intercept"] + regression_clean["slope"]*x, lw=4, label=f"slope = {regression_clean['slope']:1.1f} subs/year") axs[0].errorbar(regression_clean["date"], regression_clean["mean"], regression_clean["stderr"]) axs[0].plot(x, regression.intercept + regression.slope*x + tolerance(x), lw=4) axs[1].set_title(f'amino acid differences', fontsize=fs*1.2) sns.histplot(x=filtered_data.numdate[ind], y=np.minimum(ymax*1.5, filtered_data.aaDivergence[ind]), bins=bins, ax=axs[1]) axs[1].plot(x, regression_clean_aa["intercept"] + regression_clean_aa["slope"]*x, lw=4, label=f"slope = {regression_clean_aa['slope']:1.1f} subs/year") axs[1].errorbar(regression_clean_aa["date"], regression_clean_aa["mean"], regression_clean_aa["stderr"]) axs[2].set_title(f'synonymous differences', fontsize=fs*1.2) sns.histplot(x=filtered_data.numdate[ind], y=np.minimum(ymax*1.5, filtered_data.synDivergence[ind]), bins=bins, ax=axs[2]) axs[2].plot(x, regression_clean_syn["intercept"] + regression_clean_syn["slope"]*x, lw=4, label=f"slope = {regression_clean_syn['slope']:1.1f} subs/year") axs[2].errorbar(regression_clean_syn["date"], regression_clean_syn["mean"], regression_clean_syn["stderr"]) axs[2].text(0.8,0.9, args.clade, fontsize=fs*1.5, transform=axs[2].transAxes) axs[0].set_ylabel("Divergence", fontsize=fs) for ax,label in zip(axs,'ABC'): make_date_ticks(ax) ax.set_yticks(np.arange(0,ymax,3)) ax.legend(loc=2, fontsize=fs) ax.set_ylim(-0.5,ymax-0.5) add_panel_label(ax, label, fs=fs*1.8) if args.output_plot: plt.savefig(args.output_plot) else: plt.show() aaCounter = defaultdict(int) nucCounter = defaultdict(int) for subs in filtered_data.intra_aaSubstitutions: for a in subs: aaCounter[a]+=1 for subs in filtered_data.intra_substitutions: for a in subs: nucCounter[a]+=1 total = len(filtered_data) top_aaSubs = {a:v/total for a,v in sorted(aaCounter.items(),key=lambda x:x[1], reverse=True)[:100]} top_nucSubs = {a:v/total for a,v in sorted(nucCounter.items(),key=lambda x:x[1], reverse=True)[:100]} rate_data = {'clade':args.clade, 'nuc':regression_clean, 'aa':regression_clean_aa, 'syn':regression_clean_syn, 'spike':regression_clean_spike, 'orf1a':regression_clean_ORF1a,'orf1b':regression_clean_ORF1b,'enm':regression_clean_ENM, "top_aaSubs": top_aaSubs, "top_nucSubs": top_nucSubs, "outliers_removed": int(np.sum(filtered_data.outlier)), "qc_filter_fail": dropped_seqs["QC"], "incomplete": dropped_seqs["completeness"], "total_sequences": len(filtered_data)} if args.output_json: with open(args.output_json, 'w') as fh: json.dump(rate_data, fh) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import pandas as pd import numpy as np import argparse columns = ['strain', 'virus', 'gisaid_epi_isl', 'genbank_accession', 'date', 'region', 'country', 'length', 'host', 'Nextstrain_clade', 'pango_lineage', 'Nextclade_pango', 'missing_data', 'divergence', 'nonACGTN', 'rare_mutations', 'reversion_mutations', 'potential_contaminants', 'QC_overall_score', 'QC_overall_status', 'frame_shifts', 'deletions', 'insertions', 'substitutions', 'aaSubstitutions', 'clock_deviation'] if __name__=="__main__": parser = argparse.ArgumentParser( description="remove time info", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--metadata', type=str, required=True, help="input data") parser.add_argument('--variant-labels', nargs='+', type=str, required=True, help="input data") parser.add_argument('--variants', nargs='+', type=str, required=True, help="input data") args = parser.parse_args() files = [f"subsets/{v}.tsv" for v in args.variant_labels] d = pd.read_csv(args.metadata, sep='\t', usecols=columns).fillna('') d["Nextstrain_clade"] = d.Nextstrain_clade.apply(lambda x:x.split()[0] if x else '') for v, name in zip(args.variants, files): if len(v.split(','))>1: ind = np.any([d.Nextstrain_clade==y for y in v.split(',')], axis=0) else: ind = d.Nextstrain_clade==v subset = d.loc[ind] print(name, v, len(subset)) subset.to_csv(name, sep='\t') |
68 69 70 71 | shell: """ curl https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz -o {output} """ |
76 77 78 79 | shell: """ curl https://data.nextstrain.org/files/ncov/open/nextclade.tsv.gz -o {output} """ |
85 86 87 88 89 | shell: """ curl https://data.nextstrain.org/nextclade_sars-cov-2.json | gunzip > {output.tree} curl https://data.nextstrain.org/nextclade_sars-cov-2_root-sequence.json | gunzip > {output.root} """ |
100 101 102 103 | shell: """ python3 scripts/split_by_variant.py --metadata {input.metadata} --variants {params.variants} --variant-labels {params.variant_labels} """ |
112 113 114 115 | shell: """ python3 scripts/get_genotypes_pango.py --tree {input.tree} --root {input.root} --output {output.gt} """ |
130 131 132 133 134 135 136 137 138 139 140 | shell: """ python3 scripts/root_to_tip.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \ --clade-gts data/clade_gts.json \ --min-date {params.mindate} \ --max-date {params.maxdate} \ --qc-cutoff-scale 2 --qc-cutoff-offset 3 \ {params.filter_query} \ --output-plot {output.figure} \ --output-json {output.json} """ |
155 156 157 158 159 160 161 162 163 164 | shell: """ python3 scripts/get_genotype_counts.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \ --clade-gts data/clade_gts.json \ --min-date {params.mindate} \ --max-date {params.maxdate} \ {params.filter_query} \ --bin-size {params.bin_size} \ --output-json {output.json} """ |
171 172 173 174 | shell: """ python3 scripts/plot_genotype_counts.py --counts {input.json} --output-plot {output.fig} """ |
187 188 189 190 191 192 193 | shell: """ python3 scripts/clone_growth.py --metadata {input.metadata} --clade {params.clade} --sub-clades {params.clades} \ --clade-gts data/clade_gts.json \ --min-date {params.mindate} \ --output-plot {output.figure} """ |
201 202 203 204 | shell: """ python3 scripts/plot_af.py --counts {input.count_files} --output-plot {output.af_fig} --output-rates {output.rates} """ |
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | run: import json import pandas as pd from treetime.utils import datestring_from_numeric with open(input.gt) as fin: clade_gts = json.load(fin) offset_nuc = lambda clade: -2 if clade=='19B' else 2 offset_nonsyn = lambda clade: -1 if clade=='19B' else 1 data = [] qc_data = [] for fname in input.rate_files: with open(fname) as fin: d = json.load(fin) base_clade = d['clade'][:3] aa_div = len([x for x in clade_gts[base_clade]['aa'] if 'ORF9' not in x]) + offset_nonsyn(base_clade) nuc_div = len(clade_gts[base_clade[:3]]['nuc']) + offset_nuc(base_clade) data.append({'clade':d['clade'], 'nseqs':d['total_sequences'], 'nuc_rate': d['nuc']['slope'], 'nuc_origin': d['nuc']['origin'], 'nuc_origin_date': datestring_from_numeric(d['nuc']['origin']), 'aa_rate': d['aa']['slope'], 'aa_origin':d['aa']['origin'], 'aa_origin_date':datestring_from_numeric(d['aa']['origin']), 'syn_rate': d['syn']['slope'], 'syn_origin':d['syn']['origin'],'syn_origin_date':datestring_from_numeric(d['syn']['origin']), 'spike_rate': d['spike']['slope'], 'orf1a_rate': d['orf1a']['slope'], 'orf1b_rate': d['orf1b']['slope'], 'enm_rate': d['enm']['slope'], 'nuc_div': nuc_div, 'aa_div':aa_div, 'syn_div':nuc_div-aa_div}) qc_data.append({'clade':d['clade'], 'nseqs':d['total_sequences'], "outliers":d["outliers_removed"], "qc_fail":d["qc_filter_fail"]}) df = pd.DataFrame(data) df.to_csv(output.rate_table, sep='\t') df.to_latex(output.rate_table_tex, float_format="%.2f", index=False) dfqc = pd.DataFrame(qc_data) dfqc.to_csv(output.qc_table, sep='\t') |
274 275 276 277 278 279 280 | shell: """ python3 scripts/combine_fits.py --rate-table {input.rate_table}\ --output-plot {output.figure} \ --output-plot-rates {output.figure_rates} \ --output-plot-rates-genes {output.figure_rates_genes} """ |
292 293 294 295 296 297 298 299 300 | shell: """ python3 scripts/count_mutations.py --metadata {input.nextclade}\ --reference {input.ref} \ --pango-gts {input.pango_gts} \ --output-fitness {output.fitness_costs} \ --output-events {output.all_events} \ --output-mutations {output.mutation_rates} """ |
310 311 312 313 314 315 316 317 | shell: """ python3 scripts/plot_fitness.py --fitness {input.fitness_costs}\ --mutations {input.mutation_rates} \ --output-fitness {output.fitness_figure} \ --output-fitness-by-gene {output.fitness_figure_by_gene} \ --output-mutations {output.mutation_figure} """ |
324 325 326 327 328 | shell: """ python3 scripts/plot_fitness_landscape.py --fitness {input.fitness_costs}\ --output-fitness-landscape {output.fitness_landscape} """ |
Support
- Future updates
Related Workflows





