Estimating flu clade and mutation frequencies

public public 1yr ago 0 bookmarks

This README is for the data analysis pipeline. For the web interface, see web/README.md .

Development

Setup

Using Nextstrain CLI

# Linux
curl -fsSL --proto '=https' https://nextstrain.org/cli/installer/linux | bash
# Mac
curl -fsSL --proto '=https' https://nextstrain.org/cli/installer/mac | bash

You can set it up to use Docker or a Nextstrain managed conda environment (completely independent of any other conda environments you may have).

# Managed conda
nextstrain setup --set-default conda
# Docker
nextstrain setup --set-default docker

Run analysis:

nextstrain build . --profile profiles/flu

Using custom conda or Python environment

You will have to have at least the following packages/binaries installed:

  • Python

    • snakemake

    • augur

    • polars

  • nextclade

Then run using:

snakemake --profile profiles/flu

Viewing results in web app

Copy snakemake workflow results to data_web/inputs , ensuring that correct filenames are used, e.g.:

cp results/h3n2/region-country-frequencies.csv data_web/inputs/flu-h3n2.csv

Then process the csv files into json:

python scripts/web_convert.py --input-pathogens-json data_web/inputs/pathogens.json --output-dir web/public/data

TODO

  • Provide mamba environment file for simpler setup

  • Agree on formatters to use (snakefmt and black?)

Code Snippets

 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from fit_single_frequencies import load_and_aggregate, zero_one_clamp


def geo_label_map(x):
    if x=='China': return 'China(PRC)'
    return x

def fit_hierarchical_frequencies(totals, counts, time_bins, stiffness=0.5, stiffness_minor=0.1,
                                 mu=0.3, use_inverse_for_confidence=True):

    # Create copy but with "other" counts set to zero
    # Just for purpose of fitting, as if there was no data for "other"
    counts = counts.copy()
    counts['other'] = {}
    totals = totals.copy()
    totals['other'] = {}

    minor_cats = list(totals.keys())
    n_tp = len(time_bins)
    pc=3
    values, column, row = [], [], []
    b = []
    sq_confidence = []
    # deal with major frequency parameters
    for ti, t in enumerate(time_bins):
        if ti==0:
            diag = stiffness
            values.append(-stiffness); row.append(ti); column.append(ti+1)
        elif ti==n_tp-1:
            diag = stiffness
            values.append(-stiffness); row.append(ti); column.append(ti-1)
        else:
            diag = 2*stiffness
            values.append(-stiffness); row.append(ti); column.append(ti+1)
            values.append(-stiffness); row.append(ti); column.append(ti-1)

        res = 0
        b_res = 0
        total_n = 0
        total_k = 0
        for ci, cat in enumerate(minor_cats):
            k = counts.get(cat, {}).get(t, 0)
            n = totals.get(cat, {}).get(t, 0)
            pre_fac = n**2/(k + pc)/(n - k + pc)
            values.append(n*pre_fac); row.append(ti); column.append(ti + (ci+1)*n_tp)
            total_n += n
            total_k += k
            res += n*pre_fac
            b_res += k*pre_fac

        pre_fac = total_n**2/(total_k + pc)/(total_n - total_k + pc)
        extra_major = 0.2
        values.append(diag + res + extra_major*total_n*pre_fac); row.append(ti); column.append(ti)
        sq_confidence.append((total_k + pc)*(total_n - total_k + pc)/(total_n**3+pc))
        b.append(b_res + extra_major*total_k*pre_fac)


    # deal with frequency adjustment parameters
    for ci, cat in enumerate(minor_cats):
        for ti, t in enumerate(time_bins):
            row_index = ti + (ci+1)*n_tp
            if ti==0:
                diag = stiffness_minor
                values.append(-stiffness_minor); row.append(row_index); column.append(row_index+1)
            elif ti==n_tp-1:
                diag = stiffness_minor
                values.append(-stiffness_minor); row.append(row_index); column.append(row_index-1)
            else:
                diag = 2*stiffness_minor
                values.append(-stiffness_minor); row.append(row_index); column.append(row_index+1)
                values.append(-stiffness_minor); row.append(row_index); column.append(row_index-1)

            k = counts.get(cat, {}).get(t, 0)
            n = totals.get(cat, {}).get(t, 0)
            pre_fac = n**2/(k + pc)/(n - k + pc)
            diag += n*pre_fac
            values.append(diag + mu); row.append(row_index); column.append(row_index)
            values.append(n*pre_fac); row.append(row_index); column.append(ti)
            sq_confidence.append(1.0/((n**3+pc)/(k+pc)/(n-k+pc) + mu))
            b.append(k*pre_fac)

    from scipy.sparse import csr_matrix
    from scipy.sparse.linalg import spsolve
    A = csr_matrix((values, (row, column)), shape=(len(b), len(b)))
    sol = spsolve(A,b)

    if use_inverse_for_confidence:
        try:
            window = len(time_bins)
            matrix_conf_intervals = []
            for wi in range(len(b)//window):
                matrix_conf_intervals.extend(np.diag(np.linalg.inv(A[wi*window:(wi+1)*window,wi*window:(wi+1)*window].todense())))
            conf_to_use = matrix_conf_intervals
        except:
            print("Confidence through matrix inversion didn't work, using diagonal approximation instead.")
            conf_to_use = sq_confidence
    else:
        conf_to_use = sq_confidence

    freqs = {"time_points": time_bins}
    freqs["major_frequencies"] = {t:{"val": zero_one_clamp(sol[ti]),
                                     "upper":zero_one_clamp(sol[ti]+np.sqrt(conf_to_use[ti])),
                                     "lower":zero_one_clamp(sol[ti]-np.sqrt(conf_to_use[ti]))}
                                     for ti,t in enumerate(time_bins)}
    for ci, cat in enumerate(minor_cats):
        freqs[cat] = {}
        for ti,t in enumerate(time_bins):
            row_index = ti + (ci+1)*n_tp
            val = zero_one_clamp(sol[ti] + sol[row_index])
            dev = np.sqrt(conf_to_use[ti] + conf_to_use[row_index])
            freqs[cat][t] = {"val": val, "upper": zero_one_clamp(val+dev), "lower": zero_one_clamp(val-dev)}

    return freqs

if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--metadata", type=str, help="filename with metadata")
    parser.add_argument("--frequency-category", type=str, help="field to use for frequency categories")
    parser.add_argument("--geo-categories", nargs='+', type=str, help="field to use for geographic categories")
    parser.add_argument("--days", default=7, type=int, help="number of days in one time bin")
    parser.add_argument("--min-date", type=str, help="date to start frequency calculation")
    parser.add_argument("--output-csv", type=str, help="output csv file")
    parser.add_argument("--inclusive-clades", type=str, help="whether or not to generate inclusive clade/lineage categories")

    args = parser.parse_args()

    d = pl.read_csv(args.metadata, separator='\t', try_parse_dates=False, columns=args.geo_categories + [args.frequency_category, 'date'])
    freq_cat = args.frequency_category

    d = d.with_columns(pl.col("date").str.strptime(pl.Date, format="%Y-%m-%d", strict=False))

    data, totals, counts, time_bins = load_and_aggregate(d, args.geo_categories, freq_cat,
                                                         bin_size=args.days, min_date=args.min_date, inclusive_clades=args.inclusive_clades)

    dates = [time_bins[k] for k in time_bins]

    major_geo_cats = set([tuple(k[:-2]) for k in totals])
    output_data = []
    # major_geo_cats = set([('Europe',)])
    stiffness = 5000/args.days
    for geo_cat in major_geo_cats:
        geo_label = ','.join(geo_cat)
        minor_geo_cats = set([k[-2] for k in totals if k[:-2]==geo_cat])
        sub_totals = {"other": {}}
        data_totals = {}
        for minor_geo_cat  in minor_geo_cats:
            tmp = {k[-1]:v for k,v in totals.items() if k[:-2]==geo_cat and k[-2]==minor_geo_cat}
            data_totals[minor_geo_cat] = sum(tmp.values())
            if data_totals[minor_geo_cat]>20:
                sub_totals[minor_geo_cat] = tmp
            else:
                # Put all minor categories with less than 20 sequences into one special category
                sub_totals["other"] = {k: tmp.get(k, 0) + sub_totals["other"].get(k, 0) for k in set(tmp) | set(sub_totals["other"])}

        sub_counts = {}
        frequencies = {}
        for fcat in counts.keys():
            sub_counts[fcat] = {"other": {}}
            for minor_geo_cat in minor_geo_cats:
                tmp = {k[-1]:v for k,v in counts[fcat].items() if k[:-2]==geo_cat and k[-2]==minor_geo_cat}
                if minor_geo_cat in sub_totals:
                    sub_counts[fcat][minor_geo_cat] = tmp
                else:
                    sub_counts[fcat]["other"] = {k: tmp.get(k, 0) + sub_counts[fcat].get("other",{}).get(k, 0) for k in set(tmp) | set(sub_counts[fcat].get("other",{}))}
            frequencies[fcat] = fit_hierarchical_frequencies(sub_totals, sub_counts[fcat],
                                    sorted(time_bins.keys()), stiffness=stiffness,
                                    stiffness_minor=stiffness, mu=5.0)

            region_counts = {}
            region_totals = {}
            ## append entries for individual countries.
            for minor_geo_cat in sub_totals:
                for k, date in time_bins.items():
                    output_data.append({"date": date.strftime('%Y-%m-%d'), "region": geo_label, "country": geo_label_map(minor_geo_cat),
                                        "variant":fcat, "count": sub_counts[fcat][minor_geo_cat].get(k,0),
                                        "total": sub_totals[minor_geo_cat].get(k,0),
                                        "freqMi":frequencies[fcat][minor_geo_cat][k]['val'],
                                        "freqLo":frequencies[fcat][minor_geo_cat][k]['lower'],
                                        "freqUp":frequencies[fcat][minor_geo_cat][k]['upper']})
                    region_counts[k] = region_counts.get(k, 0) + sub_counts[fcat][minor_geo_cat].get(k,0)
                    region_totals[k] = region_totals.get(k, 0) + sub_totals[minor_geo_cat].get(k,0)

            ## append entries for region frequencies
            for k, date in time_bins.items():
                output_data.append({"date": date.strftime('%Y-%m-%d'), "region": geo_label,
                                    "country": geo_label, "variant":fcat,
                                    "count": region_counts[k], "total": region_totals[k],
                                    "freqMi":frequencies[fcat]["major_frequencies"][k]['val'],
                                    "freqLo":frequencies[fcat]["major_frequencies"][k]['lower'],
                                    "freqUp":frequencies[fcat]["major_frequencies"][k]['upper']})


    df = pl.DataFrame(output_data, schema={'date':str, 'region':str, 'country':str, 'variant':str,
                                           'count':int, 'total':int, 'freqMi':float, 'freqLo':float, 'freqUp':float})

    # region_totals = {(r[0], r[1]): r[2] for r in df.select(['date', 'region', 'count'])
    #                         .groupby(['date', 'region']).sum().iter_rows()}
    # region_counts = {(r[0], r[1], r[2]): r[3] for r in df.select(['date', 'region', 'variant', 'count'])
    #                         .groupby(['date', 'region', 'variant']).sum().iter_rows()}

    # df = df.with_columns([
    #       pl.struct(['date','region', 'country', 'total']).apply(
    #                     lambda x:region_totals.get((x['date'], x['region']),0)
    #                              if x['region']==x['country'] else x['total'])
    #         .alias('total'),
    #       pl.struct(['date','region', 'country', 'variant', 'count']).apply(
    #                     lambda x:region_counts.get((x['date'], x['region'], x['variant']), 0)
    #                               if x['region']==x['country'] else x['count'])
    #         .alias('count')
    # ])

    df.write_csv(args.output_csv, float_precision=4)
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from collections import defaultdict
from datetime import datetime

import numpy as np
import polars as pl


def zero_one_clamp(x):
    if np.isnan(x): return x
    return max(0,min(1,x))

def parse_dates(x):
    try:
        return datetime.strptime(x, "%Y-%m-%d")
    except:
        return None

def to_day_count(x, start_date):
    try:
        return x.toordinal()-start_date
    except:
        #print(x)
        return -1

def day_count_to_date(x, start_date):
    return datetime.fromordinal(start_date + x)

def load_and_aggregate(data, geo_categories, freq_category, min_date="2021-01-01", bin_size=7, inclusive_clades=""):
    if type(data)==str:
        d = pl.read_csv(data, separator='\t', try_parse_dates=True, columns = geo_categories + [freq_category, 'date'])
    else:
        d=data

    start_date = datetime.strptime(min_date, "%Y-%m-%d").toordinal()
    d = d.filter((~pl.col('date').is_null())&(~pl.col(freq_category).is_null()))
    d = d.with_columns([pl.col('date').apply(lambda x: to_day_count(x, start_date)).alias("day_count")])
    d = d.filter(pl.col("day_count")>=0)
    d = d.with_columns([(pl.col('day_count')//bin_size).alias("time_bin")])

    totals = dict()
    for row in d.groupby(by=geo_categories + ["time_bin"]).count().iter_rows():
        totals[row[:-1]] = row[-1]

    fcats = d[freq_category].unique()
    counts = {}
    for fcat in fcats:
        tmp = {}
        for row in d.filter(pl.col(freq_category)==fcat).groupby(by=geo_categories + ["time_bin"]).count().iter_rows():
            tmp[row[:-1]] = row[-1]
        counts[fcat] = tmp

    # For each cat in fcats, add a new category that also includes all children clades
    children = defaultdict(list)
    for fcat in fcats:
        for fcat2 in fcats:
            if fcat2.startswith(fcat):
                children[fcat].append(fcat2)

    if inclusive_clades == "flu":
        for lineage, children in children.items():
            if len(children)<=1: continue
            tmp = {}
            for child in children:
                # tmp is a dict where values are integers
                # I want to add dicts so that values are summed
                tmp = {k: tmp.get(k, 0) + counts[child].get(k, 0) for k in set(tmp) | set(counts[child])}
            counts[lineage + "*"] = tmp

    timebins = {int(x): day_count_to_date(x*bin_size, start_date) for x in sorted(d["time_bin"].unique())}

    return d, totals, counts, timebins

def fit_single_category(totals, counts, time_bins, stiffness=0.3, pc=3, nstd = 2):

    values, column, row = [], [], []
    b = []
    for ti, t in enumerate(time_bins):
        if t==time_bins[0]:
            diag = stiffness
            values.append(-stiffness)
            row.append(ti)
            column.append(ti+1)
        elif t==time_bins[-1]:
            diag = stiffness
            values.append(-stiffness)
            row.append(ti)
            column.append(ti-1)
        else:
            diag = 2*stiffness
            values.append(-stiffness)
            row.append(ti)
            column.append(ti+1)
            values.append(-stiffness)
            row.append(ti)
            column.append(ti-1)

        k = counts.get(t, 0)
        n = totals.get(t, 0)
        try:
            pre_fac = n**2/(k + pc)/(n - k + pc)
        except:
            print(n,k,pc)
            pre_fac = n**2/(k + pc)/(n - k + pc)

        diag += n*pre_fac
        values.append(diag)
        row.append(ti)
        column.append(ti)
        b.append(k*pre_fac)

    from numpy.linalg import inv
    from scipy.sparse import csr_matrix
    from scipy.sparse.linalg import spsolve
    A = csr_matrix((values, (row, column)), shape=(len(b), len(b)))
    sol = spsolve(A,b)
    confidence = np.sqrt(np.diag(inv(A.todense())))

    return {t:{'val': sol[ti],
               'upper': min(1.0, sol[ti] + nstd*confidence[ti]),
               'lower': max(0.0, sol[ti] - nstd*confidence[ti])} for ti,t in enumerate(time_bins)}, A



if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--metadata", type=str, help="filename with metadata")
    parser.add_argument("--frequency-category", type=str, help="field to use for frequency categories")
    parser.add_argument("--geo-categories", nargs='+', type=str, help="field to use for geographic categories")
    parser.add_argument("--days", default=7, type=int, help="number of days in one time bin")
    parser.add_argument("--min-date", type=str, help="date to start frequency calculation")
    parser.add_argument("--output-csv", type=str, help="file for csv output")
    parser.add_argument("--inclusive-clades", type=str, help="whether or not to generate inclusive clade/lineage categories")

    args = parser.parse_args()
    stiffness = 5000/args.days

    if args.frequency_category.startswith('mutation-'):
        d = pl.read_csv(args.metadata, separator='\t', try_parse_dates=False, columns=args.geo_categories + ["aaSubstitutions", 'date'])
        mutation = args.frequency_category.split('-')[-1]
        def extract_mut(muts):
            if type(muts)==str:
                a = [y for y in muts.split(',') if y.startswith(mutation)]
                return a[0] if len(a) else 'WT'
            else:
                return 'WT'
        d = d.with_columns([d["aaSubstitutions",:].apply(extract_mut).alias("mutation")])

        print(d["mutation"].value_counts())
        freq_cat = "mutation"
    else:
        d = pl.read_csv(args.metadata, separator='\t', try_parse_dates=False, columns=args.geo_categories + [args.frequency_category, 'date'])
        freq_cat = args.frequency_category
    d = d.with_columns(pl.col("date").str.strptime(pl.Date, format="%Y-%m-%d", strict=False))
    data, totals, counts, time_bins = load_and_aggregate(d, args.geo_categories, freq_cat,
                                                         bin_size=args.days, min_date=args.min_date, inclusive_clades=args.inclusive_clades)


    dates = [time_bins[k] for k in time_bins]
    geo_cats = set([k[:-1] for k in totals])
    output_data = []
    for geo_cat in geo_cats:
        geo_label = ','.join(geo_cat)
        frequencies = {}
        sub_counts = {}
        sub_totals = {k[-1]:v for k,v in totals.items() if tuple(k[:-1])==geo_cat}
        for fcat in counts.keys():
            sub_counts[fcat] = {k[-1]:v for k,v in counts[fcat].items() if tuple(k[:-1])==geo_cat}
            if sum(sub_counts[fcat].values())>10:
                frequencies[fcat],A = fit_single_category(sub_totals, sub_counts[fcat],
                                        sorted(time_bins.keys()), stiffness=stiffness)
                for k, date in time_bins.items():
                    output_data.append({"date": date.strftime('%Y-%m-%d'), "region": geo_label, "country": None,
                                        "count": sub_counts[fcat].get(k, 0), "total": sub_totals.get(k, 0),
                                        "variant":fcat,
                                        "freqMi":frequencies[fcat][k]['val'], "freqLo":frequencies[fcat][k]['lower'], "freqUp":frequencies[fcat][k]['upper']})

    df = pl.DataFrame(output_data, schema={'date':str, 'region':str, 'country':str, 'variant':str,
                                    'count':int, 'total':int,
                                    'freqMi':float, 'freqLo':float, 'freqUp':float})
    df.write_csv(args.output_csv, float_precision=4)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import matplotlib.pyplot as plt
import datetime
import polars  as pl
import numpy as np

if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--frequencies", type=str, help="csv")
    parser.add_argument("--region", type=str, help="region to plot")
    parser.add_argument("--country", type=str, help="country to plot")
    parser.add_argument("--max-freq", type=float, help="plot clades above this frequencies")
    parser.add_argument("--output", type=str, help="mask containing `{cat}` to plot")

    args = parser.parse_args()

    d = pl.read_csv(args.frequencies, try_parse_dates=True, dtypes={
            "variant": pl.Categorical, 
            "region": pl.Categorical, 
            "country": pl.Categorical
        })

    region = args.region.replace('_', ' ')
    country = args.country.replace('-', ' ')
    clades = sorted(d['variant'].fill_null('other').unique())
    fig = plt.figure()

    d = d.filter((pl.col('region')==region)&(pl.col('country')==country))
    plt.title(args.country)
    for ci,clade in enumerate(clades):
        subset = d.filter( pl.col('variant')==clade ).sort(by='date')
        dates = list(subset['date'])
        if max(subset['freqMi'])<args.max_freq: continue

        plt.plot(dates, [subset[i,'count']/subset[i,'total'] if subset[i,'total'] else np.nan
                                for i in range(len(dates))], 'o', c=f"C{ci}")
        plt.plot(dates, list(subset['freqMi']), c=f"C{ci}", label=clade)
        plt.fill_between(dates,
                        list(subset["freqLo"]),
                        list(subset["freqUp"]), color=f"C{ci}", alpha=0.2)
        print(clade, max(subset['freqUp']))
    fig.autofmt_xdate()
    plt.legend(loc=2)
    plt.savefig(args.output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import matplotlib.pyplot as plt
import polars as pl
import numpy as np
import json

if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--frequencies", type=str, help="json")
    parser.add_argument("--regions", nargs='+', type=str, help="regions to plot")
    parser.add_argument("--max-freq", type=float, help="plot clades above this frequencies")
    parser.add_argument("--auspice-config", help="Auspice config JSON with custom colorings for clades defined in a scale")
    parser.add_argument("--output", type=str, help="mask containing `{cat}` to plot")

    args = parser.parse_args()

    color_map = {}
    if args.auspice_config:
        with open(args.auspice_config, "r", encoding="utf-8") as fh:
            auspice_config = json.load(fh)

        if "colorings" in auspice_config:
            for coloring in auspice_config["colorings"]:
                if coloring["key"] == "clade_membership":
                    if "scale" in coloring:
                        print(f"Using color map defined in {args.auspice_config}")
                        color_map = {
                            clade: color
                            for clade, color in coloring["scale"]
                        }

                    break


    d = pl.read_csv(args.frequencies, try_parse_dates=True)
    clades = sorted(d['variant'].unique())
    all_dates = sorted(d['date'].unique())

    regions = [x.replace('_', ' ') for x in args.regions]

    n_regions = len(regions)
    n_rows = int(np.ceil(n_regions/2))
    fig, axs = plt.subplots(n_rows,2, sharex=True, sharey=True, figsize=(10,3*n_rows))
    clades_to_plot = set()

    for region in regions:
        for clade in clades:
            subset = d.filter((pl.col('region')==region)&(pl.col('variant')==clade)).sort(by='date')
            if len(subset) and max(subset['freqMi'])>args.max_freq:
                clades_to_plot.add(clade)

    clades_to_plot = sorted(clades_to_plot)

    for ri, region in enumerate(regions):
        ax = axs[ri//2, ri%2]
        ax.grid(color='grey', alpha=0.2)
        for ci,clade in enumerate(clades_to_plot):
            clade_color = color_map.get(clade, f"C{ci}")
            subset = d.filter((pl.col('region')==region)&(pl.col('variant')==clade)).sort(by='date')
            dates = subset['date']
            if len(subset):
                ax.plot(dates, [subset[i,'count']/subset[i,'total'] if subset[i,'total'] else np.nan
                                        for i in range(len(dates))], 'o', c=clade_color)
                ax.plot(dates, subset['freqMi'], c=clade_color, label=clade)
                ax.fill_between(dates,
                                subset["freqLo"],
                                subset["freqUp"], color=clade_color, alpha=0.2)
            else:
                ax.plot(all_dates[:2], [0,0], label=clade, c=clade_color)
        ax.plot(all_dates, np.ones(len(all_dates)), c='k', alpha=0.5)
        ax.text(all_dates[len(all_dates)//2], 1.1, region)
        ax.set_ylim(0,1.2)
    fig.autofmt_xdate()
    axs[0,0].legend(loc=3, ncol=2)
    plt.tight_layout()
    plt.savefig(args.output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
import polars as pl
import numpy as np

if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--frequencies", type=str, help="csv")
    parser.add_argument("--region", type=str, help="regions to plot")
    parser.add_argument("--max-freq", type=float, help="plot clades above this frequencies")
    parser.add_argument("--output", type=str, help="mask containing `{cat}` to plot")

    args = parser.parse_args()
    d = pl.read_csv(args.frequencies, try_parse_dates=True, infer_schema_length=1_000_000)
    clades = sorted(d['variant'].unique())
    region = args.region.replace('_', ' ')
    d = d.filter(pl.col('region')==region)

    fig = plt.figure()
    plt.title(region)
    for ci,clade in enumerate(clades):
        subset = d.filter( pl.col('variant')==clade ).sort(by='date')
        dates = list(subset['date'])
        if len(subset)==0 or max(subset['freqMi'])<args.max_freq: continue

        plt.plot(dates, [subset[i,'count']/subset[i,'total'] if subset[i,'total'] else np.nan
                                for i in range(len(dates))], 'o', c=f"C{ci}")
        plt.plot(dates, list(subset['freqMi']), c=f"C{ci}", label=clade)
        plt.fill_between(dates,
                        list(subset["freqLo"]),
                        list(subset["freqUp"]), color=f"C{ci}", alpha=0.2)

    fig.autofmt_xdate()
    plt.legend(loc=2)
    plt.savefig(args.output)
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from typing import Annotated

import polars as pl
from polars import col as c
import typer
from typer import Option


def read_tsv(path, *args, **kwargs):
    """
    Like polars.read_csv() but with default separator set to tab
    """
    kwargs.setdefault("separator", "\t")
    return pl.read_csv(path, *args, **kwargs)


def country_region_population(
    country_to_population: pl.DataFrame, country_to_region: pl.DataFrame
) -> pl.DataFrame:
    """
    Calculate population of each region
    ### Input
    country_to_population: country, population
    country_to_region: country, region
    ### Output
    columns: country, region, country_population, region_population
    """
    df_out = (
        country_to_population.rename({"population": "country_population"})
        .join(country_to_region, on="country")
        .with_columns(
            region_population=c("country_population").sum().over("region"),
        )
    )
    return df_out


def prepare_data(
    df: pl.DataFrame,
    country_to_population: pl.DataFrame,
    country_to_region: pl.DataFrame,
) -> pl.DataFrame:
    """
    Prepare data for aggregation
    Input df should have the following columns:
        - date
        - region
        - country
        - variant
        - freqMi
        - freqLo
        - freqUp
    Input country_to_population should have the following columns:
        - country
        - population
    Input country_to_region should have the following columns:
        - country
        - region
    Output DataFrame will have the following columns:
        - date
        - region
        - country
        - variant
        - freqMi
        - freqLo
        - freqUp
        - country_population
        - region_population
    """
    # Other country gets difference between region and sum(country).over(region)
    # Each region has represented population (with other counting 0)
    missing_population = (
        country_region_population(country_to_population, country_to_region)
        .filter(c("country").is_in(df.get_column("country")))
        .with_columns(
            missing_population=c("region_population")
            - (c("country_population").sum().over("region")),
        )
        # .with_columns(
        #     missing_population=c("region_population")
        #     - c("represented_population")
        # )
    )
    # Add "other" rows: like df.exclude(["country","country_population"]) with added country="other", country_population=missing_population
    population_with_other = missing_population.vstack(
        missing_population.with_columns(
            country=pl.lit("other"),
            country_population=c("missing_population"),
        ).unique()
    ).select(pl.exclude("missing_population"))

    df_out = df.filter(
        (c("country") != "?") & (c("country") != c("region"))
    ).join(population_with_other, on=["region", "country"], how="left")

    return df_out


def weighted_average(df: pl.DataFrame):
    """
    Calculates population weighted average for a region
    Frequency of special country "other" is used for countries not represented in data
    Weighted error is calculated as weighted average of squared errors
    Input DataFrame should have the following columns:
        - date
        - region
        - country
        - variant
        - freqMi
        - freqLo
        - freqUp
        - country_population
        - region_population
    Output DataFrame will have the following columns:
        - date
        - region
        - variant
        - freqMi
        - freqLo
        - freqUp
    """

    df = (
        df.filter(c("country") != c("region"))
        .with_columns(
            weight=c("country_population") / c("region_population"),
        )
        .with_columns(
            freqMi_pop_product=c("freqMi") * c("weight"),
            freqErr_pop_product=(
                pl.max([c("freqMi") - c("freqLo"), c("freqUp") - c("freqMi")])
                ** 2
            )
            * c("weight"),
        )
        .groupby(["date", "region", "variant"])
        .agg(
            freqMi=c("freqMi_pop_product").sum(),
            freqErr=c("freqErr_pop_product").sum().sqrt(),
            region_population=c("region_population").first(),
        )
        .select(
            ["date", "region", "variant", "freqMi", "region_population"],
            freqLo=pl.max([c("freqMi") - c("freqErr"), 0]),
            freqUp=pl.min([c("freqMi") + c("freqErr"), 1]),
            global_population=(
                df.unique("region").get_column("region_population").sum()
            ),
        )
    )

    global_df = (
        df.with_columns(
            weight=c("region_population") / c("global_population"),
        )
        .with_columns(
            freqMi_pop_product=c("freqMi") * c("weight"),
            freqErr_pop_product=(
                pl.max([c("freqMi") - c("freqLo"), c("freqUp") - c("freqMi")])
                ** 2
            )
            * c("weight"),
        )
        .groupby(["date", "variant"])
        .agg(
            freqMi=c("freqMi_pop_product").sum(),
            freqErr=c("freqErr_pop_product").sum().sqrt(),
        )
        .select(
            ["date", "variant", "freqMi"],
            freqLo=pl.max([c("freqMi") - c("freqErr"), 0]),
            freqUp=pl.min([c("freqMi") + c("freqErr"), 1]),
            region=pl.lit("global"),
        )
    )

    df = pl.concat(
        [
            df.select(pl.exclude(["region_population", "global_population"])),
            global_df,
        ],
        how="diagonal",
    )

    return df


def main(
    _fit_results: Annotated[
        str, Option("--fit-results")
    ] = "results/h3n2/region-country-frequencies.csv",
    _country_to_population: Annotated[
        str, Option("--country-to-population")
    ] = "defaults/iso3_to_pop.tsv",
    _country_to_region: Annotated[
        str, Option("--country-to-region")
    ] = "profiles/flu/iso3_to_region.tsv",
    output_csv: Annotated[
        str, Option()
    ] = "results/h3n2/region-country-frequencies-pop-weighted.csv",
):
    """
    Fit results need to have columns:
    - region (as defined in region_map)
    - country (iso3 or special case "other")
    - date (of bin start)
    - variant
    - count
    - total
    - freqMi
    - freqLo
    - freqUp
    """

    # Filter out unknown regions
    fit_results = pl.read_csv(_fit_results).filter(c("region") != "?")
    country_to_population = read_tsv(_country_to_population).select(
        country=c("iso3"), population=c("population")
    )
    country_to_region = read_tsv(_country_to_region).select(
        country=c("iso3"), region=c("continent")
    )

    # Prepare data
    prepped_data = prepare_data(
        fit_results, country_to_population, country_to_region
    )


    # Calculate weighted average
    weighted = weighted_average(prepped_data).select(
        ["region", c("region").alias("country")], pl.exclude("region")
    )

    pl.Config.set_tbl_cols(12)

    # Add global rows to fit_results with count/total
    global_fit_results = (
        fit_results.filter(c("country") == c("region"))
        .select(
            ["date", "variant", "count", "total"],
        )
        .groupby(["date", "variant"])
        .agg(
            count=c("count").sum(),
            total=c("total").sum(),
            region=pl.lit("global"),
        )
    )

    # Join count and total from original data
    df = weighted.join(
        pl.concat(
            [
                fit_results.filter(c("country") == c("region")).select(
                    ["date", "region", "variant", "count", "total"]
                ),
                global_fit_results,
            ],
            how="diagonal",
        ),
        on=["region", "variant", "date"],
        how="left",
    )

    # Write out the data
    df.write_csv(output_csv, float_precision=5)


if __name__ == "__main__":
    typer.run(main)
35
36
37
38
shell:
    """
    aws s3 cp {params.s3_path} - | xz -c -d > {output.sequences}
    """
54
55
56
57
58
59
60
61
62
shell:
    """
    augur parse \
        --sequences {input.sequences} \
        --output-sequences {output.sequences} \
        --output-metadata {output.metadata} \
        --fields {params.fasta_fields} \
        --prettify-fields {params.prettify_fields}
    """
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
shell:
    """
    tsv-join -H \
        --filter-file {input.country_to_iso3} \
        --key-fields country \
        --append-fields iso3 \
        --write-all="?" \
        {input.metadata} \
    | tsv-join -H \
        --filter-file {input.iso3_to_region} \
        --key-fields iso3 \
        --append-fields continent \
        --write-all="?" \
        > {output.metadata}
    """
93
94
95
96
shell:
    """
    nextclade dataset get -n flu_{wildcards.lineage}_ha --output-dir nextclade/{wildcards.lineage}
    """
106
107
108
109
shell:
    """
    nextclade run -j {threads} -D nextclade/{wildcards.lineage} {input.sequences} --quiet --output-tsv {output}
    """
120
121
122
123
124
125
126
127
128
129
130
131
132
run:
    import pandas as pd

    clades = pd.read_csv(input[0], sep="\t", index_col="seqName")[params.col]
    aaSubstitutions = pd.read_csv(input[0], sep="\t", index_col="seqName")[
        "aaSubstitutions"
    ]

    metadata = pd.read_csv(input[1], sep="\t", index_col="strain")
    metadata["clade"] = clades
    metadata["aaSubstitutions"] = aaSubstitutions

    metadata.to_csv(output[0], sep="\t", index=False)
142
143
144
145
146
147
148
149
150
151
152
shell:
    """
    python scripts/fit_single_frequencies.py \
        --metadata {input} \
        --geo-categories continent \
        --frequency-category clade \
        --min-date {params.min_date} \
        --days 14 \
        --inclusive-clades flu \
        --output-csv {output.output_csv}
    """
SnakeMake From line 142 of master/Snakefile
162
163
164
165
166
167
168
169
170
171
shell:
    """
    python scripts/fit_single_frequencies.py \
        --metadata {input} \
        --geo-categories continent \
        --frequency-category mutation-{wildcards.mutation} \
        --min-date {params.min_date} \
        --days 14 \
        --output-csv {output.output_csv}
    """
SnakeMake From line 162 of master/Snakefile
181
182
183
184
185
186
187
188
189
190
191
shell:
    """
    python scripts/fit_hierarchical_frequencies.py \
        --metadata {input} \
        --geo-categories continent iso3 \
        --frequency-category clade \
        --min-date {params.min_date} \
        --days 14 \
        --inclusive-clades flu \
        --output-csv {output.output_csv}
    """
SnakeMake From line 181 of master/Snakefile
201
202
203
204
205
206
207
208
shell:
    """
    python scripts/pop_weighted_aggregates.py \
        --fit-results {input.fit_results} \
        --country-to-population {input.iso3_to_pop} \
        --country-to-region {input.iso3_to_region} \
        --output-csv {output.output_csv}
    """
SnakeMake From line 201 of master/Snakefile
218
219
220
221
222
223
224
225
shell:
    """
    python scripts/plot_region.py \
        --frequencies {input.freqs} \
        --region {wildcards.region:q} \
        --max-freq {params.max_freq} \
        --output {output.plot}
    """
SnakeMake From line 218 of master/Snakefile
235
236
237
238
239
240
241
242
shell:
    """
    python scripts/plot_region.py \
        --frequencies {input.freqs} \
        --region {wildcards.region:q} \
        --max-freq {params.max_freq} \
        --output {output.plot}
    """
SnakeMake From line 235 of master/Snakefile
252
253
254
255
256
257
258
259
shell:
    """
    python scripts/plot_region.py \
        --frequencies {input.freqs} \
        --region {wildcards.region:q} \
        --max-freq {params.max_freq} \
        --output {output.plot}
    """
SnakeMake From line 252 of master/Snakefile
269
270
271
272
273
274
275
276
277
shell:
    """
    python scripts/plot_country.py \
        --frequencies {input.freqs} \
        --region {wildcards.region:q} \
        --country {wildcards.country:q} \
        --max-freq {params.max_freq} \
        --output {output.plot}
    """
SnakeMake From line 269 of master/Snakefile
297
298
299
300
301
302
shell:
    """
    python3 scripts/plot_multi-region.py --frequencies {input.freqs}  \
            --regions {params.regions}  --max-freq {params.max_freq} \
            --output {output.plot}
    """
SnakeMake From line 297 of master/Snakefile
322
323
324
325
shell:
    """
    python3 scripts/plot_multi-region.py --frequencies {input.freqs} --regions {params.regions}  --max-freq {params.max_freq} --output {output.plot}
    """
SnakeMake From line 322 of master/Snakefile
332
333
shell:
    "rm -rf data/ results/ plots/"
SnakeMake From line 332 of master/Snakefile
ShowHide 17 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://flu-frequencies.vercel.app
Name: flu_frequencies
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...