Integrating genotypes and phenotypes improves long-term forecasts of seasonal influenza A/H3N2 evolution

public public 1yr ago Version: revised-submission 0 bookmarks

John Huddleston1,2, John R. Barnes3, Thomas Rowe3, Xiyan Xu3, Rebecca Kondor3, David E. Wentworth3, Lynne Whittaker4, Burcu Ermetal4, Rodney S. Daniels4, John W. McCauley4, Seiichiro Fujisaki5, Kazuya Nakamura5, Noriko Kishida5, Shinji Watanabe5, Hideki Hasegawa5, Ian Barr6, Kanta Subbarao6, Pierre Barrat-Charlaix7,8, Richard A. Neher7,8 & Trevor Bedford1

1Vaccine and Infectious Disease Division, Fred Hutchinson Cancer Research Center, Seattle, WA, USA, 2Molecular and Cell Biology, University of Washington, Seattle, WA, USA, 3Virology Surveillance and Diagnosis Branch, Influenza Division, National Center for Immunization and Respiratory Diseases (NCIRD), Centers for Disease Control and Prevention (CDC), 1600 Clifton Road, Atlanta, GA 30333, USA, 4WHO Collaborating Centre for Reference and Research on Influenza, Crick Worldwide Influenza Centre, The Francis Crick Institute, London, UK., 5Influenza Virus Research Center, National Institute of Infectious Diseases, Tokyo, Japan, 6The WHO Collaborating Centre for Reference and Research on Influenza, The Peter Doherty Institute for Infection and Immunity, Melbourne, VIC, Australia; Department of Microbiology and Immunology, The University of Melbourne, The Peter Doherty Institute for Infection and Immunity, Melbourne, VIC, Australia., 7Biozentrum, University of Basel, Basel, Switzerland, 8Swiss Institute of Bioinformatics, Basel, Switzerland

DOI: https://doi.org/10.7554/eLife.60067

Contents

  1. Abstract

  2. Installation

  3. Quickstart

  4. Configuration

  5. Workflow structure

  6. Full analysis

Abstract

Seasonal influenza virus A/H3N2 is a major cause of death globally. Vaccination remains the most effective preventative. Rapid mutation of hemagglutinin allows viruses to escape adaptive immunity. This antigenic drift necessitates regular vaccine updates. Effective vaccine strains need to represent H3N2 populations circulating one year after strain selection. Experts select strains based on experimental measurements of antigenic drift and predictions made by models from hemagglutinin sequences. We developed a novel influenza forecasting framework that integrates phenotypic measures of antigenic drift and functional constraint with previously published sequence-only fitness estimates. Forecasts informed by phenotypic measures of antigenic drift consistently outperformed previous sequence-only estimates, while sequence-only estimates of functional constraint surpassed more comprehensive experimentally-informed estimates. Importantly, the best models integrated estimates of both functional constraint and either antigenic drift phenotypes or recent population growth.

Installation

Install miniconda . Clone the forecasting repository.

git clone https://github.com/blab/flu-forecasting.git
cd flu-forecasting

Create and activate a conda environment for the pipeline.

conda env create -f envs/anaconda.python3.yaml
conda activate flu_forecasting

Quickstart

Run the pipeline for sparse simulated data. This will first simulate influenza-like populations and then fit models to those populations. Inspect all steps to be executed by the pipeline with a dryrun.

snakemake --dryrun --use-conda --config active_builds='simulated_sample_1'

Run the pipeline locally with four jobs (or cores) at once.

snakemake --use-conda --config active_builds='simulated_sample_1' -j 4

Always specify a value for -j , to limit the number of cores available to the simulator. If no limit is provided, the Java-based simulator will attempt to use all available cores and may cause headaches for you or your cluster's system administrator.

Configuration

Analyses are parameterized by the contents of config/config.json . Models are fit to annotated data frames created for one or more "builds" from one or more "datasets". Datasets and builds are decoupled to allow multiple builds from a single dataset. Builds are split into "simulated" and "natural" such that each entry in one of these categories is a dictionary of build settings indexed by a build name. The list of active builds is determined by the space-delimited values in the active_builds top-level key of the configuration.

Workflow structure

Workflow

The analyses for this paper were produced using a workflow written with Snakemake . The complete graph of the workflow is available as a PDF . This PDF was created with the following Snakemake command.

snakemake --forceall --dag manuscript/flu_forecasting.pdf | dot -Tpdf > full_dag.pdf

Below is a subset of the complete workflow showing how tip attributes are created for a single timepoint (2015-10-01) from the natural populations analysis. This image was created with the following Snakemake command.

snakemake --forceall --dag \
 results/builds/natural/natural_sample_1_with_90_vpm_sliding/timepoints/2015-10-01/tip_attributes.tsv | \
 dot -Tpng > example_dag.png

Example branch of the complete workflow

Inputs

Both simulated and natural population builds depend on the configuration file , config/config.json , described above.

Simulated populations are generated by SANTA-SIM as part of the workflow. SANTA-SIM XML configuration files determine the parameters of the simulations and can be found in the corresponding data directory for a given simulated sample. For example, the densely sampled simulated populations configuration file is data/simulated/simulated_sample_3/influenza_h3n2_ha.xml .

Natural populations are represented by FASTA sequences that are freely available through GISAID. See instructions on how to download these sequences below . The full analysis for this paper also depends on raw hemagglutination inhibition (HI) and focus-reduction assay (FRA) titer measurements. Although these measurements are not publicly available, due to existing data sharing agreements, we provide imputed log2 titer values produced by Neher al. 2016's phylogenetic model for each strain. These values are available in the results files named tip_attributes_with_weighted_distances.tsv . For example, the complete set of tip attributes including imputed titer drops for the validation period of natural populations are available in results/builds/natural/natural_sample_1_with_90_vpm_sliding/tip_attributes_with_weighted_distances.tsv .

Outputs

The primary outputs of this workflow are tables of tip attributes per populations that are used to fit models ( tip_attributes_with_weighted_distances.tsv ) and the tables of resulting model coefficients ( distance_model_coefficients.tsv ) and distances to the future ( distance_model_errors.tsv ). Data for validation figures (e.g., Figures 4 and 7) can be found in validation_figure_clades.tsv and validation_figure_ranks.tsv . Additional outputs include the mapping of individual strains to clades ( tips_to_clades.tsv ) for the creation of model validation figures (e.g., comparison of estimated and observed clade frequency fold changes and absolute forecasting errors). The following outputs are included in this repository and are also created by running the full analysis pipeline.

  • results/

    • distance_model_errors.tsv

    • distance_model_coefficients.tsv

    • validation_figure_clades.tsv

    • validation_figure_ranks.tsv

    • builds/

      • natural/

        • natural_sample_1_with_90_vpm_sliding/

          • tip_attributes_with_weighted_distances.tsv
        • natural_sample_1_with_90_vpm_sliding_test_tree/

          • tip_attributes_with_weighted_distances.tsv
      • simulated/

        • simulated_sample_3/

          • tip_attributes_with_weighted_distances.tsv
        • simulated_sample_3_test_tree/

          • tip_attributes_with_weighted_distances.tsv

The manuscript and most figures and tables within are also automatically generated by the full analysis workflow. These files can be found in the following paths.

  • manuscript/

    • flu_forecasting.pdf

    • figures/

    • tables/

Full analysis

Inspect sequences for simulated populations

Each SANTA-SIM run and subsequent subsampling of the resulting sequences will produce a different random collection of sequences for the workflow. To ensure reproducibility of results, we have included the specific simulated sequences used for analyses in the manuscript. These sequences and their corresponding metadata are available at the following paths:

  • data/simulated/simulated_sample_3/filtered_sequences.fasta

  • data/simulated/simulated_sample_3/filtered_metadata.tsv

Download sequences for natural populations

All hemagglutinin sequences for natural populations are available through the GISAID database . To get access to the database, register for a free GISAID account. After logging into GISAID, select the "EpiFlu" tab from the navigation bar.

Downloading sequences from GISAID requires manually searching for specific accessions (i.e., sequence identifiers) and downloading the corresponding sequences. The maximum length of the GISAID search field is 1,000 characters, so you cannot search for all 20,000+ sequences at once. To facilitate the download process, we have created batches of accessions no longer than 1,000 characters in the file data/gisaid_batches.csv . Each of the 216 batches has its own id and expected number of sequences, to help you track your progress. Copy and paste the list of accessions from each batch into the "Search patterns" field of the GISAID search and select the "Search" button. An example search is shown below.

Example GISAID search for a single batch

From the search results display, select the checkbox in the top-left of the search display (above the checkbox for the first row of results). This will select all matching sequences to be downloaded. Click the "Download" button. An example of these search results is shown below.

Example GISAID search for a single batch

When the download dialog appears, select the "Sequences (DNA) as FASTA" radio button. Click the checkbox near "HA" to only download hemagglutinin sequences. Delete the contents of the "FASTA Header" text field and paste in the following line instead:

Isolate name | Isolate ID | Collection date | Passage details/history | Submitting lab

Leave all other fields at their default values. The download interface should look like the following screenshot.

Example GISAID download for a single batch

Click the "Download" button and name the resulting FASTA file with the same id as your current batch (e.g., gisaid_downloads/gisaid_epiflu_sequence_001.fasta ). This file naming convention will make tracking your progress easier. After the download completes, click the "Go back" button on the download dialog and then again from the search results display. Copy and paste the next batch of ids into the search field and repeat these steps until you have downloaded all batches.

When you have downloaded sequences for all batches, concatenate them together into a single file.

cat gisaid_epiflu_sequences/gisaid_epiflu_sequence_*.fasta > gisaid_downloads.fasta

Some strain names contain characters that IQ-TREE does not allow and which it will convert to underscores in its output trees. For example, the apostrophe in the name "Cote d'Ivoire" will be replaced with an underscore. To avoid mismatches between strain names caused by this IQ-TREE replacement, we replace those characters in the initial FASTA file at the beginning of the analysis using seqkit's replace command .

# Install seqkit. Optionally, use "mamba install" instead of "conda install".
conda install -c conda-forge -c bioconda seqkit
# Replace apostrophes with underscores in the FASTA record names.
seqkit replace -p "(')" -r "_" gisaid_downloads.fasta > gisaid_downloads.renamed.fasta

Use augur to parse out the metadata and sequences into separate files. Store these files in a directory with the same name as the natural samples in this analysis.

# Write out sequences and metadata for the validation sample.
mkdir -p data/natural/natural_sample_1_with_90_vpm
augur parse \
 --sequences gisaid_downloads.renamed.fasta \
 --output-sequences data/natural/natural_sample_1_with_90_vpm/filtered_sequences.fasta \
 --output-metadata data/natural/natural_sample_1_with_90_vpm/strains_metadata.tsv \
 --fields strain accession collection_date passage_category submitting_lab
# Copy the resulting sequences and metadata into the test sample directory.
mkdir -p data/natural/natural_sample_1_with_90_vpm_test_tree
cp data/natural/natural_sample_1_with_90_vpm/*.{fasta,tsv} data/natural/natural_sample_1_with_90_vpm_test_tree/

Now, you should be able to run the pipeline from start to finish. Confirm this is true by running snakemake in dry run mode.

snakemake --dryrun

Inspect derived titer data for natural populations

Due to existing data sharing agreements, we cannot publicly distribute raw titer measurements for hemagglutination inhibition (HI) assays and focus reduction assays (FRAs). As an alternative, we provide the derived titer models produced by the augur titers command using the algorithms described in Neher et al. 2016 . For HI assays, we provide two different model files per analysis timepoint for the titer "tree model" and "substitution model". These files are named titers-tree-model.json and titers-sub-model.json , respectively. For FRAs, we provide model files for the tree model with names like fra-titers-tree-model.json . Example paths for each of these files are listed below for a single timepoint in the analysis of most recent A/H3N2 sequences.

  • results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/titers-tree-model.json

  • results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/titers-sub-model.json

  • results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/fra-titers-tree-model.json

These model files contain all information required to fit the HI- and FRA-based forecasting models described in the manuscript.

Run the full analysis

Run the entire pipeline locally with four simultaneous jobs.

snakemake --use-conda -j 4

You can also run just one of the natural builds as follows, to confirm your environment is configured properly.

snakemake --use-conda --config active_builds='natural_sample_1_with_90_vpm_sliding' -j 4

Alternately, follow Snakemake documentation to distribute the entire pipeline to your cloud or cluster accounts . The following is an example of how to distribute the pipeline on a SLURM-based cluster using a Snakemake profile.

snakemake --profile profiles/slurm-drmaa

Code Snippets

18
19
20
21
22
23
24
25
26
shell:
    """
    python3 scripts/partition_strains_by_timepoint.py \
        {input.metadata} \
        {wildcards.timepoint} \
        {output} \
        --years-back {params.years_back} \
        {params.reference_strains}
    """
36
37
38
39
40
41
42
shell:
    """
    python3 scripts/extract_sequences.py \
        --sequences {input.sequences} \
        --samples {input.strains} \
        --output {output}
    """
54
55
56
57
58
59
60
61
62
63
shell:
    """
    augur align \
        --sequences {input.sequences} \
        --reference-sequence {input.reference} \
        --output {output.alignment} \
        --remove-reference \
        --fill-gaps \
        --nthreads {threads}
    """
77
78
79
80
81
82
83
84
shell:
    """
    augur tree \
        --alignment {input.alignment} \
        --output {output.tree} \
        --method iqtree \
        --nthreads {threads} &> {log}
    """
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
shell:
    """
    augur refine \
        --tree {input.tree} \
        --alignment {input.alignment} \
        --metadata {input.metadata} \
        --output-tree {output.tree} \
        --output-node-data {output.node_data} \
        --timetree \
        --no-covariance \
        {params.clock_rate} \
        {params.clock_std_dev} \
        --coalescent {params.coalescent} \
        --date-confidence \
        --date-inference {params.date_inference} &> {log}
    """
144
145
146
147
148
149
150
151
152
153
154
155
156
157
shell:
    """
    augur frequencies \
        --method kde \
        --tree {input.tree} \
        --metadata {input.metadata} \
        --narrow-bandwidth {params.narrow_bandwidth} \
        --wide-bandwidth {params.wide_bandwidth} \
        --proportion-wide {params.proportion_wide} \
        --min-date {params.min_date} \
        --max-date {params.max_date} \
        --pivot-interval {params.pivot_frequency} \
        --output {output}
    """
186
187
188
189
190
191
192
193
    shell: """python3 scripts/frequencies.py {input.tree} {input.metadata} {output} \
--narrow-bandwidth {params.narrow_bandwidth} \
--wide-bandwidth {params.wide_bandwidth} \
--proportion-wide {params.proportion_wide} \
--pivot-frequency {params.pivot_frequency} \
--start-date {params.start_date} \
--end-date {wildcards.timepoint} \
--include-internal-nodes &> {log}"""
SnakeMake From line 186 of rules/builds.smk
215
216
217
218
219
220
221
222
223
224
225
shell: """augur frequencies \
    --method diffusion \
    --tree {input.tree} \
    --metadata {input.metadata} \
    --output {output} \
    --include-internal-nodes \
    --stiffness {params.stiffness} \
    --inertia {params.inertia} \
    --pivot-interval {params.pivot_frequency} \
    --min-date {params.min_date} \
    --max-date {params.max_date} &> {log}"""
237
238
239
240
241
242
243
244
245
shell:
    """
    python3 scripts/frequencies_to_table.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --method {params.method} \
        --output {output} \
        --annotations timepoint={wildcards.timepoint}
    """
SnakeMake From line 237 of rules/builds.smk
257
258
259
260
261
262
263
264
265
shell:
    """
    python3 scripts/frequencies_to_table.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --method {params.method} \
        --output {output} \
        --annotations timepoint={wildcards.timepoint}
    """
SnakeMake From line 257 of rules/builds.smk
280
281
282
283
284
285
286
287
shell:
    """
    augur ancestral \
        --tree {input.tree} \
        --alignment {input.alignment} \
        --output {output.node_data} \
        --inference {params.inference} &> {log}
    """
301
302
303
304
305
306
307
308
shell:
    """
    augur translate \
        --tree {input.tree} \
        --ancestral-sequences {input.node_data} \
        --reference-sequence {input.reference} \
        --output {output.node_data} &> {log}
    """
321
322
323
324
325
326
327
328
329
shell:
    """
    augur reconstruct-sequences \
        --tree {input.tree} \
        --mutations {input.node_data} \
        --gene {wildcards.gene} \
        --output {output.aa_alignment} \
        --internal-nodes &> {log}
    """
340
341
342
343
344
345
346
347
shell:
    """
    python3 scripts/convert_translations_to_json.py \
        --tree {input.tree} \
        --alignment {input.translations} \
        --gene-names {params.gene_names} \
        --output {output.translations}
    """
SnakeMake From line 340 of rules/builds.smk
361
362
363
364
365
366
367
368
369
370
shell:
    """
    python3 scripts/nonoverlapping_clades.py \
        --tree {input.tree} \
        --translations {input.translations} \
        --gene-names {params.gene_names} \
        --annotations timepoint={wildcards.timepoint} \
        --output {output.clades} \
        --output-tip-clade-table {output.tip_clade_table} &> {log}
    """
SnakeMake From line 361 of rules/builds.smk
384
385
386
387
388
389
390
391
392
shell:
    """
    python3 scripts/calculate_delta_frequency.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --frequency-method {params.method} \
        --delta-pivots {params.delta_pivots} \
        --output {output.delta_frequency} &> {log}
    """
SnakeMake From line 384 of rules/builds.smk
406
407
408
409
410
411
412
413
414
shell:
    """
    augur traits \
        --tree {input.tree} \
        --metadata {input.metadata} \
        --output {output.node_data} \
        --columns {params.columns} \
        --confidence
    """
433
434
435
436
437
438
439
440
441
442
443
444
445
446
shell:
    """
    augur distance \
        --tree {input.tree} \
        --alignment {input.alignments} \
        --gene-names {params.genes} \
        --compare-to {params.comparisons} \
        --attribute-name {params.attribute_names} \
        --map {input.distance_maps} \
        --date-annotations {input.date_annotations} \
        --earliest-date {params.earliest_date} \
        --latest-date {params.latest_date} \
        --output {output}
    """
465
466
467
468
469
470
471
472
473
474
475
476
477
shell:
    """
    python3 scripts/pairwise_distances.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --alignment {input.alignments} \
        --gene-names {params.genes} \
        --attribute-name {params.attribute_names} \
        --map {input.distance_maps} \
        --date-annotations {input.date_annotations} \
        --years-back-to-compare {params.years_back_to_compare} \
        --output {output} &> {log}
    """
SnakeMake From line 465 of rules/builds.smk
505
506
507
508
509
510
511
512
513
514
515
shell:
    """
    python3 src/cross_immunity.py \
        --frequencies {input.frequencies} \
        --distances {input.distances} \
        --date-annotations {input.date_annotations} \
        --distance-attributes {params.distance_attributes} \
        --immunity-attributes {params.immunity_attributes} \
        --decay-factors {params.decay_factors} \
        --output {output}
    """
SnakeMake From line 505 of rules/builds.smk
530
531
532
533
534
535
536
537
538
539
shell:
    """
    augur lbi \
        --tree {input.tree} \
        --branch-lengths {input.branch_lengths} \
        --output {output} \
        --attribute-names {params.names} \
        --tau {params.tau} \
        --window {params.window}
    """
553
554
555
556
557
558
559
560
561
562
563
shell:
    """
    augur lbi \
        --tree {input.tree} \
        --branch-lengths {input.branch_lengths} \
        --output {output} \
        --attribute-names {params.names} \
        --tau {params.tau} \
        --window {params.window} \
        --no-normalization
    """
581
582
583
584
585
586
587
588
589
590
shell:
    """
    augur titers sub \
        --titers {input.titers} \
        --alignment {input.alignments} \
        --tree {input.tree} \
        --gene-names {params.genes} \
        --allow-empty-model \
        --output {output.titers_model} &> {log}
    """
602
603
604
605
606
607
608
609
shell:
    """
    augur titers tree \
        --titers {input.titers} \
        --tree {input.tree} \
        --allow-empty-model \
        --output {output.titers_model} &> {log}
    """
621
622
623
624
625
626
627
628
shell:
    """
    augur titers tree \
        --titers {input.titers} \
        --tree {input.tree} \
        --allow-empty-model \
        --output {output.titers_model} &> {log}
    """
640
641
642
643
644
645
shell:
    """
    python3 scripts/rename_fields_in_fra_titer_models.py \
        --titers-model {input.titers_model} \
        --output {output.titers_model}
    """
SnakeMake From line 640 of rules/builds.smk
654
655
656
657
658
659
shell:
    """
    python3 scripts/titer_model_to_distance_map.py \
        --model {input.model} \
        --output {output}
    """
SnakeMake From line 654 of rules/builds.smk
678
679
680
681
682
683
684
685
686
687
688
689
690
shell:
    """
    python3 scripts/pairwise_distances.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --alignment {input.alignments} \
        --gene-names {params.genes} \
        --attribute-name {params.attribute_names} \
        --map {input.distance_maps} \
        --date-annotations {input.date_annotations} \
        --years-back-to-compare {params.years_back_to_compare} \
        --output {output} &> {log}
    """
SnakeMake From line 678 of rules/builds.smk
708
709
710
711
712
713
714
715
716
717
718
719
shell:
    """
    python3 scripts/pairwise_titer_tree_distances.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --model {input.model} \
        --attribute-name {params.attribute_names} \
        --date-annotations {input.date_annotations} \
        --months-back-for-current-samples {params.months_back_for_current_samples} \
        --years-back-to-compare {params.years_back_to_compare} \
        --output {output} &> {log}
    """
SnakeMake From line 708 of rules/builds.smk
738
739
740
741
742
743
744
745
746
747
748
749
750
shell:
    """
    python3 scripts/pairwise_titer_tree_distances.py \
        --tree {input.tree} \
        --frequencies {input.frequencies} \
        --model {input.model} \
        --model-attribute-name {params.model_attribute_name} \
        --attribute-name {params.attribute_names} \
        --date-annotations {input.date_annotations} \
        --months-back-for-current-samples {params.months_back_for_current_samples} \
        --years-back-to-compare {params.years_back_to_compare} \
        --output {output} &> {log}
    """
SnakeMake From line 738 of rules/builds.smk
766
767
768
769
770
771
772
773
774
775
776
shell:
    """
    python3 src/cross_immunity.py \
        --frequencies {input.frequencies} \
        --distances {input.distances} \
        --date-annotations {input.date_annotations} \
        --distance-attributes {params.distance_attributes} \
        --immunity-attributes {params.immunity_attributes} \
        --decay-factors {params.decay_factors} \
        --output {output}
    """
SnakeMake From line 766 of rules/builds.smk
792
793
794
795
796
797
798
799
800
801
802
shell:
    """
    python3 src/cross_immunity.py \
        --frequencies {input.frequencies} \
        --distances {input.distances} \
        --date-annotations {input.date_annotations} \
        --distance-attributes {params.distance_attributes} \
        --immunity-attributes {params.immunity_attributes} \
        --decay-factors {params.decay_factors} \
        --output {output}
    """
SnakeMake From line 792 of rules/builds.smk
818
819
820
821
822
823
824
825
826
827
828
shell:
    """
    python3 src/cross_immunity.py \
        --frequencies {input.frequencies} \
        --distances {input.distances} \
        --date-annotations {input.date_annotations} \
        --distance-attributes {params.distance_attributes} \
        --immunity-attributes {params.immunity_attributes} \
        --decay-factors {params.decay_factors} \
        --output {output}
    """
SnakeMake From line 818 of rules/builds.smk
839
840
841
842
843
844
845
846
shell:
    """
    python3 scripts/normalize_fitness.py \
        --metadata {input.metadata} \
        --frequencies-table {input.frequencies} \
        --frequency-method {params.preferred_frequency_method} \
        --output {output.fitness}
    """
SnakeMake From line 839 of rules/builds.smk
856
857
858
859
860
861
862
shell:
    """
    python3 scripts/distance_from_consensus.py \
        --sequences {input.sequences} \
        --frequencies {input.frequencies} \
        --output {output.distances}
    """
SnakeMake From line 856 of rules/builds.smk
920
921
922
923
924
925
926
927
928
929
930
931
shell:
    """
    python3 scripts/node_data_to_table.py \
        --tree {input.tree} \
        --metadata {input.metadata} \
        --jsons {input.node_data} \
        --output {output} \
        {params.excluded_fields_arg} \
        --annotations timepoint={wildcards.timepoint} \
                      lineage={params.lineage} \
                      segment={params.segment}
    """
SnakeMake From line 920 of rules/builds.smk
944
945
946
947
948
949
950
951
952
shell:
    """
    python3 scripts/merge_node_data_and_frequencies.py \
        --node-data {input.node_data} \
        --kde-frequencies {input.kde_frequencies} \
        --diffusion-frequencies {input.diffusion_frequencies} \
        --preferred-frequency-method {params.preferred_frequency_method} \
        --output {output.table}
    """
SnakeMake From line 944 of rules/builds.smk
961
962
963
964
965
966
shell:
    """
    python3 scripts/collect_tables.py \
        --tables {input} \
        --output {output.attributes}
    """
SnakeMake From line 961 of rules/builds.smk
975
976
977
978
979
980
shell:
    """
    python3 scripts/annotate_naive_tip_attribute.py \
        --tip-attributes {input.attributes} \
        --output {output.attributes}
    """
SnakeMake From line 975 of rules/builds.smk
 993
 994
 995
 996
 997
 998
 999
1000
shell:
    """
    python3 scripts/calculate_target_distances.py \
        --tip-attributes {input.attributes} \
        --delta-months {params.delta_months} \
        --sequence-attribute-name {params.sequence_attribute_name} \
        --output {output}
    """
SnakeMake From line 993 of rules/builds.smk
1011
1012
1013
1014
1015
1016
1017
1018
shell:
    """
    python3 src/weighted_distances.py \
        --tip-attributes {input.attributes} \
        --distances {input.distances} \
        --delta-months {params.delta_months} \
        --output {output}
    """
SnakeMake From line 1011 of rules/builds.smk
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
shell:
    """
    python3 src/fit_model.py \
        --tip-attributes {input.attributes} \
        --training-window {params.training_window} \
        --delta-months {params.delta_months} \
        --predictors {params.predictors} \
        --cost-function {params.cost_function} \
        --l1-lambda {params.l1_lambda} \
        --target distances \
        --distances {input.distances} \
        --errors-by-timepoint {output.errors} \
        --coefficients-by-timepoint {output.coefficients} \
        --include-scores \
        --output {output.model} &> {log}
    """
SnakeMake From line 1038 of rules/builds.smk
1062
1063
1064
1065
1066
1067
shell:
    """
    python3 scripts/extract_minimal_models_by_distances.py \
        --model {input.model} \
        --output {output.model}
    """
SnakeMake From line 1062 of rules/builds.smk
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
shell:
    """
    python3 scripts/annotate_model_tables.py \
        --tip-attributes {input.attributes} \
        --model {input.model} \
        --errors-by-timepoint {input.errors} \
        --coefficients-by-timepoint {input.coefficients} \
        --annotated-errors-by-timepoint {output.errors} \
        --annotated-coefficients-by-timepoint {output.coefficients} \
        --delta-months {params.delta_months} \
        --annotations type="{wildcards.type}" sample="{wildcards.sample}" error_type="{params.error_type}"
    """
SnakeMake From line 1083 of rules/builds.smk
1103
1104
1105
1106
1107
1108
shell:
    """
    python3 scripts/collect_tables.py \
        --tables {input} \
        --output {output.tip_clade_table}
    """
SnakeMake From line 1103 of rules/builds.smk
1121
1122
1123
1124
1125
1126
1127
1128
shell:
    """
    python3 scripts/select_clades.py \
        --tip-attributes {input.attributes} \
        --tips-to-clades {input.tips_to_clades} \
        --delta-months {params.delta_months} \
        --output {output} &> {log}
    """
SnakeMake From line 1121 of rules/builds.smk
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
shell:
    """
    python3 src/fit_model.py \
        --tip-attributes {input.attributes} \
        --final-clade-frequencies {input.final_clade_frequencies} \
        --training-window {params.training_window} \
        --delta-months {params.delta_months} \
        --predictors {params.predictors} \
        --cost-function {params.cost_function} \
        --l1-lambda {params.l1_lambda} \
        --pseudocount {params.pseudocount} \
        --target clades \
        --output {output} &> {log}
    """
SnakeMake From line 1147 of rules/builds.smk
1200
1201
1202
1203
shell:
    """
    python3 scripts/plot_tree.py {input} {output} &> {log}
    """
SnakeMake From line 1200 of rules/builds.smk
1210
shell: "gs -dBATCH -dNOPAUSE -q -sDEVICE=pdfwrite -sOutputFile={output} {input}"
SnakeMake From line 1210 of rules/builds.smk
1221
1222
1223
1224
1225
1226
1227
shell:
    """
    python3 scripts/calculate_target_distances.py \
        --tip-attributes {input.attributes} \
        --delta-months {params.delta_months} \
        --output {output}
    """
SnakeMake From line 1221 of rules/builds.smk
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
shell:
    """
    python3 src/forecast_model.py \
        --tip-attributes {input.attributes} \
        --distances {input.distances} \
        --frequencies {input.frequencies} \
        --model {input.model} \
        --delta-months {params.delta_months} \
        --output-node-data {output.node_data} \
        --output-frequencies {output.frequencies}
    """
SnakeMake From line 1242 of rules/builds.smk
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
shell:
    """
    augur export \
        --tree {input.tree} \
        --metadata {input.metadata} \
        --node-data {input.node_data} {input.forecasts} \
        --colors {input.colors} \
        --auspice-config {input.auspice_config} \
        --output-tree {output.auspice_tree} \
        --output-meta {output.auspice_metadata} \
        --panels {params.panels} \
        --minify-json
    """
1294
1295
1296
1297
1298
1299
1300
1301
1302
shell:
    """
    python3 src/forecast_model.py \
        --tip-attributes {input.attributes} \
        --distances {input.distances} \
        --model {input.model} \
        --delta-months {params.delta_months} \
        --output-table {output.table}
    """
SnakeMake From line 1294 of rules/builds.smk
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
shell:
    """
    python3 src/fit_model.py \
        --tip-attributes {input.attributes} \
        --target distances \
        --distances {input.distances} \
        --fixed-model {input.model} \
        --errors-by-timepoint {output.errors} \
        --coefficients-by-timepoint {output.coefficients} \
        --include-scores \
        --output {output.model} &> {log}
    """
SnakeMake From line 1317 of rules/builds.smk
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
shell:
    """
    python3 scripts/annotate_model_tables.py \
        --tip-attributes {input.attributes} \
        --model {input.model} \
        --errors-by-timepoint {input.errors} \
        --coefficients-by-timepoint {input.coefficients} \
        --annotated-errors-by-timepoint {output.errors} \
        --annotated-coefficients-by-timepoint {output.coefficients} \
        --delta-months {params.delta_months} \
        --annotations type="{wildcards.type}" sample="{params.sample}" error_type="test"
    """
SnakeMake From line 1344 of rules/builds.smk
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
shell:
    """
    python3 scripts/plot_validation_figure_by_population.py \
        --tip-attributes {input.attributes} \
        --tips-to-clades {input.tips_to_clades} \
        --forecasts {input.forecasts} \
        --model-errors {input.model_errors} \
        --population {wildcards.type} \
        --sample {wildcards.sample} \
        --predictors {wildcards.predictors} \
        --output {output.figure} \
        --output-clades-table {output.clades} \
        --output-ranks-table {output.ranks}
    """
SnakeMake From line 1378 of rules/builds.smk
22
23
24
25
shell:
    """
    cd data/simulated/{wildcards.sample} && java -jar {SNAKEMAKE_DIR}/dist/santa.jar -seed={params.seed} {SNAKEMAKE_DIR}/{input.simulation_config}
    """
37
38
39
40
41
42
43
44
shell:
    """
    augur parse \
        --sequences {input.sequences} \
        --output-sequences {output.sequences} \
        --output-metadata {output.metadata} \
        --fields {params.fasta_fields}
    """
56
57
58
59
60
61
62
63
shell:
    """
    python3 scripts/standardize_simulated_sequence_dates.py \
        --metadata {input.metadata} \
        --start-year {params.start_year} \
       --generations-per-year {params.generations_per_year} \
        --output {output.metadata}
    """
79
80
81
82
83
84
85
86
87
88
shell:
    """
    augur filter \
        --sequences {input.sequences} \
        --metadata {input.metadata} \
        --min-date {params.min_date} \
        --group-by {params.group_by} \
        --sequences-per-group {params.viruses_per_month} \
        --output {output}
    """
 98
 99
100
101
102
103
104
shell:
    """
    python3 scripts/filter_simulated_metadata.py \
        --sequences {input.sequences} \
        --metadata {input.metadata} \
        --output {output.metadata}
    """
127
128
129
130
131
132
133
134
135
136
137
shell:
    """
    python3 {path_to_fauna}/vdb/download.py \
        --database vdb \
        --virus flu \
        --fasta_fields {params.fasta_fields} \
        --resolve_method split_passage \
        --select locus:{params.segment} lineage:seasonal_{params.lineage} \
        --path data/natural/{wildcards.sample} \
        --fstem original_sequences
    """
150
151
152
153
154
155
156
157
158
159
160
shell:
    """
    python3 {path_to_fauna}/tdb/download.py \
        --database {params.databases} \
        --virus flu \
        --subtype {params.lineage} \
        --select assay_type:{params.assay} \
        --path data/natural/{wildcards.sample} \
        --fstem complete \
        --ftype json
    """
173
174
175
176
177
178
179
180
181
182
183
shell:
    """
    python3 {path_to_fauna}/tdb/download.py \
        --database {params.databases} \
        --virus flu \
        --subtype {params.lineage} \
        --select assay_type:{params.assay} \
        --path data/natural/{wildcards.sample} \
        --fstem complete_fra \
        --ftype json
    """
196
197
198
199
200
201
202
shell:
    """
    python3 scripts/get_titers_by_passage.py \
        --titers {input.titers} \
        --passage-type {params.passage} \
        --output {output.titers}
    """
215
216
217
218
219
220
221
shell:
    """
    python3 scripts/get_titers_by_passage.py \
        --titers {input.titers} \
        --passage-type {params.passage} \
        --output {output.titers}
    """
233
234
235
236
237
238
239
240
shell:
    """
    augur parse \
        --sequences {input.sequences} \
        --output-sequences {output.sequences} \
        --output-metadata {output.metadata} \
        --fields {params.fasta_fields}
    """
254
255
256
257
258
259
260
261
262
263
shell:
    """
    augur filter \
        --sequences {input.sequences} \
        --metadata {input.metadata} \
        --min-length {params.min_length} \
        --exclude {params.exclude} \
        --exclude-where country=? region=? passage=egg \
        --output {output}
    """
272
273
274
275
276
277
shell:
    """
    python3 scripts/filter_strains_with_ambiguous_dates.py \
        --metadata {input.metadata} \
        --output {output.metadata}
    """
304
305
306
307
308
309
310
311
312
313
314
315
316
shell:
    """
    python3 scripts/select_strains.py \
        --sequences {input.sequences} \
        --metadata {input.metadata} \
        --segments {params.segment} \
        --include {input.include} \
        --lineage {params.lineage} \
        --time-interval {params.start_date} {params.end_date} \
        --viruses_per_month {params.viruses_per_month} \
        --titers {input.titers} \
        --output {output.strains}
    """
326
327
328
329
330
331
332
shell:
    """
    python3 scripts/filter_metadata_by_strains.py \
        --metadata {input.metadata} \
        --strains {input.strains} \
        --output {output.metadata}
    """
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import json
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors and weighted distances to the future")
    parser.add_argument("--model", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics")
    parser.add_argument("--errors-by-timepoint", help="data frame of cross-validation errors by validation timepoint")
    parser.add_argument("--coefficients-by-timepoint", help="data frame of coefficients by validation timepoint")
    parser.add_argument("--annotated-errors-by-timepoint", help="annotated model errors by timepoint")
    parser.add_argument("--annotated-coefficients-by-timepoint", help="annotated model coefficients by timepoint")
    parser.add_argument("--delta-months", type=int, help="number of months to project clade frequencies into the future")
    parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs")

    args = parser.parse_args()

    # Load tip attributes to calculate within-timepoint diversity.
    tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"])

    # Calculate weighted distance of each timepoint to itself.
    tips["average_distance_to_present"] = tips["weighted_distance_to_present"] * tips["frequency"]
    distances_to_present_per_timepoint = tips.groupby("timepoint")["average_distance_to_present"].sum().reset_index()

    print(distances_to_present_per_timepoint.head())

    # Load the model JSON to get access to projected frequencies for tips.
    with open(args.model, "r") as fh:
        model = json.load(fh)

    # Collect all projected frequencies and weighted distances, to enable
    # calculation of weighted average distances within and between seasons.
    df = pd.concat([
        pd.DataFrame(scores["validation_data"]["y_hat"])
        for scores in model["scores"]
    ])

    # Prepare to calculate weighted distance of each timepoint's projected
    # future to its observed future timepoint.
    df["average_distance_to_future"] = df["weighted_distance_to_future"] * df["projected_frequency"]

    # Sum the scaled weighted distances to get average distances per timepoint.
    distances_per_timepoint = df.groupby("timepoint").aggregate({
        "average_distance_to_future": "sum"
    }).reset_index()

    # Prepare timepoint for joins with model errors.
    distances_per_timepoint["timepoint"] = pd.to_datetime(distances_per_timepoint["timepoint"])

    # Load the original model table output for validation/test errors.
    errors = pd.read_csv(args.errors_by_timepoint, sep="\t", parse_dates=["validation_timepoint"])
    errors["future_timepoint"] = errors["validation_timepoint"] + pd.DateOffset(months=args.delta_months)

    # Annotate information about the present's estimate of the future.
    print("Errors: %s" % str(errors.shape))
    errors = errors.merge(
        distances_per_timepoint.loc[:, ["timepoint", "average_distance_to_future"]].copy(),
        left_on=["validation_timepoint"],
        right_on=["timepoint"]
    ).drop(columns=["timepoint"])
    print("Errors with distance to future: %s" % str(errors.shape))

    # Annotate information about the future's distance to itself for the
    # present timepoints.
    errors = errors.merge(
        distances_to_present_per_timepoint,
        left_on=["future_timepoint"],
        right_on=["timepoint"]
    ).rename(columns={
        "average_distance_to_present": "average_diversity_in_future"
    })
    #drop(columns=["timepoint", "future_timepoint"]).
    print("Errors with diversity in future: %s" % str(errors.shape))

    # Load coefficients to which annotations will be added.
    coefficients = pd.read_csv(args.coefficients_by_timepoint, sep="\t")

    # Add any additional annotations requested by the user in the format of
    # "key=value" pairs where each key becomes a new column with the given
    # value.
    if args.annotations:
        for annotation in args.annotations:
            key, value = annotation.split("=")
            errors[key] = value
            coefficients[key] = value

    # Save annotated tables.
    errors.to_csv(args.annotated_errors_by_timepoint, sep="\t", header=True, index=False)
    coefficients.to_csv(args.annotated_coefficients_by_timepoint, sep="\t", header=True, index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--tip-attributes", required=True, help="table of tip attributes from one or more timepoints")
    parser.add_argument("--output", required=True, help="table of tip attributes annotated with a 'naive' predictor")

    args = parser.parse_args()

    # Annotate a predictor for a naive model with no growth.
    df = pd.read_csv(args.tip_attributes, sep="\t")
    df["naive"] = 0.0
    df.to_csv(args.output, sep="\t", index=False)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
from augur.frequency_estimators import TreeKdeFrequencies
from augur.utils import write_json
import Bio.Phylo
from collections import defaultdict
import json
import numpy as np


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Calculate the change in frequency for clades over time",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tree", required=True, help="Newick tree")
    parser.add_argument("--frequencies", required=True, help="frequencies JSON")
    parser.add_argument("--frequency-method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies")
    parser.add_argument("--clades", help="JSON of clade annotations for nodes in the given tree")
    parser.add_argument("--delta-pivots", type=int, default=1, help="number of frequency pivots to look back in time for change in frequency calculation")
    parser.add_argument("--output", required=True, help="JSON of delta frequency annotations for nodes in the given tree")

    args = parser.parse_args()

    # Load the tree.
    tree = Bio.Phylo.read(args.tree, "newick")

    # Load frequencies.
    with open(args.frequencies, "r") as fh:
        frequencies_json = json.load(fh)

    if args.frequency_method == "kde":
        kde_frequencies = TreeKdeFrequencies.from_json(frequencies_json)
        frequencies = kde_frequencies.frequencies

        # Load clades.
        with open(args.clades, "r") as fh:
            clades_json = json.load(fh)

        clades_by_node = {
            key: value["clade_membership"]
            for key, value in clades_json["nodes"].items()
        }

        # Calculate the total frequency per clade at the most recent timepoint and
        # requested timepoint in the past using non-zero tip frequencies.
        current_clade_frequencies = defaultdict(float)
        previous_clade_frequencies = defaultdict(float)

        for tip in tree.find_clades(terminal=True):
            # Add tip to current clade frequencies.
            current_clade_frequencies[clades_by_node[tip.name]] += frequencies[tip.name][-1]

            # Add tip to previous clade frequencies.
            previous_clade_frequencies[clades_by_node[tip.name]] += frequencies[tip.name][-(args.delta_pivots + 1)]

        # Determine the total time that elapsed between the current and past timepoint.
        delta_time = kde_frequencies.pivots[-1] - kde_frequencies.pivots[-(args.delta_pivots + 1)]

        # Calculate the change in frequency over time elapsed for each clade.
        delta_frequency_by_clade = {}
        for clade, current_frequency in current_clade_frequencies.items():
            # If the current clade was not observed in the previous timepoint, it
            # will have a zero frequency.
            delta_frequency_by_clade[clade] = (current_frequency - previous_clade_frequencies.get(clade, 0.0)) / delta_time

        # Assign clade delta frequencies to all corresponding tips and internal nodes.
        delta_frequency = {}
        for node in tree.find_clades(terminal=True):
            delta_frequency[node.name] = {
                "delta_frequency": delta_frequency_by_clade.get(clades_by_node[node.name], 0.0)
            }
    else:
        frequencies = frequencies_json

        # Determine the total time that elapsed between the current and past timepoint.
        delta_time = frequencies["pivots"][-1] - frequencies["pivots"][-(args.delta_pivots + 1)]

        delta_frequency = {}
        for node in tree.find_clades(terminal=True):
            delta_frequency[node.name] = {
                "delta_frequency": (frequencies[node.name]["global"][-1] - frequencies[node.name]["global"][-(args.delta_pivots + 1)]) / delta_time
            }

    # Write out the node annotations.
    write_json({"nodes": delta_frequency}, args.output)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import numpy as np
import pandas as pd


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Calculate pairwise distances between samples at adjacent timepoints (t and t - delta months)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("--tip-attributes", required=True, help="a tab-delimited file describing tip attributes at one or more timepoints")
    parser.add_argument("--delta-months", required=True, nargs="+", type=int, help="number of months between timepoints to be compared")
    parser.add_argument("--output", help="tab-delimited file of pairwise distances between tips in timepoints separate by the given delta time", required=True)
    parser.add_argument("--sequence-attribute-name", default="aa_sequence", help="attribute name of sequences to compare")
    args = parser.parse_args()

    # Load tip attributes.
    tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"])

    # Open output file handle to enable streaming distances to disk instead of
    # storing them in memory.
    output_handle = open(args.output, "w")
    output_handle.write("\t".join(["sample", "other_sample", "distance"]) + "\n")

    # Calculate pairwise distances between all tips within a timepoint and at
    # the next timepoint as defined by the given delta months.
    for timepoint, timepoint_df in tips.groupby("timepoint"):
        current_tips = [
            tuple(values)
            for values in timepoint_df.loc[:, ["strain", args.sequence_attribute_name]].values.tolist()
        ]
        comparison_tips = current_tips

        for delta_month in args.delta_months:
            future_timepoint_df = tips[tips["timepoint"] == (timepoint + pd.DateOffset(months=delta_month))]
            future_tips = [
                tuple(values)
                for values in future_timepoint_df.loc[:, ["strain", args.sequence_attribute_name]].values.tolist()
            ]
            comparison_tips = comparison_tips + future_tips

        comparison_tips = list(set(comparison_tips))
        for current_tip, current_tip_sequence in current_tips:
            current_tip_sequence_array = np.frombuffer(current_tip_sequence.encode(), dtype="S1")

            for future_tip, future_tip_sequence in comparison_tips:
                future_tip_sequence_array = np.frombuffer(future_tip_sequence.encode(), dtype="S1")
                distance = (current_tip_sequence_array != future_tip_sequence_array).sum()
                output_handle.write("\t".join([current_tip, future_tip, str(distance)]) + "\n")

    # Close output.
    output_handle.close()
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import numpy as np
import pandas as pd
import re
import sys


class BadTransformException(Exception):
    pass


def parse_transform(transform):
    """Parse a transform string into its corresponding new column name, Python function, and original column name.

    Return `None` for the Python function if the requested function is not valid.

    Parameters
    ----------
    transform : str
        transformation definition string (e.g., "log_lbi=log(lbi)")

    Returns
    -------
    str, callable, str
        new column name, transformation function, and original column name

    >>> parse_transform("log_lbi=log(lbi)")
    ('log_lbi', <ufunc 'log'>, 'lbi')
    >>> parse_transform("fake_col=fake(col)")
    Traceback (most recent call last):
        ...
    collect_tables.BadTransformException: the requested function was invalid

    >>> parse_transform("bad_transform")
    Traceback (most recent call last):
        ...
    collect_tables.BadTransformException: the requested transform was malformed

    """
    match = re.match(r"(?P<new_column>\w+)=(?P<function>\w+)\((?P<column>\w+)\)", transform)
    if match is None:
        raise BadTransformException("the requested transform was malformed")

    new_column, function_string, column = match.groups()
    function = getattr(np, function_string, None)
    if function is None:
        raise BadTransformException("the requested function was invalid")

    return new_column, function, column


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Collect two or more data frame tables",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tables", nargs="+", required=True, help="tab-delimited files with the same columns to be collected into a single file")
    parser.add_argument("--transforms", nargs="+", help="a list of new columns to create by transformation of existing columns (e.g., 'log_lbi=log(lbi)')")
    parser.add_argument("--output", required=True, help="tab-delimited output file collecting the given input tables")
    args = parser.parse_args()

    # Concatenate tip attributes across all timepoints.
    df = pd.concat([pd.read_table(table) for table in args.tables], ignore_index=True)

    # Apply transformations.
    if args.transforms:
        for transform in args.transforms:
            try:
                new_column, transform_function, column = parse_transform(transform)
                df[new_column] = transform_function(df[column])
            except BadTransformException as e:
                print(f"Error: Could not apply transformation '{transform}' because {e}", file=sys.stderr)
            except Exception as e:
                print(f"Error: Failed to apply transformation '{transform}' ({e})", file=sys.stderr)

    df.to_csv(args.output, sep="\t", index=False)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tables", nargs="+", help="tables to concatenate")
    parser.add_argument("--separator", default="\t", help="separator between columns in the given tables")
    parser.add_argument("--output", help="concatenated table")

    args = parser.parse_args()

    # Concatenate tables.
    df = pd.concat([
        pd.read_csv(table_file, sep=args.separator)
        for table_file in args.tables
    ], ignore_index=True, sort=True)

    df.to_csv(args.output, sep=args.separator, header=True, index=False)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
from augur.reconstruct_sequences import load_alignments
from augur.utils import write_json
import Bio.Phylo


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Convert translation FASTA to a node data JSON",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("--tree", required=True, help="Newick file for the tree used to construct the given node data JSONs")
    parser.add_argument("--alignment", nargs="+", help="sequence(s) to be used, supplied as FASTA files", required=True)
    parser.add_argument('--gene-names', nargs="+", type=str, help="names of the sequences in the alignment, same order assumed", required=True)
    parser.add_argument("--output", help="JSON file with translated sequences by node", required=True)
    parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output JSON")
    parser.add_argument("--attribute-name", default="aa_sequence", help="name of attribute to store the complete amino acid sequence of each node")
    args = parser.parse_args()

    # Load tree.
    tree = Bio.Phylo.read(args.tree, "newick")

    # Load sequences.
    alignments = load_alignments(args.alignment, args.gene_names)

    # Concatenate translated sequences into a single sequence indexed by sample name.
    is_node_terminal = {node.name: node.is_terminal() for node in tree.find_clades()}

    translations = {}
    for gene in args.gene_names:
        alignment = alignments[gene]

        for record in alignment:
            if is_node_terminal[record.name] or args.include_internal_nodes:
                # Initialize new samples by name with an empty string.
                if record.name not in translations:
                    translations[record.name] = {args.attribute_name:  ""}

                # Append the current gene's amino acid sequence to the current
                # string for this sample.
                translations[record.name][args.attribute_name] += str(record.seq)

    # Write out the node annotations.
    write_json({"nodes": translations}, args.output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
from augur.utils import read_node_data, write_json
import Bio.Align
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from collections import Counter
import numpy as np
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Calculate consensus sequence for all non-zero frequency strains and the distance of each strain from the resulting consensus.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--sequences", required=True, help="node data JSON containing sequences to find a consensus for")
    parser.add_argument("--frequencies", required=True, help="table of strain frequencies at the current timepoint")
    parser.add_argument("--sequence-attribute", default="aa_sequence", help="attribute in node data JSON containing the sequence data to use")
    parser.add_argument("--frequency-attribute", default="kde_frequency", help="attribute in frequency table representing the frequency data to use")
    parser.add_argument("--output", required=True, help="node data JSON with consensus sequence and distances from the consensus per strain")

    args = parser.parse_args()

    # Load sequence data from a node data JSON file.
    node_sequences = read_node_data(args.sequences)

    # Load frequency data.
    frequencies = pd.read_csv(args.frequencies, sep="\t")

    # Select names of strains with non-zero frequencies.
    strains = set(frequencies.query(f"{args.frequency_attribute} > 0.0")["strain"].values)

    # Select sequences for strains with non-zero frequencies.
    sequences = Bio.Align.MultipleSeqAlignment([
        SeqRecord(
            Seq(
                record[args.sequence_attribute],
            ),
            id=strain
        )
        for strain, record in node_sequences["nodes"].items()
        if strain in strains
    ])

    # Output will store the consensus sequence and the distance of each strain
    # to the consensus.
    output = {
        "nodes": {}
    }

    # Calculate the consensus sequence using a majority-rule approach where we
    # take the most common value in each column.
    consensus = "".join(
        Counter(sequences[:, i]).most_common(1)[0][0]
        for i in range(sequences.get_alignment_length())
    )
    output["consensus"] = consensus

    # Calculate the distance of each strain sequence from the consensus.
    consensus_array = np.frombuffer(consensus.encode(), dtype="S1")
    for sequence in sequences:
        sequence_array = np.frombuffer(str(sequence.seq).encode(), dtype="S1")
        distance = int((consensus_array != sequence_array).sum())
        output["nodes"][sequence.id] = {
            "distance_from_consensus": distance
        }

    # Output the results.
    write_json(output, args.output)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import argparse
import json


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model", required=True, help="JSON for a complete fitness model output")
    parser.add_argument("--output", required=True, help="JSON for a minimal fitness model (coefficients only)")

    args = parser.parse_args()

    with open(args.model, "r") as fh:
        model = json.load(fh)

    if "scores" in model:
        del model["scores"]

    with open(args.output, "w") as oh:
        json.dump(model, oh, indent=1)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import Bio
import Bio.SeqIO


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Extract sample sequences by name",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--sequences", required=True, help="FASTA file of all sample sequences")
    parser.add_argument("--samples", required=True, help="text file of samples names with one name per line")
    parser.add_argument("--output", required=True, help="FASTA file of extracted sample sequences")
    args = parser.parse_args()

    with open(args.samples) as infile:
        samples = set([line.strip() for line in infile])

    with open(args.output, 'w') as outfile:
        for seq in Bio.SeqIO.parse(args.sequences, 'fasta'):
            if seq.name in samples:
                Bio.SeqIO.write(seq, outfile, 'fasta')
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--metadata", required=True, help="table of metadata to be filtered based on a date column")
    parser.add_argument("--strains", required=True, help="text file with one strain per line that should be included in the output")
    parser.add_argument("--output", required=True, help="table of filtered metadata")

    args = parser.parse_args()

    metadata = pd.read_table(args.metadata)
    strains = pd.read_table(args.strains, header=None, names=["strain"])

    selected_metadata = strains.merge(metadata, how="left", on="strain")
    selected_metadata.to_csv(args.output, sep="\t", index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import argparse
import Bio.SeqIO
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sequences", help="simulated sequences that have already been filtered")
    parser.add_argument("--metadata", help="original metadata table for simulated sequences")
    parser.add_argument("--output", help="filtered metadata where only samples present in the given sequences are included")

    args = parser.parse_args()

    # Get a list of all samples that passed the sequence filtering step.
    sequences = Bio.SeqIO.parse(args.sequences, "fasta")
    sample_ids = [sequence.id for sequence in sequences]

    # Load all metadata.
    metadata = pd.read_csv(args.metadata, sep="\t")
    filtered_metadata = metadata[metadata["strain"].isin(sample_ids)].copy()

    # Save only the metadata records that have entries in the filtered sequences.
    filtered_metadata.to_csv(args.output, sep="\t", header=True, index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--metadata", required=True, help="table of metadata to be filtered based on a date column")
    parser.add_argument("--date-field", default="date", help="name of date column in the metadata")
    parser.add_argument("--output", required=True, help="table of filtered metadata")

    args = parser.parse_args()

    df = pd.read_csv(args.metadata, sep="\t")

    # Exclude strains with ambiguous collection dates.
    df[~df[args.date_field].str.contains("XX")].to_csv(args.output, sep="\t", header=True, index=False)
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
from augur.utils import read_metadata, get_numerical_dates
from augur.frequencies import TreeKdeFrequencies
import Bio
import Bio.Phylo
import datetime
import json
import numpy as np
import os
import sys


def get_time_interval_as_floats(time_interval):
    """
    Converts the given datetime interval to start and end floats.

    Returns:
        start_date (float): the start of the given time interval
        end_date (float): the end of the given time interval
    """
    start_date = time_interval[1].year + (time_interval[1].month - 1) / 12.0
    end_date = time_interval[0].year + (time_interval[0].month - 1) / 12.0
    return start_date, end_date


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("tree", help="Newick tree")
    parser.add_argument("metadata", help="tab-delimited metadata for tips in the given tree including a date field")
    parser.add_argument("frequencies", help="JSON with frequencies estimated from the given tree and used to estimate the given parameters")
    parser.add_argument("--narrow-bandwidth", type=float, default=1 / 12.0, help="the bandwidth for the narrow KDE")
    parser.add_argument("--wide-bandwidth", type=float, default=3 / 12.0, help="the bandwidth for the wide KDE")
    parser.add_argument("--proportion-wide", type=float, default=0.2, help="the proportion of the wide bandwidth to use in the KDE mixture model")
    parser.add_argument("--pivot-frequency", type=int, default=1, help="number of months between pivots")
    parser.add_argument("--start-date", help="the start of the interval to estimate frequencies across")
    parser.add_argument("--end-date", help="the end of the interval to estimate frequencies across")
    parser.add_argument("--include-internal-nodes", action="store_true", help="calculate frequencies for internal nodes as well as tips")
    parser.add_argument("--weights", help="a dictionary of key/value mappings in JSON format used to weight tip frequencies")
    parser.add_argument("--weights-attribute", help="name of the attribute on each tip whose values map to the given weights dictionary")

    parser.add_argument("--precision", type=int, default=6, help="number of decimal places to retain in frequency estimates")
    parser.add_argument("--censored", action="store_true", help="calculate censored frequencies at each pivot")

    args = parser.parse_args()

    # Load tree.
    tree = Bio.Phylo.read(args.tree, "newick")

    # Load metadata.
    metadata, columns = read_metadata(args.metadata)
    dates = get_numerical_dates(metadata, fmt='%Y-%m-%d')

    # Annotate tree with dates and other metadata.
    for tip in tree.find_clades(terminal=True):
        tip.attr = {"num_date": np.mean(dates[tip.name])}

        # Annotate tips with metadata to enable filtering and weighting of
        # frequencies by metadata attributes.
        for key, value in metadata[tip.name].items():
            tip.attr[key] = value

    # Convert start and end dates to floats from time interval format.
    if args.start_date is not None and args.end_date is not None:
        # Convert the string time interval to a datetime instance and then to floats.
        time_interval = [
            datetime.datetime.strptime(time, "%Y-%m-%d")
            for time in (args.end_date, args.start_date)
        ]

        start_date, end_date = get_time_interval_as_floats(time_interval)
    else:
        start_date = end_date = None

    # Load weights if they have been provided.
    if args.weights:
        with open(args.weights, "r") as fh:
            weights = json.load(fh)

        weights_attribute = args.weights_attribute
    else:
        weights = None
        weights_attribute = None

    # Estimate frequencies.
    frequencies = TreeKdeFrequencies(
        sigma_narrow=args.narrow_bandwidth,
        sigma_wide=args.wide_bandwidth,
        proportion_wide=args.proportion_wide,
        pivot_frequency=args.pivot_frequency,
        start_date=start_date,
        end_date=end_date,
        weights=weights,
        weights_attribute=weights_attribute,
        include_internal_nodes=args.include_internal_nodes,
        censored=args.censored
    )
    frequencies.estimate(tree)

    # Export frequencies to JSON.
    json_frequencies = frequencies.to_json()

    # Set precision of frequency estimates.
    for clade in json_frequencies["data"]["frequencies"]:
        json_frequencies["data"]["frequencies"][clade] = np.around(
            np.array(
                json_frequencies["data"]["frequencies"][clade]
            ),
            args.precision
        ).tolist()

    with open(args.frequencies, "w") as oh:
        json.dump(json_frequencies, oh, indent=1, sort_keys=True)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import Bio.Phylo
import json
import pandas as pd


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Convert frequencies JSON to a data frame",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tree", required=True, help="Newick file for the tree used to estimate the given frequencies")
    parser.add_argument("--frequencies", required=True, help="frequencies JSON")
    parser.add_argument("--method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies")
    parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs")
    parser.add_argument("--output", required=True, help="tab-delimited file with frequency per node at the last available timepoint")
    parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output table")
    parser.add_argument("--minimum-frequency", type=float, default=1e-5, help="minimum frequency to keep below which values will be zeroed and all others renormalized to sum to one")
    args = parser.parse_args()

    # Load tree.
    tree = Bio.Phylo.read(args.tree, "newick")

    # Load frequencies.
    with open(args.frequencies, "r") as fh:
        frequencies_json = json.load(fh)

    if args.method == "kde":
        frequencies = frequencies_json["data"]["frequencies"]
    else:
        frequencies = {
            node_name: region_frequencies["global"]
            for node_name, region_frequencies in frequencies_json.items()
            if node_name not in ["pivots", "counts", "generated_by"]
        }

    # Collect the last frequency for each node keeping only terminal nodes
    # (tips) unless internal nodes are also requested.
    frequency_key = "%s_frequency" % args.method
    records = [
        {
            "strain": node.name,
            frequency_key: float(frequencies[node.name][-1]),
            "is_terminal": node.is_terminal()
        }
        for node in tree.find_clades()
        if args.include_internal_nodes or node.is_terminal()
    ]

    # Convert frequencies data into a data frame.
    df = pd.DataFrame(records)

    # Replace records whose frequency values are below the requested minimum
    # with zeros and renormalize the remaining records to sum to one.
    to_zero = df[frequency_key] < args.minimum_frequency
    not_to_zero = ~to_zero
    df.loc[to_zero, frequency_key] = 0.0
    df.loc[not_to_zero, frequency_key] = df.loc[not_to_zero, frequency_key] / df.loc[not_to_zero, frequency_key].sum()

    # Add any additional annotations requested by the user in the format of
    # "key=value" pairs where each key becomes a new column with the given
    # value.
    if args.annotations:
        for annotation in args.annotations:
            key, value = annotation.split("=")
            df[key] = value

    # Save the table.
    df.to_csv(args.output, sep="\t", float_format="%.6f", index=False, header=True)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--titers", required=True, help="JSON of complete titer records to be filtered by passage type")
    parser.add_argument("--passage-type", required=True, help="type of passage for viruses used in titer assays")
    parser.add_argument("--output", required=True, help="table of filtered titer records by passage type")

    args = parser.parse_args()

    df = pd.read_json(args.titers)
    passaged = (df["serum_passage_category"] == args.passage_type)
    tdb_passaged = df["index"].apply(lambda index: isinstance(index, list) and args.passage_type in index)
    tsv_fields = [
        "virus_strain",
        "serum_strain",
        "serum_id",
        "source",
        "titer",
        "assay_type"
    ]

    titers_df = df.loc[(passaged | tdb_passaged), tsv_fields]
    titers_df.to_csv(args.output, sep="\t", header=False, index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--node-data", required=True, help="table of node data from one or more timepoints")
    parser.add_argument("--kde-frequencies", required=True, help="table of KDE frequencies by strain name and timepoint")
    parser.add_argument("--diffusion-frequencies", required=True, help="table of diffusion frequencies by strain name and timepoint")
    parser.add_argument("--preferred-frequency-method", choices=["kde", "diffusion"], help="specify which frequency method should be used for the primary frequency column")
    parser.add_argument("--output", required=True, help="table of merged node data and frequencies")

    args = parser.parse_args()

    node_data = pd.read_table(args.node_data)
    kde_frequencies = pd.read_table(args.kde_frequencies)
    diffusion_frequencies = pd.read_table(args.diffusion_frequencies)
    df = node_data.merge(
        kde_frequencies,
        how="inner",
        on=["strain", "timepoint", "is_terminal"]
    ).merge(
        diffusion_frequencies,
        how="inner",
        on=["strain", "timepoint", "is_terminal"]
    )

    # Annotate frequency by the preferred method if there isn't already a
    # frequency column defined.
    if "frequency" not in df.columns:
        df["frequency"] = df["%s_frequency" % args.preferred_frequency_method]

    df = df[df["frequency"] > 0.0].copy()
    df.to_csv(args.output, sep="\t", index=False, header=True)
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
from augur.utils import read_node_data
import Bio.Phylo
import pandas as pd


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Convert node data JSONs to a data frame",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tree", required=True, help="Newick file for the tree used to construct the given node data JSONs")
    parser.add_argument("--metadata", help="file with metadata associated with viral sequences, one for each segment")
    parser.add_argument("--jsons", nargs="+", required=True, help="node data JSON(s) from augur")
    parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs")
    parser.add_argument("--excluded-fields", nargs="+", help="names of columns to omit from output table")
    parser.add_argument("--output", required=True, help="tab-delimited file collecting all given node data")
    parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output table")
    args = parser.parse_args()

    # Load tree.
    tree = Bio.Phylo.read(args.tree, "newick")

    # Load metadata for samples.
    metadata = pd.read_csv(args.metadata, sep="\t")

    # Load one or more node data JSONs into a single dictionary indexed by node name.
    node_data = read_node_data(args.jsons)

    # Convert node data into a data frame.
    # Data are initially loaded with one column per node.
    # Transposition converts the table to the expected one row per node format.
    df = pd.DataFrame(node_data["nodes"]).T.rename_axis("strain").reset_index()

    # Annotate node data with per sample metadata.
    df = df.merge(metadata, on="strain", suffixes=["", "_metadata"])

    # Remove excluded fields if they are in the data frame.
    df = df.drop(columns=[field for field in args.excluded_fields if field in df.columns])

    # Annotate the tip/internal status of each node using the tree.
    node_terminal_status_by_name = {node.name: node.is_terminal() for node in tree.find_clades()}
    df["is_terminal"] = df["strain"].map(node_terminal_status_by_name)

    # Eliminate internal nodes if they have not been requested.
    if not args.include_internal_nodes:
        df = df[df["is_terminal"]].copy()

    # Add any additional annotations requested by the user in the format of
    # "key=value" pairs where each key becomes a new column with the given
    # value.
    if args.annotations:
        for annotation in args.annotations:
            key, value = annotation.split("=")
            df[key] = value

    # Save the table.
    df.to_csv(args.output, sep="\t", index=False, header=True)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from augur.frequency_estimators import TreeKdeFrequencies
from augur.reconstruct_sequences import load_alignments
from augur.utils import annotate_parents_for_tree, write_json
import Bio.Phylo
import Bio.SeqIO
import hashlib
import json
import pandas as pd

# Magic number of maximum length of SHA hash to keep for each clade.
MAX_HASH_LENGTH = 7


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Find clades in a tree by distinct amino acid haplotypes",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tree", required=True, help="Newick tree to identify clades in")
    parser.add_argument("--translations", required=True, nargs="+", help="FASTA file(s) of amino acid sequences per node")
    parser.add_argument("--gene-names", required=True, nargs="+", help="gene names corresponding to translations provided")
    parser.add_argument("--output", required=True, help="JSON of clade annotations for nodes in the given tree")
    parser.add_argument("--output-tip-clade-table", help="optional table of all clades per tip in the tree")
    parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the tip clade output table in the format of 'key=value' pairs")

    args = parser.parse_args()

    # Load the tree.
    tree = Bio.Phylo.read(args.tree, "newick")
    tree = annotate_parents_for_tree(tree)

    # Load translations for nodes in the given tree and index them by gene name and node name.
    translations = load_alignments(args.translations, args.gene_names)
    translations_by_gene_name = {}
    for gene in translations:
        translations_by_gene_name[gene] = {}
        for seq in translations[gene]:
            translations_by_gene_name[gene][seq.name] = str(seq.seq)

    clades = {}
    for node in tree.find_clades(order="preorder", terminal=False):
        # Assign the current node a clade id based on the hash of its
        # full-length amino acid sequence.
        node_sequence = "".join([translations_by_gene_name[gene][node.name] for gene in args.gene_names])
        clades[node.name] = {"clade_membership": hashlib.sha256(node_sequence.encode()).hexdigest()[:MAX_HASH_LENGTH]}

        # Assign the current node's clade id to all of its terminal children.
        for child in node.clades:
            if child.is_terminal():
                clades[child.name] = clades[node.name]

    # Count unique clade groups.
    distinct_clades = {clade["clade_membership"] for clade in clades.values()}
    print("Found %i distinct clades" % len(distinct_clades))

    # Write out the node annotations.
    write_json({"nodes": clades}, args.output)

    # Output the optional tip-to-clade table, if requested.
    if args.output_tip_clade_table:
        records = []
        for tip in tree.find_clades(terminal=True):
            # Note the tip's own clade assignment which may be distinct from its
            # parent's.
            depth = 0
            records.append([tip.name, clades[tip.name]["clade_membership"], depth])

            parent = tip.parent
            depth += 1
            while True:
                records.append([tip.name, clades[parent.name]["clade_membership"], depth])

                if parent == tree.root:
                    break

                parent = parent.parent
                depth += 1

        df = pd.DataFrame(records, columns=["tip", "clade_membership", "depth"])
        df = df.drop_duplicates(subset=["tip", "clade_membership"])

        # Add any additional annotations requested by the user in the format of
        # "key=value" pairs where each key becomes a new column with the given
        # value.
        if args.annotations:
            for annotation in args.annotations:
                key, value = annotation.split("=")
                df[key] = value

        df.to_csv(args.output_tip_clade_table, sep="\t", index=False)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from augur.utils import write_json
import pandas as pd


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Normalize fitness by timepoint frequencies for samples in simulated populations.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--metadata", required=True, help="file with metadata associated with viral sequences, one for each segment")
    parser.add_argument("--frequencies-table", required=True, help="frequencies table for the current timepoint")
    parser.add_argument("--frequency-method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies")
    parser.add_argument("--output", required=True, help="JSON of normalized fitness per sample")

    args = parser.parse_args()

    # Load metadata.
    metadata = pd.read_csv(args.metadata, sep="\t")

    # Load frequencies.
    frequencies = pd.read_csv(args.frequencies_table, sep="\t")

    # Filter samples to those with nonzero frequencies at the current timepoint.
    nonzero_frequencies = frequencies[frequencies["%s_frequency" % args.frequency_method] > 0].copy()

    # Merge extent sample frequencies with metadata containing fitnesses.
    nonzero_metadata = nonzero_frequencies.merge(
        metadata,
        on="strain"
    )

    # Normalize fitness by maximum fitness.
    nonzero_metadata["normalized_fitness"] = nonzero_metadata["fitness"] / nonzero_metadata["fitness"].max()

    # Prepare dictionary of normalized fitnesses by sample.
    normalized_fitness = {
        strain: {"normalized_fitness": fitness}
        for strain, fitness in nonzero_metadata.loc[:, ["strain", "normalized_fitness"]].values
    }

    print("Raw fitness: %.2f +/- %.2f" % (nonzero_metadata["fitness"].mean(),
                                          nonzero_metadata["fitness"].std()))
    print("Normalized fitness: %.2f +/- %.2f" % (nonzero_metadata["normalized_fitness"].mean(),
                                                 nonzero_metadata["normalized_fitness"].std()))

    # Save normalized fitness as a node data JSON.
    write_json({"nodes": normalized_fitness}, args.output)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
import argparse

from augur.distance import read_distance_map, get_distance_between_nodes
from augur.frequency_estimators import TreeKdeFrequencies
from augur.reconstruct_sequences import load_alignments
from augur.utils import annotate_parents_for_tree, read_node_data, write_json

import Bio.Phylo
from collections import defaultdict
import json


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tree", help="Newick tree", required=True)
    parser.add_argument("--frequencies", help="frequencies JSON", required=True)
    parser.add_argument("--alignment", nargs="+", help="sequence(s) to be used, supplied as FASTA files", required=True)
    parser.add_argument('--gene-names', nargs="+", type=str, help="names of the sequences in the alignment, same order assumed", required=True)
    parser.add_argument("--attribute-name", nargs="+", help="name to store distances associated with the given distance map; multiple attribute names are linked to corresponding positional comparison method and distance map arguments", required=True)
    parser.add_argument("--map", nargs="+", help="JSON providing the distance map between sites and, optionally, sequences present at those sites; the distance map JSON minimally requires a 'default' field defining a default numeric distance and a 'map' field defining a dictionary of genes and one-based coordinates", required=True)
    parser.add_argument("--date-annotations", help="JSON of branch lengths and date annotations from augur refine for samples in the given tree; required for comparisons to earliest or latest date", required=True)
    parser.add_argument("--years-back-to-compare", type=int, help="number of years prior to the current season to search for samples to calculate pairwise comparisons with", required=True)
    parser.add_argument("--output", help="JSON file with calculated distances stored by node name and attribute name", required=True)

    args = parser.parse_args()

    # Load tree and annotate parents.
    tree = Bio.Phylo.read(args.tree, "newick")
    tree = annotate_parents_for_tree(tree)

    # Load frequencies.
    with open(args.frequencies, "r") as fh:
        frequencies_json = json.load(fh)

    frequencies = TreeKdeFrequencies.from_json(frequencies_json)
    pivots = frequencies.pivots

    # Identify pivots that belong within our search window for past samples.
    past_pivot_indices = (pivots < pivots[-1]) & (pivots >= pivots[-1] - args.years_back_to_compare)

    # Load sequences.
    alignments = load_alignments(args.alignment, args.gene_names)

    # Index sequences by node name and gene.
    sequences_by_node_and_gene = defaultdict(dict)
    for gene, alignment in alignments.items():
        for record in alignment:
            sequences_by_node_and_gene[record.name][gene] = str(record.seq)

    # Load date annotations and annotate tree with them.
    date_annotations = read_node_data(args.date_annotations)
    for node in tree.find_clades():
        node.attr = date_annotations["nodes"][node.name]
        node.attr["num_date"] = node.attr["numdate"]

    # Identify samples to compare including those in the current timepoint
    # (pivot) and those in previous timepoints.
    current_samples = []
    past_samples = []
    date_by_sample = {}
    for tip in tree.find_clades(terminal=True):
        # Samples with nonzero frequencies in the last timepoint are current
        # samples. Those with one or more nonzero frequencies in the search
        # window of the past timepoints are past samples.
        if frequencies.frequencies[tip.name][-1] > 0:
            current_samples.append(tip.name)
        elif (frequencies.frequencies[tip.name][past_pivot_indices] > 0).sum() > 0:
            past_samples.append(tip.name)

        date_by_sample[tip.name] = tip.attr["numdate"]

    print("Expecting %i comparisons" % (len(current_samples) * len(past_samples) * len(args.attribute_name)))

    distances_by_node = {}
    distance_map_names = []
    comparisons = 0
    for attribute, distance_map_file in zip(args.attribute_name, args.map):
        # Load the given distance map.
        distance_map = read_distance_map(distance_map_file)
        distance_map_names.append(distance_map.get("name", distance_map_file))

        for current_sample in current_samples:
            if not current_sample in distances_by_node:
                distances_by_node[current_sample] = {}

            if not attribute in distances_by_node[current_sample]:
                distances_by_node[current_sample][attribute] = {}

            for past_sample in past_samples:
                # The past is in the past.
                comparisons += 1
                if date_by_sample[past_sample] < date_by_sample[current_sample]:
                    distances_by_node[current_sample][attribute][past_sample] = get_distance_between_nodes(
                        sequences_by_node_and_gene[past_sample],
                        sequences_by_node_and_gene[current_sample],
                        distance_map
                    )

    print("Calculated %i comparisons" % comparisons)
    # Prepare params for export.
    params = {
        "attribute": args.attribute_name,
        "map_name": distance_map_names,
        "years_back_to_compare": args.years_back_to_compare
    }

    # Export distances to JSON.
    write_json({"params": params, "nodes": distances_by_node}, args.output)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
from augur.frequency_estimators import float_to_datestring, timestamp_to_float
from augur.utils import annotate_parents_for_tree, read_node_data, write_json
import Bio.Phylo
import json
import numpy as np
import pandas as pd


def get_titer_distance_between_nodes(tree, past_node, current_node, titer_attr="dTiter"):
    # Find MRCA of tips from one tip up. Sum the titer attribute of interest
    # while walking up to the MRCA, to avoid an additional pass later. The loop
    # below stops when the past node is found in the list of the candidate
    # MRCA's terminals. This test should always evaluate to true when the MRCA
    # is the root node, so we should not have to worry about trying to find the
    # parent of the root.
    current_node_branch_sum = 0.0
    mrca = current_node
    while past_node.name not in mrca.terminals:
        current_node_branch_sum += mrca.attr[titer_attr]
        mrca = mrca.parent

    # Sum the node weights for the other tip from the bottom up until we reach
    # the MRCA. The value of the MRCA is intentionally excluded here, as it
    # would represent the branch leading to the MRCA and would be outside the
    # path between the two tips.
    past_node_branch_sum = 0.0
    current_node = past_node
    while current_node != mrca:
        past_node_branch_sum += current_node.attr[titer_attr]
        current_node = current_node.parent

    final_sum = past_node_branch_sum + current_node_branch_sum
    return final_sum


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tree", help="Newick tree", required=True)
    parser.add_argument("--frequencies", help="frequencies JSON", required=True)
    parser.add_argument("--model-attribute-name", help="name of attribute to use from titer model file", default="dTiter")
    parser.add_argument("--attribute-name", help="name to store distances", required=True)
    parser.add_argument("--model", help="JSON providing the titer tree model", required=True)
    parser.add_argument("--date-annotations", help="JSON of branch lengths and date annotations from augur refine for samples in the given tree; required for comparisons to earliest or latest date", required=True)
    parser.add_argument("--months-back-for-current-samples", type=int, help="number of months prior to the last date with estimated frequencies to include samples as current", required=True)
    parser.add_argument("--years-back-to-compare", type=int, help="number of years prior to the current season to search for samples to calculate pairwise comparisons with", required=True)
    parser.add_argument("--max-past-samples", type=int, default=200, help="maximum number of past samples to randomly select for comparison to current samples")
    parser.add_argument("--min-frequency", type=float, default=0.0, help="minimum frequency to consider a sample alive")
    parser.add_argument("--output", help="JSON file with calculated distances stored by node name and attribute name", required=True)

    args = parser.parse_args()

    # Load tree and annotate parents.
    tree = Bio.Phylo.read(args.tree, "newick")
    tree = annotate_parents_for_tree(tree)

    # Make a single pass through the tree in postorder to store a set of all
    # terminals descending from each node. This uses more memory, but it allows
    # faster identification of MRCAs between any pair of tips in the tree and
    # speeds up pairwise distance calculations by orders of magnitude.
    for node in tree.find_clades(order="postorder"):
        node.terminals = set()
        for child in node.clades:
            if child.is_terminal():
                node.terminals.add(child.name)
            else:
                node.terminals.update(child.terminals)

    # Load frequencies.
    with open(args.frequencies, "r") as fh:
        frequencies = json.load(fh)

    pivots = np.array(frequencies.pop("pivots"))

    # Identify pivots that belong within our search window for past samples.
    # First, calculate dates associated with the interval for current samples
    # based on the number of months back requested. Then, calculate interval for
    # past samples with an upper bound based on the earliest current samples and
    # a lower bound based on the years back requested.
    last_pivot_datetime = pd.to_datetime(float_to_datestring(pivots[-1]))
    last_current_datetime = last_pivot_datetime - pd.DateOffset(months=args.months_back_for_current_samples)
    last_past_datetime = last_pivot_datetime - pd.DateOffset(years=args.years_back_to_compare)

    # Find the pivot indices that correspond to the current and past pivots.
    current_pivot_indices = np.array([
        pd.to_datetime(float_to_datestring(pivot)) > last_current_datetime
        for pivot in pivots
    ])
    past_pivot_indices = np.array([
        ((pd.to_datetime(float_to_datestring(pivot)) >= last_past_datetime) &
         (pd.to_datetime(float_to_datestring(pivot)) <= last_current_datetime))
        for pivot in pivots
    ])

    # Load date and titer model annotations and annotate tree with them.
    annotations = read_node_data([args.date_annotations, args.model])
    for node in tree.find_clades():
        node.attr = annotations["nodes"][node.name]
        node.attr["num_date"] = node.attr["numdate"]

    # Identify samples to compare including those in the current timepoint
    # (pivot) and those in previous timepoints.
    current_samples = []
    past_samples = []
    date_by_sample = {}
    tips_by_sample = {}
    for tip in tree.find_clades(terminal=True):
        # Samples with nonzero frequencies in the last timepoint are current
        # samples. Those with one or more nonzero frequencies in the search
        # window of the past timepoints are past samples.
        frequencies[tip.name]["frequencies"] = np.array(frequencies[tip.name]["frequencies"])
        if (frequencies[tip.name]["frequencies"][current_pivot_indices] > args.min_frequency).sum() > 0:
            current_samples.append(tip.name)
            tips_by_sample[tip.name] = tip
        elif (frequencies[tip.name]["frequencies"][past_pivot_indices] > args.min_frequency).sum() > 0:
            past_samples.append(tip.name)
            tips_by_sample[tip.name] = tip

        date_by_sample[tip.name] = tip.attr["numdate"]

    print("Expecting %i comparisons for %i current and %i past samples" % (len(current_samples) * len(past_samples), len(current_samples), len(past_samples)))
    distances_by_node = {}
    comparisons = 0

    for current_sample in current_samples:
        if not current_sample in distances_by_node:
            distances_by_node[current_sample] = {}

        if not args.attribute_name in distances_by_node[current_sample]:
            distances_by_node[current_sample][args.attribute_name] = {}

        for past_sample in past_samples:
            # The past is in the past.
            if date_by_sample[past_sample] < date_by_sample[current_sample]:
                distances_by_node[current_sample][args.attribute_name][past_sample] = np.around(get_titer_distance_between_nodes(
                    tree,
                    tips_by_sample[past_sample],
                    tips_by_sample[current_sample],
                    args.model_attribute_name
                ), 4)

                comparisons += 1
                if comparisons % 10000 == 0:
                    print("Completed", comparisons, "comparisons, with last distance of", distances_by_node[current_sample][args.attribute_name][past_sample], flush=True)

    print("Calculated %i comparisons" % comparisons)
    # Prepare params for export.
    params = {
        "attribute": args.attribute_name,
        "years_back_to_compare": args.years_back_to_compare
    }

    # Export distances to JSON.
    write_json({"params": params, "nodes": distances_by_node}, args.output, indent=None)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
from augur.utils import get_numerical_dates, read_metadata
import numpy as np
import pandas as pd
from treetime.utils import numeric_date

from select_strains import read_strain_list


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Partition strains into timepoints",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("metadata", help="tab-delimited metadata with columns for strain and date")
    parser.add_argument("timepoint", help="date for which strains should be partitioned")
    parser.add_argument("output", help="text file into which strains should be written for the given timepoint")
    parser.add_argument("--years-back", type=int, help="Number of years prior to the given timepoint to limit strains to")
    parser.add_argument("--additional-years-back-for-references", type=int, default=5, help="Additional number of years prior to the given timepoint to allow reference strains")
    parser.add_argument("--reference-strains", help="text file containing list of reference strains that should be included from the original strains even if they were sampled prior to the minimum date determined by the requested number of years before the given timepoint")
    args = parser.parse_args()

    # Convert date string to a datetime instance.
    timepoint = pd.to_datetime(args.timepoint)
    numeric_timepoint = np.around(numeric_date(timepoint), 2)

    # Load metadata with strain names and dates.
    metadata, columns = read_metadata(args.metadata)

    # Convert string dates with potential ambiguity (e.g., 2010-05-XX) into
    # floating point dates.
    dates = get_numerical_dates(metadata, fmt="%Y-%m-%d")

    # Setup reference strains.
    if args.reference_strains:
        reference_strains = read_strain_list(args.reference_strains)
    else:
        reference_strains = []

    # If a given number of years back has been requested, determine what the
    # earliest date to accept for strains is.
    if args.years_back is not None:
        earliest_timepoint = timepoint - pd.DateOffset(years=args.years_back)
        numeric_earliest_timepoint = np.around(numeric_date(earliest_timepoint), 2)

        # If reference strains are provided, calculate the earliest date to
        # accept those strains.
        if len(reference_strains) > 0:
            earliest_reference_timepoint = earliest_timepoint - pd.DateOffset(years=args.additional_years_back_for_references)
            numeric_earliest_reference_timepoint = np.around(numeric_date(earliest_reference_timepoint), 2)

    # Find strains sampled prior to the current timepoint. Strains may have
    # multiple numerical dates, so we filter on the latest (maximum) observed
    # date per strain. If a requested number of years back is provided, use the
    # corresponding earliest dates for non-reference and reference strains to
    # determine whether they are included in the current timepoint.
    timepoint_strains = []
    for strain, strain_dates in dates.items():
        strain_date = np.max(strain_dates)
        if (strain_date <= numeric_timepoint and
            ((args.years_back is None) or
             (strain_date >= numeric_earliest_timepoint) or
             (strain in reference_strains and strain_date >= numeric_earliest_reference_timepoint))):
            timepoint_strains.append(strain)

    timepoint_strains = sorted(timepoint_strains)

    # Write sorted list of strains to disk.
    with open(args.output, "w") as oh:
        for strain in timepoint_strains:
            oh.write(f"{strain}\n")
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import argparse
from augur.utils import json_to_tree, read_tree
import Bio.Phylo
import json
import matplotlib as mpl
mpl.use("Agg")
from matplotlib import gridspec
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
import pandas as pd
import sys


def timestamp_to_float(time):
    """Convert a pandas timestamp to a floating point date.
    """
    return time.year + ((time.month - 1) / 12.0)


def plot_tree(tree, figure_name, color_by_trait, initial_branch_width, tip_size,
              start_date, end_date, include_color_bar):
    """Plot a BioPython Phylo tree in the BALTIC-style.
    """
    # Plot H3N2 tree in BALTIC style from Bio.Phylo tree.
    mpl.rcParams['savefig.dpi'] = 120
    mpl.rcParams['figure.dpi'] = 100

    mpl.rcParams['font.weight']=300
    mpl.rcParams['axes.labelweight']=300
    mpl.rcParams['font.size']=14

    yvalues = [node.yvalue for node in tree.find_clades()]
    y_span = max(yvalues)
    y_unit = y_span / float(len(yvalues))

    # Setup colors.
    trait_name = color_by_trait
    traits = [k.attr[trait_name] for k in tree.find_clades()]
    norm = mpl.colors.Normalize(min(traits), max(traits))
    cmap = mpl.cm.viridis

    #
    # Setup the figure grid.
    #

    if include_color_bar:
        fig = plt.figure(figsize=(8, 6), facecolor='w')
        gs = gridspec.GridSpec(2, 1, height_ratios=[14, 1], width_ratios=[1], hspace=0.1, wspace=0.1)
        ax = fig.add_subplot(gs[0])
        colorbar_ax = fig.add_subplot(gs[1])
    else:
        fig = plt.figure(figsize=(8, 4), facecolor='w')
        gs = gridspec.GridSpec(1, 1)
        ax = fig.add_subplot(gs[0])

    L=len([k for k in tree.find_clades() if k.is_terminal()])

    # Setup arrays for tip and internal node coordinates.
    tip_circles_x = []
    tip_circles_y = []
    tip_circles_color = []
    tip_circle_sizes = []
    node_circles_x = []
    node_circles_y = []
    node_circles_color = []
    node_line_widths = []
    node_line_segments = []
    node_line_colors = []
    branch_line_segments = []
    branch_line_widths = []
    branch_line_colors = []
    branch_line_labels = []

    for k in tree.find_clades(): ## iterate over objects in tree
        x=k.attr["num_date"] ## or from x position determined earlier
        y=k.yvalue ## get y position from .drawTree that was run earlier, but could be anything else

        if k.parent is None:
            xp = None
        else:
            xp=k.parent.attr["num_date"] ## get x position of current object's parent

        if x==None: ## matplotlib won't plot Nones, like root
            x=0.0
        if xp==None:
            xp=x

        c = 'k'
        if trait_name in k.attr:
            c = cmap(norm(k.attr[trait_name]))

        branchWidth=2
        if k.is_terminal(): ## if leaf...
            s = tip_size ## tip size can be fixed

            tip_circle_sizes.append(s)
            tip_circles_x.append(x)
            tip_circles_y.append(y)
            tip_circles_color.append(c)
        else: ## if node...
            k_leaves = [child
                        for child in k.find_clades()
                        if child.is_terminal()]

            # Scale branch widths by the number of tips.
            branchWidth += initial_branch_width * len(k_leaves) / float(L)

            if len(k.clades)==1:
                node_circles_x.append(x)
                node_circles_y.append(y)
                node_circles_color.append(c)

            ax.plot([x,x],[k.clades[-1].yvalue, k.clades[0].yvalue], lw=branchWidth, color=c, ls='-', zorder=9, solid_capstyle='round')

        branch_line_segments.append([(xp, y), (x, y)])
        branch_line_widths.append(branchWidth)
        branch_line_colors.append(c)

    branch_lc = LineCollection(branch_line_segments, zorder=9)
    branch_lc.set_color(branch_line_colors)
    branch_lc.set_linewidth(branch_line_widths)
    branch_lc.set_label(branch_line_labels)
    branch_lc.set_linestyle("-")
    ax.add_collection(branch_lc)

    # Add circles for tips and internal nodes.
    tip_circle_sizes = np.array(tip_circle_sizes)
    ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes, facecolor=tip_circles_color, edgecolor='none',zorder=11) ## plot circle for every tip
    ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes*2, facecolor='k', edgecolor='none', zorder=10) ## plot black circle underneath
    ax.scatter(node_circles_x, node_circles_y, facecolor=node_circles_color, s=50, edgecolor='none', zorder=10, lw=2, marker='|') ## mark every node in the tree to highlight that it's a multitype tree

    #ax.set_ylim(-10, y_span - 300)

    ax.spines['top'].set_visible(False) ## no axes
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.grid(axis='x',ls='-',color='grey')
    ax.tick_params(axis='y',size=0)
    ax.set_yticklabels([])

    if start_date:
        # Always add a buffer to the left edge of the plot so data up to the
        # given end date can be clearly seen.
        ax.set_xlim(left=timestamp_to_float(pd.to_datetime(start_date)) - 2.0)

    if end_date:
        # Always add a buffer of 3 months to the right edge of the plot so data
        # up to the given end date can be clearly seen.
        ax.set_xlim(right=timestamp_to_float(pd.to_datetime(end_date)) + 0.25)

    if include_color_bar:
        cb1 = mpl.colorbar.ColorbarBase(
            colorbar_ax,
            cmap=cmap,
            norm=norm,
            orientation='horizontal'
        )
        cb1.set_label(color_by_trait)

    gs.tight_layout(fig)
    plt.savefig(figure_name)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("tree", help="auspice tree JSON or Newick tree")
    parser.add_argument("output", help="plotted tree figure")
    parser.add_argument("--colorby", help="trait in tree to color by", default="num_date")
    parser.add_argument("--branch_width", help="initial branch width", type=int, default=10)
    parser.add_argument("--tip_size", help="tip size", type=int, default=10)
    parser.add_argument("--start-date", help="earliest date to show on the x-axis")
    parser.add_argument("--end-date", help="latest date to show on the x-axis")
    parser.add_argument("--include-color-bar", action="store_true", help="display a color bar for the color by option at the bottom of the plot")
    args = parser.parse_args()

    if args.tree.endswith(".json"):
        with open(args.tree, "r") as json_fh:
            json_tree = json.load(json_fh)

        # Convert JSON tree layout to a Biopython Clade instance.
        tree = json_to_tree(json_tree)

        # Plot the tree.
        plot_tree(
            tree,
            args.output,
            args.colorby,
            args.branch_width,
            args.tip_size,
            args.start_date,
            args.end_date,
            args.include_color_bar
        )
    else:
        tree = read_tree(args.tree)
        tree.ladderize()
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        Bio.Phylo.draw(tree, axes=ax, label_func=lambda node: "", show_confidence=False)
        plt.savefig(args.output)
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
import argparse
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
import seaborn as sns
import statsmodels.api as sm


np.random.seed(314159)

PLOT_THEME_ATTRIBUTES = {
    "axes.labelsize": 14,
    "font.size": 18,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "figure.figsize": [6.0, 4.0],
    "savefig.dpi": 200,
    "figure.dpi": 200,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "text.usetex": False
}


def matthews_correlation_coefficient(tp, tn, fp, fn):
    """Return Matthews correlation coefficient for values from a confusion matrix.
    Implementation is based on the definition from wikipedia:

    https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    """
    numerator = (tp * tn) - (fp * fn)
    denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    if denominator == 0:
            denominator = 1

    return float(numerator) / denominator


def get_matthews_correlation_coefficient_for_data_frame(freq_df, return_confusion_matrix=False):
        """Calculate Matthew's correlation coefficient from a given pandas data frame
        with columns for initial, observed, and predicted frequencies.
        """
        observed_growth = (freq_df["frequency_final"] > freq_df["frequency"])
        predicted_growth = (freq_df["projected_frequency"] > freq_df["frequency"])
        true_positives = ((observed_growth) & (predicted_growth)).sum()
        false_positives= ((~observed_growth) & (predicted_growth)).sum()

        observed_decline = (freq_df["frequency_final"] < freq_df["frequency"])
        predicted_decline = (freq_df["projected_frequency"] < freq_df["frequency"])
        true_negatives = ((observed_decline) & (predicted_decline)).sum()
        false_negatives = ((~observed_decline) & (predicted_decline)).sum()

        mcc = matthews_correlation_coefficient(
            true_positives,
            true_negatives,
            false_positives,
            false_negatives
        )

        if return_confusion_matrix:
            confusion_matrix = {
                "tp": true_positives,
                "tn": true_negatives,
                "fp": false_positives,
                "fn": false_negatives
            }

            return mcc, confusion_matrix
        else:
            return mcc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors and weighted distances to the future")
    parser.add_argument("--tips-to-clades", required=True, help="tab-delimited file of all clades per tip and timepoint from a single tree that includes all tips in the given tip attributes table")
    parser.add_argument("--forecasts", required=True, help="table of forecasts for the given tips")
    parser.add_argument("--model-errors", required=True, help="annotated validation errors for the model used to make the given forecasts")
    parser.add_argument("--bootstrap-samples", type=int, default=100, help="number of bootstrap samples to generate for confidence intervals around absolute forecast errors")
    parser.add_argument("--population", help="the population being analyzed (e.g., simulated or natural)")
    parser.add_argument("--sample", help="sample name for population being analyzed")
    parser.add_argument("--predictors", help="predictors being analyzed")
    parser.add_argument("--output", required=True, help="validation figure")
    parser.add_argument("--output-clades-table", help="table of clade frequencies used in left panels")
    parser.add_argument("--output-ranks-table", help="table of strain ranks used in right panels")

    args = parser.parse_args()

    # Define constants for frequency analyses below.
    min_clade_frequency = 0.15
    precision = 4
    pseudofrequency = 0.001
    number_of_bootstrap_samples = args.bootstrap_samples

    sns.set_style("white")
    mpl.rcParams.update(PLOT_THEME_ATTRIBUTES)

    # Load validation errors for the model used to produce the given forecasts
    # table. These errors are used to identify the first validation timepoint.
    model_errors = pd.read_csv(
        args.model_errors,
        sep="\t",
        parse_dates=["validation_timepoint"]
    )
    first_validation_timepoint = model_errors["validation_timepoint"].min().strftime("%Y-%m-%d")

    # Load tip attributes to be associated with clades and used to calculate
    # clade frequencies.
    tips = pd.read_csv(
        args.tip_attributes,
        sep="\t",
        parse_dates=["timepoint"],
        usecols=["strain", "timepoint", "frequency", "aa_sequence"]
    )
    tips = tips.query("timepoint >= '%s'" % first_validation_timepoint).copy()
    distinct_tips_with_sequence = tips.groupby(["timepoint", "aa_sequence"]).first().reset_index()

    # Load mapping of tips to clades based on a single tree that included all of
    # the tips in the given tip attributes table.
    tips_to_clades = pd.read_csv(
        args.tips_to_clades,
        sep="\t",
        usecols=["tip", "clade_membership", "depth"]
    )
    tips_to_clades = tips_to_clades.rename(columns={"tip": "strain"})

    # Load forecasts for all tips by the model associated with the given model
    # errors. First, load only a subset of the forecast information to simplify
    # downstream data frames.
    forecasts = pd.read_csv(
        args.forecasts,
        sep="\t",
        parse_dates=["timepoint"],
        usecols=["timepoint", "strain", "frequency", "projected_frequency"]
    )

    # Next, load the complete forecasts data frame for ranking of estimated and
    # observed closest strains.
    full_forecasts = pd.read_csv(
        args.forecasts,
        sep="\t",
        parse_dates=["timepoint", "future_timepoint"]
    )
    full_forecasts = full_forecasts.query("timepoint >= '%s'" % first_validation_timepoint).copy()

    # Map tip attributes to all corresponding clades.
    clade_tip_initial_frequencies = tips_to_clades.merge(
        tips,
        on=["strain"]
    )
    clade_tip_initial_frequencies["future_timepoint"] = clade_tip_initial_frequencies["timepoint"] + pd.DateOffset(months=12)

    # Calculate the initial frequency of each clade per timepoint.
    initial_clade_frequencies = clade_tip_initial_frequencies.groupby([
        "timepoint", "future_timepoint", "clade_membership"
    ])["frequency"].sum().reset_index()

    # Merge clade frequencies between adjacent years.
    initial_and_observed_clade_frequencies = initial_clade_frequencies.merge(
        initial_clade_frequencies,
        left_on=["future_timepoint", "clade_membership"],
        right_on=["timepoint", "clade_membership"],
        suffixes=["", "_final"]
    ).groupby(["timepoint", "clade_membership", "frequency"])["frequency_final"].sum().reset_index()

    # Select clades with an initial frequency above the defined threshold.
    large_clades = initial_and_observed_clade_frequencies.query("frequency > %s" % min_clade_frequency).copy()

    # Find estimated future frequencies of large clades.
    clade_tip_estimated_frequencies = tips_to_clades.merge(
        forecasts,
        on=["strain"]
    )
    estimated_clade_frequencies = clade_tip_estimated_frequencies.groupby(
        ["timepoint", "clade_membership"]
    ).aggregate({"projected_frequency": "sum"}).reset_index()

    # Annotate initial and observed clade frequencies with the estimated future
    # values.
    complete_clade_frequencies = large_clades.merge(
        estimated_clade_frequencies,
        on=["timepoint", "clade_membership"],
        suffixes=["", "_other"]
    )

    # Reduce precision of frequency estimates to a reasonable value and
    # eliminate entries where the clade frequency did not change between the
    # initial and final timepoints (these are primarily clades that have already
    # fixed at 100%).
    complete_clade_frequencies = np.round(complete_clade_frequencies, 2)
    complete_clade_frequencies = complete_clade_frequencies.query("frequency != frequency_final").copy()

    # Calculate accuracy of growth and decline classifications.
    mcc, confusion_matrix = get_matthews_correlation_coefficient_for_data_frame(complete_clade_frequencies, True)
    growth_accuracy = confusion_matrix["tp"] / float(confusion_matrix["tp"] + confusion_matrix["fp"])
    decline_accuracy = confusion_matrix["tn"] / float(confusion_matrix["tn"] + confusion_matrix["fn"])

    # Calculate the observed and estimated log growth rates for all clades.
    complete_clade_frequencies["log_observed_growth_rate"] = (
        np.log10((complete_clade_frequencies["frequency_final"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency))
    )
    complete_clade_frequencies["log_estimated_growth_rate"] = (
        np.log10((complete_clade_frequencies["projected_frequency"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency))
    )

    # Calculate the bounds for the clade growth rate display based on values in
    # observed and estimated rates.
    log_lower_limit = complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].min().min() - 0.1
    log_upper_limit = np.ceil(complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].max().max()) + 0.1

    # Calculate the Pearson's correlation between observed and estimated log
    # growth rates.
    r, p = pearsonr(
        complete_clade_frequencies["log_observed_growth_rate"],
        complete_clade_frequencies["log_estimated_growth_rate"]
    )

    # Use observed forecasting errors to inspect the accuracy of one-year
    # lookaheads based on the initial frequency of each clade.
    complete_clade_frequencies["clade_error"] = complete_clade_frequencies["frequency_final"] - complete_clade_frequencies["projected_frequency"]
    complete_clade_frequencies["absolute_clade_error"] = np.abs(complete_clade_frequencies["clade_error"])

    # Estimate uncertainty of the mean absolute clade error by initial clade
    # frequency with LOESS fits to bootstraps from the complete data frame.
    bootstrap_samples = []
    for i in range(number_of_bootstrap_samples):
        complete_clade_frequencies_sample = complete_clade_frequencies.sample(frac=1.0, replace=True).copy()
        z = sm.nonparametric.lowess(
            complete_clade_frequencies_sample["absolute_clade_error"].values * 100,
            complete_clade_frequencies_sample["frequency"].values * 100
        )

        # Track both the initial frequency and the LOESS fits for each bootstrap
        # sample. This ensures that the summary statistics calculated downstream
        # per initial frequency are based on the correct LOESS values.
        bootstrap_samples.append(
            pd.DataFrame({
                "initial_frequency": z[:, 0],
                "loess": z[:, 1]}
            )
        )

    bootstrap_df = pd.concat(bootstrap_samples)

    # Calculate the mean and 95% CIs from bootstraps.
    bootstrap_summary = bootstrap_df.groupby("initial_frequency")["loess"].agg(
        lower=lambda group: np.percentile(group, 2.5),
        mean=np.mean,
        upper=lambda group: np.percentile(group, 97.5)
    ).reset_index()

    initial_frequency = bootstrap_summary["initial_frequency"].values
    mean_lowess_fit = bootstrap_summary["mean"].values
    upper_lowess_fit = bootstrap_summary["upper"].values
    lower_lowess_fit = bootstrap_summary["lower"].values

    # For each timepoint, calculate the percentile rank of each strain based on
    # both its observed and estimated distance to the future.
    sorted_df = full_forecasts.dropna().sort_values(
        ["timepoint"]
    ).copy()

    # Filter sorted records by strains with distinct amino acid sequences.
    sorted_df = sorted_df.merge(
        distinct_tips_with_sequence,
        on=["timepoint", "strain"]
    )

    # First, calculate the rank per strain by observed distance to the future.
    sorted_df["timepoint_rank"] = sorted_df.groupby("timepoint")["weighted_distance_to_future"].rank(pct=True)

    # Then, calculate the rank by estimated distance to the future.
    sorted_df["timepoint_estimated_rank"] = sorted_df.groupby("timepoint")["y"].rank(pct=True)

    # Calculate the Spearman correlation of ranks, to get a measure of the model
    # fit.
    rank_rho, rank_p = spearmanr(
        sorted_df["timepoint_rank"],
        sorted_df["timepoint_estimated_rank"]
    )

    # Select the observed rank of the estimated closest strain to the future per
    # timepoint.
    best_fitness_rank_by_timepoint_df = sorted_df.sort_values(
        ["timepoint", "y"],
        ascending=True
    ).groupby("timepoint")["timepoint_rank"].first().reset_index()

    #
    # Summarize model fit by clade frequencies and strain ranks.
    #

    fig = plt.figure(figsize=(10, 10), facecolor='w')
    gs = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1], wspace=0.1)

    ticks = np.array([0, 0.2, 0.4, 0.6, 0.8, 1.0])

    #
    # Top-left: Clade growth rate correlations
    #

    clade_ax = fig.add_subplot(gs[0])
    clade_ax.plot(
        complete_clade_frequencies["log_observed_growth_rate"],
        complete_clade_frequencies["log_estimated_growth_rate"],
        "o",
        alpha=0.4
    )

    clade_ax.axhline(color="#cccccc", zorder=-5)
    clade_ax.axvline(color="#cccccc", zorder=-5)

    if p < 0.001:
        p_value = "$p value$ < 0.001"
    else:
        p_value = "$p$ = %.3f" % p

    clade_ax.text(
        0.02,
        0.15,
        "Growth accuracy = %.2f\nDecline accuracy = %.2f\nPearson $R^2$ = %.2f\nN = %s" % (
            growth_accuracy,
            decline_accuracy,
            r ** 2,
            complete_clade_frequencies.shape[0]
        ),
        fontsize=12,
        horizontalalignment="left",
        verticalalignment="center",
        transform=clade_ax.transAxes
    )

    clade_ax.set_xlabel("Observed $log_{10}$ fold change")
    clade_ax.set_ylabel("Estimated $log_{10}$ fold change")

    growth_rate_ticks = np.arange(-6, 4, 1)
    clade_ax.set_xticks(growth_rate_ticks)
    clade_ax.set_yticks(growth_rate_ticks)

    clade_ax.set_xlim(log_lower_limit, log_upper_limit)
    clade_ax.set_ylim(log_lower_limit, log_upper_limit)
    clade_ax.set_aspect("equal")

    #
    # Top-right: Estimated closest strain to the future ranking
    #

    rank_ax = fig.add_subplot(gs[1])

    median_best_rank = best_fitness_rank_by_timepoint_df["timepoint_rank"].median()

    rank_ax.hist(best_fitness_rank_by_timepoint_df["timepoint_rank"], bins=np.arange(0, 1.01, 0.05), label=None)
    rank_ax.axvline(
        median_best_rank,
        color="orange",
        label="median = %i%%" % round(median_best_rank * 100, 0)
    )
    rank_ax.set_xticks(ticks)
    rank_ax.set_xticklabels(['{:3.0f}%'.format(x*100) for x in ticks])
    rank_ax.set_xlim(0, 1)

    rank_ax.legend(
        frameon=False
    )
    rank_ax.set_xlabel("Percentile rank by distance\nfor estimated closest strain")
    rank_ax.set_ylabel("Number of timepoints")

    #
    # Bottom-left: Absolute clade forecast errors with uncertainty.
    #

    forecast_error_ax = fig.add_subplot(gs[2])
    forecast_error_ax.plot(
        complete_clade_frequencies["frequency"].values * 100,
        complete_clade_frequencies["absolute_clade_error"].values * 100,
        "o",
        alpha=0.2
    )

    forecast_error_ax.fill_between(
        initial_frequency,
        lower_lowess_fit,
        upper_lowess_fit,
        alpha=0.1,
        color="black"
    )
    forecast_error_ax.plot(
        initial_frequency,
        mean_lowess_fit,
        alpha=0.75,
        color="black"
    )

    forecast_error_ax.set_xlabel("Initial clade frequency")
    forecast_error_ax.set_ylabel("Absolute forecast error")

    forecast_error_ax.set_xticks(ticks * 100)
    forecast_error_ax.set_yticks(ticks * 100)
    forecast_error_ax.set_xticklabels(['{:3.0f}%'.format(x * 100) for x in ticks])
    forecast_error_ax.set_yticklabels(['{:3.0f}%'.format(x * 100) for x in ticks])

    forecast_error_ax.set_aspect("equal")

    #
    # Bottom-right: Observed vs. estimated percentile rank for all strains at all timepoints.
    #

    all_rank_ax = fig.add_subplot(gs[3])

    if rank_p < 0.001:
        rank_p_value = "$p$ < 0.001"
    else:
        rank_p_value = "$p$ = %.3f" % rank_p

    all_rank_ax.plot(
        sorted_df["timepoint_rank"],
        sorted_df["timepoint_estimated_rank"],
        "o",
        alpha=0.05
    )

    all_rank_ax.text(
        0.45,
        0.05,
        "Spearman $\\rho^2$ = %.2f" % (rank_rho ** 2,),
        fontsize=12,
        horizontalalignment="left",
        verticalalignment="center",
        transform=all_rank_ax.transAxes
    )

    all_rank_ax.set_xticks(ticks)
    all_rank_ax.set_yticks(ticks)
    all_rank_ax.set_xticklabels(['{:3.0f}%'.format(x * 100) for x in ticks])
    all_rank_ax.set_yticklabels(['{:3.0f}%'.format(x * 100) for x in ticks])

    all_rank_ax.set_xlabel("Observed percentile rank")
    all_rank_ax.set_ylabel("Estimated percentile rank")
    all_rank_ax.set_aspect("equal")

    # Annotate panel labels.
    panel_labels_dict = {
        "weight": "bold",
        "size": 14
    }
    plt.figtext(0.0, 0.97, "A", **panel_labels_dict)
    plt.figtext(0.5, 0.97, "B", **panel_labels_dict)
    plt.figtext(0.0, 0.47, "C", **panel_labels_dict)
    plt.figtext(0.5, 0.47, "D", **panel_labels_dict)

    gs.tight_layout(fig)
    plt.savefig(args.output)

    timepoints_better_than_20th_percentile = (best_fitness_rank_by_timepoint_df["timepoint_rank"] <= 0.2).sum()
    total_timepoints = best_fitness_rank_by_timepoint_df.shape[0]
    print(
        "Estimated strain was in the top 20th percentile at %s of %s (%s%%) timepoints" % (
            timepoints_better_than_20th_percentile,
            total_timepoints,
            int(np.round((timepoints_better_than_20th_percentile / float(total_timepoints)) * 100))
        )
    )

    if args.output_clades_table:
        complete_clade_frequencies = complete_clade_frequencies.rename(columns={
            "frequency": "initial_frequency",
            "frequency_final": "observed_future_frequency",
            "projected_frequency": "estimated_future_frequency"
        })
        complete_clade_frequencies["population"] = args.population
        complete_clade_frequencies["predictors"] = args.predictors
        complete_clade_frequencies["error_type"] = "test" if "test" in args.sample else "validation"

        complete_clade_frequencies.to_csv(
            args.output_clades_table,
            sep="\t",
            header=True,
            index=False
        )

    if args.output_ranks_table:
        sorted_df["observed_distance_to_future"] = sorted_df["weighted_distance_to_future"]
        sorted_df["estimated_distance_to_future"] = sorted_df["y"]
        sorted_df["observed_rank"] = sorted_df["timepoint_rank"]
        sorted_df["estimated_rank"] = sorted_df["timepoint_estimated_rank"]

        sorted_df["population"] = args.population
        sorted_df["sample"] = args.sample
        sorted_df["predictors"] = args.predictors
        sorted_df["error_type"] = "test" if "test" in args.sample else "validation"
        sorted_df = np.around(sorted_df, 2)

        sorted_df.to_csv(
            args.output_ranks_table,
            sep="\t",
            header=True,
            index=False,
            columns=[
                "population",
                "error_type",
                "predictors",
                "timepoint",
                "strain",
                "observed_distance_to_future",
                "estimated_distance_to_future",
                "observed_rank",
                "estimated_rank"
            ]
        )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import argparse
import json


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--titers-model", required=True, help="titer model JSON from augur titers tree")
    parser.add_argument("--output", required=True, help="titer model JSON with renamed fields for FRA data")

    args = parser.parse_args()

    with open(args.titers_model, "r") as fh:
        titers_json = json.load(fh)

    for sample in titers_json["nodes"].keys():
        titers_json["nodes"][sample]["fra_cTiter"] = titers_json["nodes"][sample]["cTiter"]
        titers_json["nodes"][sample]["fra_dTiter"] = titers_json["nodes"][sample]["dTiter"]
        del titers_json["nodes"][sample]["cTiter"]
        del titers_json["nodes"][sample]["dTiter"]

    with open(args.output, "w") as oh:
        json.dump(titers_json, oh, indent=1)
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import numpy as np
import pandas as pd


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Standardize predictors",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints")
    parser.add_argument("--tips-to-clades", required=True, help="tab-delimited file of all clades per tip and timepoint")
    parser.add_argument("--delta-months", required=True, type=int, help="number of months to project clade frequencies into the future")
    parser.add_argument("--output", required=True, help="tab-delimited file of clades per timepoint and their corresponding tips and tip frequencies at the given delta time in the future")
    args = parser.parse_args()

    delta_time_offset = pd.DateOffset(months=args.delta_months)

    # Load tip attributes, subsetting to relevant frequency and time information.
    tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"])
    tips = tips.loc[:, ["strain", "clade_membership", "timepoint", "frequency"]].copy()

    # Confirm tip frequencies sum to 1 per timepoint.
    summed_tip_frequencies = tips.groupby("timepoint")["frequency"].sum()
    print(summed_tip_frequencies)
    assert all([
        np.isclose(total, 1.0, atol=1e-3)
        for total in summed_tip_frequencies
    ])

    # Identify distinct clades per timepoint.
    clades = tips.loc[:, ["timepoint", "clade_membership"]].drop_duplicates().copy()
    clades = clades.rename(columns={"timepoint": "initial_timepoint"})

    # Annotate future timepoint.
    clades["final_timepoint"] = clades["initial_timepoint"] + delta_time_offset

    # Load mapping of tips to all possible clades at each timepoint.
    tips_to_clades = pd.read_csv(args.tips_to_clades, sep="\t", parse_dates=["timepoint"])
    tips_to_clades = tips_to_clades.loc[:, ["tip", "clade_membership", "depth", "timepoint"]].copy()

    # Get all tip-clade combinations by timepoint for the distinct clades.
    future_tips_by_clades = clades.merge(
        tips_to_clades,
        how="inner",
        left_on=["final_timepoint", "clade_membership"],
        right_on=["timepoint", "clade_membership"]
    )

    # Drop redundant columns.
    future_tips_by_clades = future_tips_by_clades.drop(
        columns=["timepoint"]
    )

    # Get the closest clade to each tip by timepoint. This relies on records
    # being sorted by depth of clade from tip.
    future_tips_by_clades = future_tips_by_clades.sort_values(["initial_timepoint", "tip", "depth"]).groupby(["initial_timepoint", "tip"]).first().reset_index()

    # Get frequencies of future tips associated with current clades.
    future_clade_frequencies = future_tips_by_clades.merge(tips, how="inner", left_on=["tip", "final_timepoint"], right_on=["strain", "timepoint"], suffixes=["", "_tip"])
    future_clade_frequencies = future_clade_frequencies.drop(
        columns=[
            "tip",
            "depth",
            "clade_membership_tip",
            "timepoint"
        ]
    )

    # Confirm that future frequencies sum to 1.
    print(future_clade_frequencies.groupby("initial_timepoint")["frequency"].sum())

    # Confirm the future frequencies of individual clades.
    print(future_clade_frequencies.groupby(["initial_timepoint", "clade_membership"])["frequency"].sum())

    # Left join original clades table with the future tip frequencies to enable
    # assessment of all current clades including those without future tips.
    final_clade_frequencies = clades.merge(
        future_clade_frequencies,
        how="left",
        on=["initial_timepoint", "final_timepoint", "clade_membership"]
    )

    # Fill frequency of clades without any future tips with zeros to enable a
    # simple groupby in the future to get observed future frequencies of all
    # clades.
    final_clade_frequencies["frequency"] = final_clade_frequencies["frequency"].fillna(0.0)

    # Confirm that future frequencies sum to 1.
    print(final_clade_frequencies.groupby("initial_timepoint")["frequency"].sum())

    # Save clade future tip frequencies by timepoint.
    final_clade_frequencies.to_csv(args.output, sep="\t", na_rep="N/A", index=False)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import argparse, sys, os
from augur.utils import read_metadata, get_numerical_dates
import Bio
import Bio.SeqIO
from collections import defaultdict
from datetime import datetime, timedelta, date
import numpy as np
from treetime.utils import numeric_date


vpm_dict = {
    2: 3,
    3: 2,
    6: 2,
    12: 1,
}

regions = [
    ('africa',            "",   1.02),
    ('europe',            "EU", 0.74),
    ('north_america',     "NA", 0.54),
    ('china',             "AS", 1.36),
    ('south_asia',        "AS", 1.45),
    ('japan_korea',       "AS", 0.20),
    ('oceania',           "OC", 0.04),
    ('south_america',     "SA", 0.41),
    ('southeast_asia',    "AS", 0.62),
    ('west_asia',         "AS", 0.75)
]

subcats = [r[0] for r in regions]

def read_strain_list(fname):
    """
    read strain names from a file assuming there is one strain name per line

    Parameters:
    -----------
    fname : str
        file name

    Returns:
    --------
    strain_list : list
        strain names

    """
    if os.path.isfile(fname):
        with open(fname, 'r') as fh:
            strain_list = [x.strip() for x in fh.readlines() if x[0]!='#']
    else:
        print("ERROR: file %s containing strain list not found"%fname)
        sys.exit(1)

    return strain_list


def count_titer_measurements(fname):
    """
    read how many titer measurements exist for each virus

    Parameters:
    -----------
    fname : str
        file name

    Returns:
    --------
    titer_count : defaultdict(int)
        dictionary with titer count for each strain
    """
    titer_count = defaultdict(int)
    if os.path.isfile(fname):
        with open(fname, 'r') as fh:
            for line in fh:
                titer_count[line.split()[0]] += 1
    else:
        print("ERROR: file %s containing strain list not found"%fname)
        sys.exit(1)

    return titer_count


def populate_categories(metadata):
    super_category = lambda x: (x['year'],
                                x['month'])

    category = lambda x: (x['region'],
                          x['year'],
                          x['month'])

    virus_by_category = defaultdict(list)
    virus_by_super_category = defaultdict(list)
    for v in metadata:
        virus_by_category[category(metadata[v])].append(v)
        virus_by_super_category[super_category(metadata[v])].append(v)

    return virus_by_super_category, virus_by_category


def flu_subsampling(metadata, viruses_per_month, time_interval, titer_fname=None):
    # Filter metadata by date using the given time interval. Using numeric dates
    # here allows users to define time intervals to the day and filter viruses
    # at that same level of precision.
    time_interval_start = round(numeric_date(time_interval[1]), 2)
    time_interval_end = round(numeric_date(time_interval[0]), 2)
    metadata = {
        strain: record
        for strain, record in metadata.items()
        if time_interval_start <= record["num_date"] <= time_interval_end
    }

    #### DEFINE THE PRIORITY
    if titer_fname:
        HI_titer_count = count_titer_measurements(titer_fname)
        def priority(strain):
            return HI_titer_count[strain]
    else:
        print("No titer counts provided - using random priorities")
        def priority(strain):
            return np.random.random()

    subcat_threshold = int(np.ceil(1.0*viruses_per_month/len(subcats)))

    virus_by_super_category, virus_by_category = populate_categories(metadata)
    def threshold_fn(x):
        #x is the subsampling category, in this case a tuple of (region, year, month)

        # if there are not enough viruses by super category, take everything
        if len(virus_by_super_category[x[1:]]) < viruses_per_month:
            return viruses_per_month

        # otherwise, sort sub categories by strain count
        sub_counts = sorted([(r, virus_by_super_category[(r, x[1], x[2])]) for r in subcats],
                             key=lambda y:len(y[1]))

        # if all (the smallest) subcat has more strains than the threshold, return threshold
        if len(sub_counts[0][1]) > subcat_threshold:
            return subcat_threshold


        strains_selected = 0
        tmp_subcat_threshold = subcat_threshold
        for ri, (r, strains) in enumerate(sub_counts):
            current_threshold = int(np.ceil(1.0*(viruses_per_month-strains_selected)/(len(subcats)-ri)))
            if r==x[0]:
                return current_threshold
            else:
                strains_selected += min(len(strains), current_threshold)
        return subcat_threshold

    selected_strains = []
    for cat, val in virus_by_category.items():
        val.sort(key=priority, reverse=True)
        selected_strains.extend(val[:threshold_fn(cat)])

    return selected_strains


def determine_time_interval(time_interval, resolution):
    # determine date range to include strains from
    if time_interval: # explicitly specified
        datetime_interval = sorted([datetime.strptime(x, '%Y-%m-%d').date() for x in args.time_interval], reverse=True)
    else: # derived from resolution arguments (explicit takes precedence)
        if resolution:
            years_back = int(resolution[:-1])
        else:
            years_back = 3
        datetime_interval = [datetime.today().date(), (datetime.today()  - timedelta(days=365.25 * years_back)).date()]
    return datetime_interval

def parse_metadata(segments, metadata_files):
    metadata = {}
    for segment, fname in zip(segments, metadata_files):
        tmp_meta, columns = read_metadata(fname)

        numerical_dates = get_numerical_dates(tmp_meta, fmt='%Y-%m-%d')
        for x in tmp_meta:
            tmp_meta[x]['num_date'] = np.mean(numerical_dates[x])
            tmp_meta[x]['year'] = int(tmp_meta[x]['num_date'])

            # Extract month values starting at January == 1 for comparison with
            # datetime objects.
            tmp_meta[x]['month'] = int((tmp_meta[x]['num_date'] % 1) * 12) + 1
        metadata[segment] = tmp_meta
    return metadata

def parse_sequences(segments, sequence_files):
    """Load sequence names into a dictionary of sets indexed by segment.
    """
    sequences = {}
    for segment, filename in zip(segments, sequence_files):
        sequence_set = Bio.SeqIO.parse(filename, "fasta")
        sequences[segment] = set()
        for seq in sequence_set:
            sequences[segment].add(seq.name)

    return sequences


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Select strains for downstream analysis",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('-v', '--viruses_per_month', type = int, default=15,
                        help='Subsample x viruses per country per month. Set to 0 to disable subsampling.')
    parser.add_argument('--sequences', nargs='+', help="FASTA file with viral sequences, one for each segment")
    parser.add_argument('--metadata', nargs='+', help="file with metadata associated with viral sequences, one for each segment")
    parser.add_argument('--output', help="name of the file to write selected strains to")
    parser.add_argument('--verbose', action="store_true", help="turn on verbose reporting")

    parser.add_argument('-l', '--lineage', choices=['h3n2', 'h1n1pdm', 'vic', 'yam'], default='h3n2', type=str, help="single lineage to include (default: h3n2)")
    parser.add_argument('-r', '--resolution',default='3y', type = str,  help = "single resolution to include (default: 3y)")
    parser.add_argument('-s', '--segments', default=['ha'], nargs='+', type = str,  help = "list of segments to include (default: ha)")
    parser.add_argument('--sampling', default = 'even', type=str,
                        help='sample evenly over regions (even) (default), or prioritize one region (region name), otherwise sample randomly')
    parser.add_argument('--time-interval', nargs=2, help="explicit time interval to use -- overrides resolutions"
                                                                     " expects YYYY-MM-DD YYYY-MM-DD")
    parser.add_argument('--titers', help="a text file titers. this will only read in how many titer measurements are available for a each virus"
                                          " and use this count as a priority for inclusion during subsampling.")
    parser.add_argument('--include', help="a text file containing strains (one per line) that will be included regardless of subsampling")
    parser.add_argument('--max-include-range', type=float, default=5, help="number of years prior to the lower date limit for reference strain inclusion")
    parser.add_argument('--exclude', help="a text file containing strains (one per line) that will be excluded")

    args = parser.parse_args()
    time_interval = determine_time_interval(args.time_interval, args.resolution)

    # derive additional lower inclusion date for "force-included strains"
    lower_reference_cutoff = date(year = time_interval[1].year - args.max_include_range, month=1, day=1)
    upper_reference_cutoff = time_interval[0]

    # read strains to exclude
    excluded_strains = read_strain_list(args.exclude) if args.exclude else []
    # read strains to include
    included_strains = read_strain_list(args.include) if args.include else []

    # read in sequence names to determine which sequences already passed upstream filters
    sequence_names_by_segment = parse_sequences(args.segments, args.sequences)

    # read in meta data, parse numeric dates
    metadata = parse_metadata(args.segments, args.metadata)

    # eliminate all metadata entries that do not have sequences
    filtered_metadata = {}
    for segment in metadata:
        filtered_metadata[segment] = {}
        for name in metadata[segment]:
            if name in sequence_names_by_segment[segment]:
                filtered_metadata[segment][name] = metadata[segment][name]

    # filter down to strains with sequences for all required segments
    guide_segment = args.segments[0]
    strains_with_all_segments = set.intersection(*(set(filtered_metadata[x].keys()) for x in args.segments))
    # exclude outlier strains
    strains_with_all_segments.difference_update(set(excluded_strains))
    # subsample by region, month, year
    selected_strains = flu_subsampling({x:filtered_metadata[guide_segment][x] for x in strains_with_all_segments},
                                  args.viruses_per_month, time_interval, titer_fname=args.titers)

    # add strains that need to be included
    for strain in included_strains:
        if strain in strains_with_all_segments and strain not in selected_strains:
            # Do not include strains sampled too far in the past or strains
            # sampled from the future relative to the requested build interval.
            if (filtered_metadata[guide_segment][strain]['year'] >= lower_reference_cutoff.year and
                filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff)):
                selected_strains.append(strain)

    # Confirm that none of the selected strains were sampled outside of the
    # requested interval.
    for strain in selected_strains:
        assert filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff)

    # write the list of selected strains to file
    with open(args.output, 'w') as ofile:
        ofile.write('\n'.join(selected_strains))
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import datetime
import pandas as pd


def float_to_datestring(time):
    """Convert a floating point date from TreeTime `numeric_date` to a date string
    """
    # Extract the year and remainder from the floating point date.
    year = int(time)
    remainder = time - year

    # Calculate the day of the year (out of 365 + 0.25 for leap years).
    tm_yday = int(remainder * 365.25)
    if tm_yday == 0:
        tm_yday = 1

    # Construct a date object from the year and day of the year.
    date = datetime.datetime.strptime("%s-%s" % (year, tm_yday), "%Y-%j")

    # Build the date string with zero-padded months and days.
    date_string = "%s-%.2i-%.2i" % (date.year, date.month, date.day)

    return date_string


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--metadata", help="metadata for simulated sequences")
    parser.add_argument("--start-year", default=2000.0, type=float, help="year to start simulated dates from")
    parser.add_argument("--generations-per-year", default=200.0, type=float, help="number of generations to map to a single yeasr")
    parser.add_argument("--output", help="metadata with standardized dates and nonzero fitness records")

    args = parser.parse_args()

    df = pd.read_csv(args.metadata, sep="\t")
    df["num_date"] = args.start_year + (df["generation"] / args.generations_per_year)
    df["date"] = df["num_date"].apply(float_to_datestring)
    df["year"]  = pd.to_datetime(df["date"]).dt.year
    df["month"]  = pd.to_datetime(df["date"]).dt.month

    # Omit records with a fitness of zero.
    df[df["fitness"] > 0].to_csv(args.output, header=True, index=False, sep="\t")
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import json


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Convert titer substitution model to distance map",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--model", required=True, help="JSON from titer substitution model")
    parser.add_argument("--output", required=True, help="distance map JSON")
    args = parser.parse_args()

    # Load titer model.
    with open(args.model, "r") as fh:
        model = json.load(fh)

    # Prepare a distance map for the model.
    distance_map = {
        "name": "titer_substitution_model",
        "default": 0.0,
        "map": {}
    }

    # Convert values like:
    # "HA1:E173K": 0.4656
    # to distance map format.
    for substitution, weight in model["substitution"].items():
        gene, mutation = substitution.split(":")
        ancestral = mutation[0]
        derived = mutation[-1]
        position = mutation[1:-1]

        if ancestral != "X" and derived != "X":
            if gene not in distance_map["map"]:
                distance_map["map"][gene] = {}

            if position not in distance_map["map"][gene]:
                distance_map["map"][gene][position] = []

            distance_map["map"][gene][position].append({
                "from": ancestral,
                "to": derived,
                "weight": weight
            })

    # Save the distance map.
    with open(args.output, "w") as oh:
        json.dump(distance_map, oh, sort_keys=True, indent=1)
327
shell: "echo Environment built"
SnakeMake From line 327 of master/Snakefile
341
342
343
344
345
346
shell:
    """
    python3 scripts/concatenate_tables.py \
        --tables {input.errors} \
        --output {output.errors}
    """
SnakeMake From line 341 of master/Snakefile
354
355
356
357
358
359
shell:
    """
    python3 scripts/concatenate_tables.py \
        --tables {input.coefficients} \
        --output {output.coefficients}
    """
SnakeMake From line 354 of master/Snakefile
373
374
375
376
377
378
shell:
    """
    python3 scripts/collect_tables.py \
      --tables {input} \
      --output {output.clades}
    """
SnakeMake From line 373 of master/Snakefile
386
387
388
389
390
391
shell:
    """
    python3 scripts/collect_tables.py \
      --tables {input} \
      --output {output.ranks}
    """
SnakeMake From line 386 of master/Snakefile
397
shell: "gs -dBATCH -dNOPAUSE -q -sDEVICE=pdfwrite -sOutputFile={output} {input}"
SnakeMake From line 397 of master/Snakefile
522
523
524
525
526
527
528
shell:
    """
    while read original_name new_name
    do
        ln manuscript/figures/$original_name manuscript/$new_name
    done < {input.figure_names}
    """
SnakeMake From line 522 of master/Snakefile
547
548
549
550
551
552
553
554
shell:
    """
    cd manuscript
    pdflatex -draftmode {params.title}
    bibtex {params.title}
    pdflatex -draftmode {params.title}
    pdflatex {params.title}
    """
SnakeMake From line 547 of master/Snakefile
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from augur.utils import read_node_data
import json

from forecast.fitness_predictors import inverse_cross_immunity_amplitude, cross_immunity_cost


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Calculate cross-immunity",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--frequencies", required=True, help="JSON of frequencies per sample")
    parser.add_argument("--distances", required=True, help="JSON of distances between samples")
    parser.add_argument("--date-annotations", required=True, help="JSON of branch lengths and date annotations from augur refine for samples in the given tree")
    parser.add_argument("--distance-attributes", nargs="+", required=True, help="names of attributes to use from the given distances JSON")
    parser.add_argument("--immunity-attributes", nargs="+", required=True, help="names of attributes to use for the calculated cross-immunities")
    parser.add_argument("--decay-factors", nargs="+", required=True, type=float, help="list of decay factors (d_0) for each given immunity attribute")
    parser.add_argument("--years-to-wane", type=int, help="number of years after which immunity wanes completely")
    parser.add_argument("--output", required=True, help="cross-immunities calculated from the given distances and frequencies")
    args = parser.parse_args()

    # Load frequencies.
    with open(args.frequencies, "r") as fh:
        frequencies = json.load(fh)

    # Identify maximum frequency per sample.
    max_frequency_per_sample = {
        sample: float(max(sample_frequencies["frequencies"]))
        for sample, sample_frequencies in frequencies.items()
        if sample not in ["pivots", "generated_by"] and not sample.startswith("count")
    }
    current_timepoint = frequencies["pivots"][-1]

    # Load distances.
    with open(args.distances, "r") as fh:
        distances = json.load(fh)

    distances = distances["nodes"]

    # Load date annotations and annotate tree with them.
    date_annotations = read_node_data(args.date_annotations)
    date_by_node_name = {}
    for node, annotations in date_annotations["nodes"].items():
        date_by_node_name[node] = annotations["numdate"]

    """
  "A/Acre/15093/2010": {
   "ep": 9,
   "ne": 8,
   "rb": 3
  },
    """
    if args.years_to_wane is not None:
        print("Waning effect with max years of %i" % args.years_to_wane)
    else:
        print("No waning effect")

    # Calculate cross-immunity for distances defined by the given attributes.
    cross_immunities = {}
    for sample, sample_distances in distances.items():
        for distance_attribute, immunity_attribute, decay_factor in zip(args.distance_attributes, args.immunity_attributes, args.decay_factors):
            if distance_attribute not in sample_distances:
                continue

            if sample not in cross_immunities:
                cross_immunities[sample] = {}

            # Calculate cross-immunity cost from all distances to the current
            # sample. This negative value increases for samples that are
            # increasingly distant from previous samples.
            cross_immunity = 0.0
            for past_sample, distance in sample_distances[distance_attribute].items():
                # Calculate effect of waning immunity.
                if args.years_to_wane is not None:
                    waning_effect = max(1 - ((current_timepoint - date_by_node_name[past_sample]) / args.years_to_wane), 0)
                else:
                    waning_effect = 1.0

                # Calculate cost of cross-immunity with waning.
                if waning_effect > 0:
                    cross_immunity += waning_effect * max_frequency_per_sample[past_sample] * cross_immunity_cost(
                        distance,
                        decay_factor
                    )

            cross_immunities[sample][immunity_attribute] = -1 * cross_immunity

    # Export cross-immunities to JSON.
    with open(args.output, "w") as oh:
        json.dump({"nodes": cross_immunities}, oh, indent=1, sort_keys=True)
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
import argparse
import csv
import cv2
import json
import numpy as np
import pandas as pd
from scipy.optimize import minimize
import sys
import time

from forecast.fitness_model import get_train_validate_timepoints
from forecast.metrics import add_pseudocounts_to_frequencies, negative_information_gain
from forecast.metrics import mean_absolute_error, sum_of_squared_errors, root_mean_square_error
from weighted_distances import get_distances_by_sample_names, get_distance_matrix_by_sample_names

MAX_PROJECTED_FREQUENCY = 1e3
FREQUENCY_TOLERANCE = 1e-3

np.random.seed(314159)


def sum_of_differences(observed, estimated, y_diff, **kwargs):
    """
    Calculates the sum of squared errors for observed and estimated values.

    Parameters
    ----------
    observed : numpy.ndarray
        observed values

    estimated : numpy.ndarray
        estimated values

    y_diff : numpy.ndarray
        differences between observed and estimated values

    Returns
    -------
    float :
        sum of differences between estimated and observed future values
    """
    return np.sum(y_diff)


class ExponentialGrowthModel(object):
    def __init__(self, predictors, delta_time, l1_lambda, cost_function):
        """Construct an empty exponential growth model instance.

        Parameters
        ----------
        predictors : list
            a list of predictors to estimate coefficients for

        delta_time : float
            number of years into the future to project frequencies

        l1_lambda : float
            hyperparameter to scale L1 regularization penalty for non-zero coefficients

        cost_function : callable
            function returning the error to be minimized between observed and estimated values

        Returns
        -------
        ExponentialGrowthModel
        """
        self.predictors = predictors
        self.delta_time = delta_time
        self.l1_lambda = l1_lambda
        self.cost_function = cost_function

    def calculate_mean_stds(self, X, predictors):
        """Calculate mean standard deviations of predictors by timepoints prior to
        fitting.

        Parameters
        ----------
        X : pandas.DataFrame
            standardized tip attributes by timepoint

        predictors : list
            names of predictors for which mean standard deviations should be calculated

        Returns
        -------
        ndarray :
            mean standard deviation per predictor across all timepoints

        """
        # Note that the pandas standard deviation method ignores missing data
        # whereas numpy requires the use of specific NaN-aware functions (nanstd).
        return X.loc[:, ["timepoint"] + predictors].groupby("timepoint").std().mean().values

    def standardize_predictors(self, predictors, mean_stds, initial_frequencies):
        """Standardize the values for the given predictors by centering on the mean of
        each predictor and scaling by the mean standard deviation provided.

        Parameters
        ----------
        predictors : ndarray
            matrix of values per sample (rows) and predictor (columns)

        mean_stds : ndarray
            mean standard deviations of predictors across all training
            timepoints

        initial_frequencies : ndarray
            initial frequencies of samples corresponding to each row of the
            given predictors

        Returns
        -------
        ndarray :
            standardized predictor values

        """
        means = np.average(predictors, weights=initial_frequencies, axis=0)
        variances = np.average((predictors - means) ** 2, weights=initial_frequencies, axis=0)
        stds = np.sqrt(variances)

        nonzero_stds = np.where(stds)[0]

        if len(nonzero_stds) == 0:
            return predictors

        standardized_predictors = predictors
        standardized_predictors[:, nonzero_stds] = (predictors[:, nonzero_stds] - means[nonzero_stds]) / stds[nonzero_stds]

        return standardized_predictors

    def get_fitnesses(self, coefficients, predictors):
        """Apply the coefficients to the predictors and sum them to get strain
        fitnesses.

        Parameters
        ----------
        coefficients : ndarray or list
            coefficients for given predictors

        predictors : ndarray
            predictor values per sample (n x p matrix for p predictors and n samples)

        Returns
        -------
        ndarray :
            fitnesses per sample
        """
        return np.sum(predictors * coefficients, axis=-1)

    def project_frequencies(self, initial_frequencies, fitnesses, delta_time):
        """Project the given initial frequencies into the future by the given delta time
        based on the given fitnesses.

        Returns the projected frequencies normalized to sum to 1.

        Parameters
        ----------
        initial_frequencies : ndarray
            floating point frequencies for all samples in a timepoint

        fitnesses : ndarray
            floating point fitnesses for all samples in same order as given frequencies

        delta_time : float
            number of years to project into the future

        Returns
        -------
        ndarray :
            projected and normalized frequencies
        """
        # Exponentiate the fitnesses and multiply them by strain frequencies.
        projected_frequencies = initial_frequencies * np.exp(fitnesses * self.delta_time)

        # Replace infinite values a very large number that can still be summed
        # across all timepoints. This addresses the case of buffer overflows in
        # exponentiation which can produce both of these problematic values.
        projected_frequencies[np.isinf(projected_frequencies)] = MAX_PROJECTED_FREQUENCY

        # Sum the projected frequencies.
        total_projected_frequencies = projected_frequencies.sum()

        # Normalize the projected frequencies.
        projected_frequencies = projected_frequencies / total_projected_frequencies

        # Confirm that projected frequencies sum to 1.
        assert np.isclose(projected_frequencies.sum(), np.ones(1), atol=FREQUENCY_TOLERANCE)

        # Confirm that all projected frequencies are proper numbers.
        assert np.isnan(projected_frequencies).sum() == 0

        return projected_frequencies

    def _fit(self, coefficients, X, y, use_l1_penalty=True):
        """Calculate the error between observed and estimated values for the given
        parameters and data.

        Parameters
        ----------
        coefficients : ndarray
            coefficients for each of the model's predictors

        X : pandas.DataFrame
            standardized tip attributes by timepoint

        y : pandas.DataFrame
            final clade frequencies at delta time in the future from each
            timepoint in the given tip attributes table

        Returns
        -------
        float :
            error between estimated values using the given coefficients and
            input data and the observed values
        """
        # Estimate final frequencies.
        y_hat = self.predict(X, coefficients)

        # Merge estimated and observed frequencies. The left join enables
        # tracking of clades that die in the future and are therefore not
        # observed in the future frequencies data frame.
        frequencies = y_hat.merge(
            y,
            how="left",
            on=["timepoint", "clade_membership"],
            suffixes=["_estimated", "_observed"]
        )
        frequencies["frequency_observed"] = frequencies["frequency_observed"].fillna(0.0)

        # Calculate initial frequencies for use by cost function.
        initial_frequencies = X.groupby([
            "timepoint",
            "clade_membership"
        ])["frequency"].sum().reset_index()

        # Annotate future frequencies with initial frequencies.
        frequencies = frequencies.merge(
            initial_frequencies,
            how="inner",
            on=["timepoint", "clade_membership"]
        )

        # Calculate the error between the observed and estimated frequencies.
        error = self.cost_function(
            frequencies["frequency_observed"],
            frequencies["frequency_estimated"],
            initial=frequencies["frequency"]
        )

        if use_l1_penalty:
            l1_penalty = self.l1_lambda * np.abs(coefficients).sum()
        else:
            l1_penalty = 0.0

        return error + l1_penalty

    def fit(self, X, y):
        """Fit a model to the given input data, producing beta coefficients for each of
        the model's predictors.

        Coefficients are stored in the `coef_` attribute, after the pattern of
        scikit-learn models.

        Parameters
        ----------
        X : pandas.DataFrame
            standardized tip attributes by timepoint

        y : pandas.DataFrame
            final clade frequencies at delta time in the future from each
            timepoint in the given tip attributes table

        Returns
        -------
        float :
            model training error

        """
        # Calculate mean standard deviations of predictors by timepoints prior
        # to fitting.
        self.mean_stds_ = self.calculate_mean_stds(X, self.predictors)

        # Find coefficients that minimize the model's cost function.
        if hasattr(self, "coef_"):
            # Use the previous coefficients +/- a small random offset (+/- 0.05)
            # to prevent getting stuck in local minima.
            initial_coefficients = self.coef_ + (0.1 * np.random.random(len(self.predictors)) - 0.05)
        else:
            # If no previous coefficients exist, sample random values between -0.5 and 0.5.
            initial_coefficients = np.random.random(len(self.predictors)) - 0.5

        results = minimize(
            self._fit,
            initial_coefficients,
            args=(X, y),
            method="Nelder-Mead",
            options={"disp": False}
        )
        self.coef_ = results.x

        training_error = self.score(X, y)

        return training_error

    def predict(self, X, coefficients=None, mean_stds=None):
        """Calculate the estimate final frequencies of all clades in the given tip
        attributes data frame using previously calculated beta coefficients.

        Parameters
        ----------
        X : pandas.DataFrame
            standardized tip attributes by timepoint

        coefficients : ndarray
            optional coefficients to use for each of the model's predictors
            instead of the model's currently defined coefficients

        mean_stds : ndarray
            optional mean standard deviations of predictors across all training
            timepoints

        Returns
        -------
        pandas.DataFrame
            estimated final clade frequencies at delta time in the future for
            each clade from each timepoint in the given tip attributes table

        """
        # Use model coefficients, if none are provided.
        if coefficients is None:
            coefficients = self.coef_

        if mean_stds is None:
            mean_stds = self.mean_stds_

        estimated_frequencies = []
        for timepoint, timepoint_df in X.groupby("timepoint"):
            # Select frequencies from timepoint.
            initial_frequencies = timepoint_df["frequency"].values

            # Select predictors from the timepoint.
            predictors = timepoint_df.loc[:, self.predictors].values

            # Standardize predictors by timepoint centering by means at
            # timepoint and mean standard deviation provided.
            standardized_predictors = self.standardize_predictors(predictors, mean_stds, initial_frequencies)

            # Calculate fitnesses.
            fitnesses = self.get_fitnesses(coefficients, standardized_predictors)

            # Project frequencies.
            projected_frequencies = self.project_frequencies(
                initial_frequencies,
                fitnesses,
                self.delta_time
            )

            # Sum the estimated frequencies by clade.
            projected_timepoint_df = timepoint_df[["timepoint", "clade_membership"]].copy()
            projected_timepoint_df["frequency"] = projected_frequencies
            projected_clade_frequencies = projected_timepoint_df.groupby([
                "timepoint",
                "clade_membership"
            ])["frequency"].sum().reset_index()

            estimated_frequencies.append(projected_clade_frequencies)

        # Collect all estimated frequencies by timepoint.
        estimated_frequencies = pd.concat(estimated_frequencies)
        return estimated_frequencies

    def score(self, X, y):
        """Calculate model error between the estimated final clade frequencies for the
        given tip attributes, `X`, and the observed final clade frequencies in
        `y`.

        Parameters
        ----------
        X : pandas.DataFrame
            standardized tip attributes by timepoint

        y : pandas.DataFrame
            final clade frequencies at delta time in the future from each
            timepoint in the given tip attributes table

        Returns
        -------
        float :
            model error
        """
        return self._fit(self.coef_, X, y, use_l1_penalty=False)


class DistanceExponentialGrowthModel(ExponentialGrowthModel):
    def __init__(self, predictors, delta_time, l1_lambda, cost_function, distances):
        super().__init__(predictors, delta_time, l1_lambda, cost_function)
        self.distances = distances

    def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distance=False):
        """Calculate the error between observed and estimated values for the given
        parameters and data.

        Parameters
        ----------
        coefficients : ndarray
            coefficients for each of the model's predictors

        X : pandas.DataFrame
            standardized tip attributes by timepoint

        y : pandas.DataFrame
            final weighted distances at delta time in the future from each
            timepoint in the given tip attributes table

        Returns
        -------
        float :
            error between estimated values using the given coefficients and
            input data and the observed values
        """
        # Estimate target values.
        y_hat = self.predict(X, coefficients)

        # Calculate EMD for each timepoint in the estimated values and sum that
        # distance across all timepoints.
        error = 0.0
        count = 0
        for timepoint, timepoint_df in y_hat.groupby("timepoint"):
            samples_a = timepoint_df["strain"]
            sample_a_initial_frequencies = timepoint_df["frequency"].values.astype(np.float32)
            sample_a_frequencies = timepoint_df["projected_frequency"].values.astype(np.float32)

            future_timepoint_df = y[y["timepoint"] == timepoint]
            assert future_timepoint_df.shape[0] > 0

            samples_b = future_timepoint_df["strain"]
            sample_b_frequencies = future_timepoint_df["frequency"].values.astype(np.float32)

            distance_matrix = get_distance_matrix_by_sample_names(
                samples_a,
                samples_b,
                self.distances
            ).astype(np.float32)

            # Calculate the optimal distance to the future timepoint by mapping
            # the frequency of each future strain to the closest strain in the
            # current timepoint.
            if calculate_optimal_distance:
                # For each strain in the future timepoint, identify the closest
                # strain in the current timepoint. This is an array of current
                # strain indices (one index per future strain).
                closest_strain_to_future = np.argmin(distance_matrix, axis=0)

                # Sum the frequencies of the future strains across each closest
                # strain in the current timepoint. This can and will often
                # result in a few current strains accuring most of the future
                # frequencies.
                estimated_frequencies = np.zeros_like(sample_a_frequencies)
                for i in range(sample_b_frequencies.shape[0]):
                    estimated_frequencies[closest_strain_to_future[i]] += sample_b_frequencies[i]

                # Calculate earth mover's distance to the future based on this
                # optimal (or, at least, greedy) mapping of strains between
                # timepoints. The resulting EMD value should be the best any
                # model can hope to perform and establishes a lower bound for
                # all models.
                self.optimal_model_emd, _, optimal_model_flow = cv2.EMD(
                    estimated_frequencies,
                    sample_b_frequencies,
                    cv2.DIST_USER,
                    cost=distance_matrix
                )

            # Estimate the distance between the model's estimated future and the
            # observed future populations.
            model_emd, _, self.model_flow = cv2.EMD(
                sample_a_frequencies,
                sample_b_frequencies,
                cv2.DIST_USER,
                cost=distance_matrix
            )

            error += model_emd
            count += 1

        error = error / float(count)

        if use_l1_penalty:
            l1_penalty = self.l1_lambda * np.abs(coefficients).sum()
        else:
            l1_penalty = 0.0

        return error + l1_penalty

    def _fit_distance(self, coefficients, X, y, use_l1_penalty=True):
        """Calculate the error between observed and estimated values for the given
        parameters and data.

        Parameters
        ----------
        coefficients : ndarray
            coefficients for each of the model's predictors

        X : pandas.DataFrame
            standardized tip attributes by timepoint

        y : pandas.DataFrame
            final weighted distances at delta time in the future from each
            timepoint in the given tip attributes table

        Returns
        -------
        float :
            error between estimated values using the given coefficients and
            input data and the observed values
        """
        # Estimate target values.
        y_hat = self.predict(X, coefficients)

        # Calculate weighted distance to the future for each timepoint in the
        # estimated values and sum that distance across all timepoints.
        error = 0.0
        null_error = 0.0
        count = 0
        for timepoint, timepoint_df in y_hat.groupby("timepoint"):
            samples_a = timepoint_df["strain"]
            sample_a_initial_frequencies = timepoint_df["frequency"].values
            sample_a_frequencies = timepoint_df["projected_frequency"].values
            sample_a_weighted_distance_to_future = timepoint_df["weighted_distance_to_future"].values

            future_timepoint_df = y[y["timepoint"] == timepoint]
            assert future_timepoint_df.shape[0] > 0

            samples_b = future_timepoint_df["strain"]
            sample_b_frequencies = future_timepoint_df["frequency"].values
            sample_b_weighted_distance_to_present = future_timepoint_df["weighted_distance_to_present"].values

            d_t_u = (sample_a_initial_frequencies * sample_a_weighted_distance_to_future).sum()
            d_u_hat_u = (sample_a_frequencies * sample_a_weighted_distance_to_future).sum()
            d_u_u = (sample_b_frequencies * sample_b_weighted_distance_to_present).sum()

            null_error += d_t_u

            error += (d_u_hat_u - d_u_u) / d_t_u
            count += 1

        null_error = null_error / float(count)
        error = error / float(count)

        if use_l1_penalty:
            l1_penalty = self.l1_lambda * np.abs(coefficients).sum()
        else:
            l1_penalty = 0.0

        return error + l1_penalty

    def predict(self, X, coefficients=None, mean_stds=None):
        """Calculate the estimated final weighted distance between tips at each
        timepoint and at that timepoint plus delta months in the future.

        Parameters
        ----------
        X : pandas.DataFrame
            standardized tip attributes by timepoint

        coefficients : ndarray
            optional coefficients to use for each of the model's predictors
            instead of the model's currently defined coefficients

        mean_stds : ndarray
            optional mean standard deviations of predictors across all training
            timepoints

        Returns
        -------
        pandas.DataFrame
            estimated weighted distances at delta time in the future for
            each tip from each timepoint in the given tip attributes table

        """
        # Use model coefficients, if none are provided.
        if coefficients is None:
            coefficients = self.coef_
            model_is_fit = True
        else:
            model_is_fit = False

        if mean_stds is None:
            mean_stds = self.mean_stds_

        estimated_targets = []
        for timepoint, timepoint_df in X.groupby("timepoint"):
            # Select frequencies from timepoint.
            initial_frequencies = timepoint_df["frequency"].values

            # Select predictors from the timepoint.
            predictors = timepoint_df.loc[:, self.predictors].values

            # Standardize predictors by timepoint centering by means at
            # timepoint and mean standard deviation provided.
            mean_stds = timepoint_df.loc[:, self.predictors].std().values
            standardized_predictors = self.standardize_predictors(predictors, mean_stds, initial_frequencies)

            # Calculate fitnesses.
            fitnesses = self.get_fitnesses(coefficients, standardized_predictors)

            # Project frequencies.
            projected_frequencies = self.project_frequencies(
                initial_frequencies,
                fitnesses,
                self.delta_time
            )

            # Calculate observed distance between current tips and the future
            # using projected frequencies and weighted distances to the future.
            columns_to_extract = ["timepoint", "strain", "frequency"]
            optional_columns = ["weighted_distance_to_present", "weighted_distance_to_future"]
            for column in optional_columns:
                if column in timepoint_df.columns:
                    columns_to_extract.append(column)

            projected_timepoint_df = timepoint_df[columns_to_extract].copy()
            projected_timepoint_df["fitness"] = fitnesses
            projected_timepoint_df["projected_frequency"] = projected_frequencies

            if model_is_fit:
                # Calculate estimate distance between current tips and future tips
                # based on projections of current tips.
                estimated_weighted_distance_to_future = []
                for current_tip, current_tip_frequency in projected_timepoint_df.loc[:, ["strain", "frequency"]].values:
                    weighted_distance_to_future = 0.0
                    for other_tip, other_tip_projected_frequency in projected_timepoint_df.loc[:, ["strain", "projected_frequency"]].values:
                        weighted_distance_to_future += other_tip_projected_frequency * self.distances[current_tip][other_tip]

                    estimated_weighted_distance_to_future.append(weighted_distance_to_future)

                projected_timepoint_df["y"] = np.array(estimated_weighted_distance_to_future)
            else:
                projected_timepoint_df["y"] = np.nan

            estimated_targets.append(projected_timepoint_df)

        # Collect all estimated targets by timepoint.
        estimated_targets = pd.concat(estimated_targets, ignore_index=True)
        return estimated_targets


def cross_validate(model_class, model_kwargs, data, targets, train_validate_timepoints, coefficients=None, group_by="clade_membership",
                   include_attributes=False):
    """Calculate cross-validation scores for the given data and targets across the
    given train/validate timepoints.

    Parameters
    ----------
    model : ExponentialGrowthModel
        an instance of a model with defined hyperparameters including a list of
        predictors to use for fitting

    data : pandas.DataFrame
        standardized input attributes to use for model fitting

    targets : pandas.DataFrame
        observed outputs to fit the model to

    train_validate_timepoints : list
        a list of dictionaries of lists indexed by "train" and "validate" keys
        and containing timepoints to use for model training and validation,
        respectively

    coefficients : ndarray
        an optional array of fixed coefficients for the given model's predictors
        to use when calculating cross-validation error for specific models
        (e.g., naive forecasts)

    group_by : string
        column of the tip attributes by which they should be grouped to
        calculate the total number of samples in the model (e.g., group by clade
        or strain)

    include_attributes : boolean
        specifies whether tip attribute data used to train/validate models
        should be included in the output per training window

    Returns
    -------
    list
        a list of dictionaries containing cross-validation results with scores,
        training and validation results, and beta coefficients per timepoint

    """
    results = []
    differences_of_model_and_naive_errors = []
    previous_coefficients = None

    for timepoints in train_validate_timepoints:
        model = model_class(**model_kwargs)

        if previous_coefficients is not None:
            model.coef_ = previous_coefficients

        # Get training and validation timepoints.
        training_timepoints = pd.to_datetime(timepoints["train"])
        validation_timepoint = pd.to_datetime(timepoints["validate"])

        # Get training data by timepoints.
        training_X = data[data["timepoint"].isin(training_timepoints)].copy()
        training_y = targets[targets["timepoint"].isin(training_timepoints)].copy()

        # Fit a model to the training data.
        if coefficients is None:
            start_time = time.time()
            training_error = model.fit(training_X, training_y)
            end_time = time.time()
            previous_coefficients = model.coef_
            null_training_error = model._fit(np.zeros_like(model.coef_), training_X, training_y)
        else:
            start_time = end_time = time.time()
            model.coef_ = coefficients
            model.mean_stds_ = model.calculate_mean_stds(training_X, model.predictors)
            training_error = model.score(training_X, training_y)
            null_training_error = training_error

        # Get validation data by timepoints.
        validation_X = data[data["timepoint"] == validation_timepoint].copy()
        validation_y = targets[targets["timepoint"] == validation_timepoint].copy()

        # Calculate the model score for the validation data.
        validation_error = model.score(validation_X, validation_y)
        null_validation_error = model._fit(np.zeros_like(model.coef_), validation_X, validation_y, calculate_optimal_distance=True)
        optimal_validation_error = model.optimal_model_emd
        differences_of_model_and_naive_errors.append(validation_error - null_validation_error)
        print(
            "%s\t%s\t%.2f\t%.2f\t%.2f\t%.2f\t%.2f\t%s\t%.2f\t%.2f" % (
                training_timepoints[-1].strftime("%Y-%m"),
                validation_timepoint.strftime("%Y-%m"),
                training_error,
                null_training_error,
                validation_error,
                null_validation_error,
                optimal_validation_error,
                model.coef_,
                (np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors)),
                end_time - start_time
            ),
            flush=True
        )

        # Get the estimated frequencies for training and validation sets to export.
        training_y_hat = model.predict(training_X)
        validation_y_hat = model.predict(validation_X)

        # Convert timestamps to a serializable format.
        for df in [training_X, training_y, training_y_hat, validation_X, validation_y, validation_y_hat]:
            for column in ["timepoint", "future_timepoint"]:
                if column in df.columns:
                    df[column] = df[column].dt.strftime("%Y-%m-%d")

        # Store training results, beta coefficients, and validation results.
        result = {
            "predictors": model.predictors,
            "training_data": {
                "y": training_y.to_dict(orient="records"),
                "y_hat": training_y_hat.to_dict(orient="records")
            },
            "training_n": training_X[group_by].unique().shape[0],
            "training_error": training_error,
            "coefficients": model.coef_.tolist(),
            "mean_stds": model.mean_stds_.tolist(),
            "validation_data": {
                "y": validation_y.to_dict(orient="records"),
                "y_hat": validation_y_hat.to_dict(orient="records")
            },
            "validation_n": validation_X[group_by].unique().shape[0],
            "validation_error": validation_error,
            "null_validation_error": null_validation_error,
            "optimal_validation_error": optimal_validation_error,
            "last_training_timepoint": training_timepoints[-1].strftime("%Y-%m-%d"),
            "validation_timepoint": validation_timepoint.strftime("%Y-%m-%d")
        }

        # Include tip attributes, if requested.
        if include_attributes:
            result["training_data"]["X"] = training_X.to_dict(orient="records")
            result["validation_data"]["X"] = validation_X.to_dict(orient="records")

        results.append(result)

    # Return results for all validation timepoints.
    print("Mean difference between model and naive: %.4f" % (sum(differences_of_model_and_naive_errors) / len(differences_of_model_and_naive_errors)), flush=True)
    print("Proportion of timepoints when model < naive: %.2f" % ((np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors))))
    return results


def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None, group_by="clade_membership",
                   include_attributes=False):
    """Calculate test scores for the given data and targets across the given
    timepoints.

    Parameters
    ----------
    model : ExponentialGrowthModel
        an instance of a model with defined hyperparameters including a list of
        predictors to use for fitting

    data : pandas.DataFrame
        standardized input attributes to use for model fitting

    targets : pandas.DataFrame
        observed outputs to test the model with

    timepoints : list
        a list of timepoint strings in YYYY-MM-DD format

    coefficients : ndarray
        an array of fixed coefficients for the given model's predictors

    group_by : string
        column of the tip attributes by which they should be grouped to
        calculate the total number of samples in the model (e.g., group by clade
        or strain)

    include_attributes : boolean
        specifies whether tip attribute data used to test models should be
        included in the output per timepoint

    Returns
    -------
    list
        a list of dictionaries containing test results with scores per timepoint

    """
    results = []
    differences_of_model_and_naive_errors = []

    for timepoint in timepoints:
        model = model_class(**model_kwargs)
        model.coef_ = coefficients
        model.mean_stds_ = np.zeros_like(coefficients)

        # Get training and validation timepoints.
        test_timepoint = pd.to_datetime(timepoint)

        # Get test data by timepoints.
        test_X = data[data["timepoint"] == test_timepoint].copy()
        test_y = targets[targets["timepoint"] == test_timepoint].copy()

        # Calculate the model score for the validation data.
        test_error = model.score(test_X, test_y)
        null_test_error = model._fit(np.zeros_like(model.coef_), test_X, test_y, calculate_optimal_distance=True)
        optimal_test_error = model.optimal_model_emd
        differences_of_model_and_naive_errors.append(test_error - null_test_error)
        print(
            "%s\t%.2f\t%.2f\t%.2f\t%s\t%.2f" % (
                test_timepoint.strftime("%Y-%m"),
                test_error,
                null_test_error,
                optimal_test_error,
                model.coef_,
                (np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors))
            ),
            flush=True
        )

        # Get the estimated frequencies for test sets to export.
        test_y_hat = model.predict(test_X)

        # Convert timestamps to a serializable format.
        for df in [test_X, test_y, test_y_hat]:
            for column in ["timepoint", "future_timepoint"]:
                if column in df.columns:
                    df[column] = df[column].dt.strftime("%Y-%m-%d")

        # Store test results and beta coefficients.
        result = {
            "predictors": model.predictors,
            "coefficients": model.coef_.tolist(),
            "mean_stds": model.mean_stds_.tolist(),
            "validation_data": {
                "y": test_y.to_dict(orient="records"),
                "y_hat": test_y_hat.to_dict(orient="records")
            },
            "validation_n": test_X[group_by].unique().shape[0],
            "validation_error": test_error,
            "null_validation_error": null_test_error,
            "optimal_validation_error": optimal_test_error,
            "validation_timepoint": test_timepoint.strftime("%Y-%m-%d")
        }

        # Include tip attributes, if requested.
        if include_attributes:
            result["validation_data"]["X"] = test_X.to_dict(orient="records")

        results.append(result)

    # Return results for all validation timepoints.
    print("Mean difference between model and naive: %.4f" % (sum(differences_of_model_and_naive_errors) / len(differences_of_model_and_naive_errors)), flush=True)
    print("Proportion of timepoints when model < naive: %.2f" % ((np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors))))
    return results


def summarize_scores(scores, include_scores=False):
    """Summarize model errors across timepoints.

    Parameters
    ----------
    scores : list
        a list of cross-validation results including training errors,
        cross-validation errors, and beta coefficients OR a list of test errors

    include_scores : boolean
        specifies whether cross-validation scores should be included in the
        output per timepoint

    Returns
    -------
    dict :
        a dictionary of all cross-validation results plus summary statistics for
        training, cross-validation, and beta coefficients OR test results

    """
    summary = {
        "predictors": scores[0]["predictors"]
    }

    if include_scores:
        summary["scores"] = scores

    validation_errors = [score["validation_error"] for score in scores]
    summary["cv_error_mean"] = np.mean(validation_errors)
    summary["cv_error_std"] = np.std(validation_errors)

    coefficients = np.array([
        np.array(score["coefficients"])
        for score in scores
    ])
    summary["coefficients_mean"] = coefficients.mean(axis=0).tolist()
    summary["coefficients_std"] = coefficients.std(axis=0).tolist()

    mean_stds = np.array([
        np.array(score["mean_stds"])
        for score in scores
    ])
    summary["mean_stds_mean"] = mean_stds.mean(axis=0).tolist()
    summary["mean_stds_std"] = mean_stds.std(axis=0).tolist()

    return summary


def get_errors_by_timepoint(scores):
    """Convert cross-validation errors into a data frame by timepoint and predictors.

    Parameters
    ----------
    scores : list
        a list of cross-validation results including training errors,
        cross-validation errors, and beta coefficients

    Returns
    -------
    pandas.DataFrame
    """
    predictors = "-".join(scores[0]["predictors"])
    errors_by_time = []
    for score in scores:
        errors_by_time.append({
            "predictors": predictors,
            "validation_timepoint": pd.to_datetime(score["validation_timepoint"]),
            "validation_error": score["validation_error"],
            "null_validation_error": score["null_validation_error"],
            "optimal_validation_error": score["optimal_validation_error"],
            "validation_n": score["validation_n"]
        })

    return pd.DataFrame(errors_by_time)


def get_coefficients_by_timepoint(scores):
    """Convert model coefficients into a data frame by timepoint and predictors.

    Parameters
    ----------
    scores : list
        a list of cross-validation results including training errors,
        cross-validation errors, and beta coefficients

    Returns
    -------
    pandas.DataFrame
    """
    predictors = "-".join(scores[0]["predictors"])
    coefficients_by_time = []
    for score in scores:
        for predictor, coefficient in zip(score["predictors"], score["coefficients"]):
            coefficients_by_time.append({
                "predictors": predictors,
                "predictor": predictor,
                "coefficient": coefficient,
                "validation_timepoint": pd.to_datetime(score["validation_timepoint"])
            })

    return pd.DataFrame(coefficients_by_time)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors")
    parser.add_argument("--output", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics")
    parser.add_argument("--predictors", nargs="+", help="tip attribute columns to use as predictors of final clade frequencies; optional if a fixed model is provided")
    parser.add_argument("--delta-months", type=int, help="number of months to project clade frequencies into the future")
    parser.add_argument("--target", required=True, choices=["clades", "distances"], help="target for models to fit")
    parser.add_argument("--final-clade-frequencies", help="tab-delimited file of clades per timepoint and their corresponding tips and tip frequencies at the given delta time in the future")
    parser.add_argument("--distances", help="tab-delimited file of distances between pairs of samples")
    parser.add_argument("--training-window", type=int, default=4, help="number of years required for model training")
    parser.add_argument("--l1-lambda", type=float, default=0.0, help="L1 regularization lambda")
    parser.add_argument("--cost-function", default="sse", choices=["sse", "rmse", "mae", "information_gain", "diffsum"], help="name of the function that returns the error between observed and estimated values")
    parser.add_argument("--pseudocount", type=float, help="pseudocount numerator to adjust all frequencies by, enabling some information theoretic metrics like information gain")
    parser.add_argument("--include-attributes", action="store_true", help="include attribute data used to train/validate models in the cross-validation output")
    parser.add_argument("--include-scores", action="store_true", help="include score data resulting from cross-validation output")
    parser.add_argument("--errors-by-timepoint", help="optional data frame of cross-validation errors by validation timepoint")
    parser.add_argument("--coefficients-by-timepoint", help="optional data frame of coefficients by validation timepoint")
    parser.add_argument("--fixed-model", help="optional model JSON to use as a fixed model for calculation of test error in forecasts")

    args = parser.parse_args()

    cost_functions_by_name = {
        "sse": sum_of_squared_errors,
        "rmse": root_mean_square_error,
        "mae": mean_absolute_error,
        "information_gain": negative_information_gain,
        "diffsum": sum_of_differences
    }

    # Load standardized tip attributes subsetting to tip name, clade, frequency,
    # and requested predictors.
    tips = pd.read_csv(
        args.tip_attributes,
        sep="\t",
        parse_dates=["timepoint"]
    )

    if args.target == "clades":
        # Load final clade tip frequencies.
        final_clade_tip_frequencies = pd.read_csv(
            args.final_clade_frequencies,
            sep="\t",
            parse_dates=["initial_timepoint", "final_timepoint"]
        )

        # If a pseudocount numerator has been provided, update the given tip
        # frequencies both from current and future timepoints.
        if args.pseudocount is not None and args.pseudocount > 0.0:
            tips = add_pseudocounts_to_frequencies(tips, args.pseudocount)
            print("Sum of tip frequencies by timepoint: ",
                  tips.groupby("timepoint")["frequency"].sum())
            final_clade_tip_frequencies = add_pseudocounts_to_frequencies(
                final_clade_tip_frequencies,
                args.pseudocount,
                timepoint_column="initial_timepoint"
            )
            print("Sum of tip frequencies by timepoint: ",
                  final_clade_tip_frequencies.groupby("initial_timepoint")["frequency"].sum())

        # Aggregate final clade frequencies.
        final_clade_frequencies = final_clade_tip_frequencies.groupby([
            "initial_timepoint",
            "clade_membership"
        ])["frequency"].sum().reset_index()

        # Rename initial timepoint column for comparison with tip attribute data.
        targets = final_clade_frequencies.rename(
            columns={"initial_timepoint": "timepoint"}
        )
        model_class = ExponentialGrowthModel
        model_kwargs = {}
        group_by_attribute = "clade_membership"
    elif args.target == "distances":
        # Scale each tip's weighted distance to future populations by one minus
        # the tip's current frequency. This ensures that lower frequency tips do
        # not considered closer to the future.
        tips["y"] = tips["weighted_distance_to_future"]

        # Get strain frequency per timepoint and subtract delta time from
        # timepoint to align strain frequencies with the previous timepoint and
        # make them appropriate as targets for the model.
        targets = tips.loc[:, ["strain", "timepoint", "frequency", "weighted_distance_to_present", "weighted_distance_to_future", "y"]].copy()
        targets["future_timepoint"] = targets["timepoint"]

        model_class = DistanceExponentialGrowthModel

        with open(args.distances, "r") as fh:
            print("Read distances", flush=True)
            reader = csv.DictReader(fh, delimiter="\t")
            print("Get distances by sample names", flush=True)
            distances_by_sample_names = get_distances_by_sample_names(reader)
            print("Data loaded", flush=True)

        model_kwargs = {"distances": distances_by_sample_names}
        group_by_attribute = "strain"

    # Identify all available timepoints from tip attributes.
    timepoints = tips["timepoint"].dt.strftime("%Y-%m-%d").unique()

    # If a fixed model is provided, calculate test errors. Otherwise, calculate
    # cross-validation errors.
    if args.fixed_model is not None:
        # Load model details and extract mean coefficients.
        with open(args.fixed_model, "r") as fh:
            model_json = json.load(fh)

        coefficients = np.array(model_json["coefficients_mean"])
        delta_months = model_json["delta_months"]
        delta_time = delta_months / 12.0
        l1_lambda = model_json["l1_lambda"]
        training_window = model_json["training_window"]
        cost_function_name = model_json["cost_function"]
        cost_function = cost_functions_by_name[cost_function_name]
        model_kwargs.update({
            "predictors": model_json["predictors"],
            "delta_time": delta_time,
            "l1_lambda": l1_lambda,
            "cost_function": cost_function
        })

        # Find the latest timepoint we can project from based on the given delta
        # months.
        latest_timepoint = pd.to_datetime(timepoints[-1]) - pd.DateOffset(months=delta_months)
        test_timepoints = [
            timepoint
            for timepoint in timepoints
            if pd.to_datetime(timepoint) <= latest_timepoint
        ]

        # Calculate test errors/scores for the given coefficients and data at
        # the identified test timepoints.
        targets["timepoint"] = targets["timepoint"] - pd.DateOffset(months=delta_months)
        scores = test(
            model_class,
            model_kwargs,
            tips,
            targets,
            test_timepoints,
            coefficients,
            group_by=group_by_attribute,
            include_attributes=args.include_attributes
        )
    else:
        # First, confirm that all predictors are defined in the given tip
        # attributes.
        if not all([predictor in tips.columns for predictor in args.predictors]):
            print("ERROR: Not all predictors could be found in the given tip attributes table.", file=sys.stderr)
            sys.exit(1)

        # Select the cost function.
        cost_function_name = args.cost_function
        cost_function = cost_functions_by_name[cost_function_name]

        # Identify train/validate splits from timepoints.
        training_window = args.training_window
        train_validate_timepoints = get_train_validate_timepoints(
            timepoints,
            args.delta_months,
            training_window
        )

        # For each train/validate split, fit a model to the training data, and
        # evaluate the model with the validation data, storing the training results,
        # beta parameters, and validation results.
        delta_months = args.delta_months
        delta_time = delta_months / 12.0
        l1_lambda = args.l1_lambda
        model_kwargs.update({
            "predictors": args.predictors,
            "delta_time": delta_time,
            "l1_lambda": l1_lambda,
            "cost_function": cost_function
        })

        # If this is a naive model, set the coefficients to zero so cross-validation
        # can run under naive model conditions.
        if "naive" in args.predictors:
            coefficients = np.zeros(len(args.predictors))
        else:
            coefficients = None

        targets["timepoint"] = targets["timepoint"] - pd.DateOffset(months=delta_months)
        scores = cross_validate(
            model_class,
            model_kwargs,
            tips,
            targets,
            train_validate_timepoints,
            coefficients,
            group_by=group_by_attribute,
            include_attributes=args.include_attributes
        )

    # Summarize model errors including in-sample errors by AIC, out-of-sample
    # errors by cross-validation, and beta parameters across timepoints.
    model_results = summarize_scores(scores, args.include_scores)

    # Annotate parameters used to produce models.
    model_results["cost_function"] = cost_function_name
    model_results["l1_lambda"] = l1_lambda
    model_results["delta_months"] = delta_months
    model_results["training_window"] = training_window
    model_results["pseudocount"] = args.pseudocount

    # Save model fitting hyperparameters, raw results, and summary of results to
    # JSON.
    with open(args.output, "w") as fh:
        json.dump(model_results, fh, indent=1)

    # Save errors by timepoint, if requested.
    if args.errors_by_timepoint:
        errors_by_timepoint_df = get_errors_by_timepoint(scores)
        errors_by_timepoint_df.to_csv(args.errors_by_timepoint, sep="\t", header=True, index=False)

    # Save coefficients by timepoint, if requested.
    if args.coefficients_by_timepoint:
        coefficients_by_timepoint_df = get_coefficients_by_timepoint(scores)
        coefficients_by_timepoint_df.to_csv(args.coefficients_by_timepoint, sep="\t", header=True, index=False)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import argparse
import csv
import json
import numpy as np
import pandas as pd
import sys

from fit_model import DistanceExponentialGrowthModel
from weighted_distances import get_distances_by_sample_names


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors")
    parser.add_argument("--distances", help="tab-delimited file of distances between pairs of samples")
    parser.add_argument("--frequencies", help="JSON representing historical frequencies to project from")
    parser.add_argument("--model", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics")
    parser.add_argument("--delta-months", required=True, type=int, nargs="+", help="number of months to project clade frequencies into the future")
    parser.add_argument("--output-node-data", help="node data JSON of forecasts for the given tips")
    parser.add_argument("--output-frequencies", help="frequencies JSON extended with forecasts for the given tips")
    parser.add_argument("--output-table", help="table of forecasts for the given tips")

    args = parser.parse_args()

    # Confirm that at least one output file has been specified.
    outputs = [
        args.output_node_data,
        args.output_frequencies,
        args.output_table
    ]
    outputs_missing =[output is None for output in outputs]
    if all(outputs_missing):
        print("ERROR: No output files were specified", file=sys.stderr)
        sys.exit(1)

    # Load standardized tip attributes subsetting to tip name, clade, frequency,
    # and requested predictors.
    tips = pd.read_csv(
        args.tip_attributes,
        sep="\t",
        parse_dates=["timepoint"]
    )

    # Load distances.
    with open(args.distances, "r") as fh:
        reader = csv.DictReader(fh, delimiter="\t")
        distances_by_sample_names = get_distances_by_sample_names(reader)

    # Load model details
    with open(args.model, "r") as fh:
        model_json = json.load(fh)

    predictors = model_json["predictors"]
    cost_function = model_json["cost_function"]
    l1_lambda = model_json["l1_lambda"]
    coefficients = np.array(model_json["coefficients_mean"])
    mean_stds = np.array(model_json["mean_stds_mean"])

    delta_month = args.delta_months[-1]
    delta_time = delta_month / 12.0
    delta_offset = pd.DateOffset(months=delta_month)

    model = DistanceExponentialGrowthModel(
        predictors=predictors,
        delta_time=delta_time,
        cost_function=cost_function,
        l1_lambda=l1_lambda,
        distances=distances_by_sample_names
    )
    model.coef_ = coefficients
    model.mean_stds_ = mean_stds

    # collect fitness and projection
    forecasts_df = model.predict(tips)
    forecasts_df["weighted_distance_to_future_by_%s" % "-".join(predictors)] = forecasts_df["y"]
    forecasts_df["future_timepoint"] = forecasts_df["timepoint"] + delta_offset

    # collect dicts from dataframe
    strain_to_fitness = {}
    strain_to_future_timepoint = {}
    strain_to_projected_frequency = {}
    strain_to_weighted_distance_to_future = {}
    for index, row in forecasts_df.iterrows():
        strain_to_fitness[row['strain']] = row['fitness']
        strain_to_future_timepoint[row['strain']] = row["future_timepoint"].strftime("%Y-%m-%d")
        strain_to_projected_frequency[row['strain']] = row['projected_frequency']
        strain_to_weighted_distance_to_future[row['strain']] = row['y']

    # output to file
    if args.output_node_data:
        # populate node data
        node_data = {}
        strains = list(tips['strain'])
        for strain in strains:
            node_data[strain] = {
                "fitness": strain_to_fitness[strain],
                "future_timepoint": strain_to_future_timepoint[strain],
                "projected_frequency": strain_to_projected_frequency[strain],
                "weighted_distance_to_future": strain_to_weighted_distance_to_future[strain]
            }

        with open(args.output_node_data, "w") as jsonfile:
            json.dump({"nodes": node_data}, jsonfile, indent=1)

    # load historic frequencies
    if args.frequencies:
        with open(args.frequencies, "r") as fh:
            frequencies = json.load(fh)

        pivots = frequencies.pop("pivots")
        projection_pivot = pivots[-1]
    else:
        frequencies = None

    forecasts = []
    for delta_month in args.delta_months:
        delta_time = delta_month / 12.0
        delta_offset = pd.DateOffset(months=delta_month)

        model = DistanceExponentialGrowthModel(
            predictors=predictors,
            delta_time=delta_time,
            cost_function=cost_function,
            l1_lambda=l1_lambda,
            distances=distances_by_sample_names
        )
        model.coef_ = coefficients
        model.mean_stds_ = mean_stds

        # collect fitness and projection
        forecasts_df = model.predict(tips)
        forecasts_df["future_timepoint"] = forecasts_df["timepoint"] + delta_offset

        # collect dicts from dataframe
        strain_to_projected_frequency = {}
        for index, row in forecasts_df.iterrows():
            strain_to_projected_frequency[row['strain']] = row['projected_frequency']

        if frequencies is not None:
            # extend frequencies
            for strain in frequencies.keys():
                trajectory = frequencies[strain]['frequencies']
                if strain in strain_to_projected_frequency:
                    trajectory.append(strain_to_projected_frequency[strain])
                else:
                    trajectory.append(0.0)

            # extend pivots
            pivots.append(projection_pivot + delta_time)

        # Collect forecast data frames, if requested.
        if args.output_table:
            forecasts.append(forecasts_df)

    # reconnect pivots and label projection pivot
    if frequencies is not None:
        frequencies['pivots'] = pivots
        frequencies['projection_pivot'] = projection_pivot

    # output to file
    if args.output_frequencies:
        with open(args.output_frequencies, "w") as jsonfile:
            json.dump(frequencies, jsonfile, indent=1)

    # Save forecasts table, if requested.
    if args.output_table:
        all_forecasts = pd.concat(forecasts, ignore_index=True)
        all_forecasts["model"] = "-".join(predictors)
        all_forecasts.to_csv(args.output_table, sep="\t", index=False, header=True, na_rep="N/A")
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
from collections import defaultdict
import csv
import numpy as np
import pandas as pd
import sys


def get_distances_by_sample_names(distances):
    """Return a dictionary of distances by pairs of sample names.

    Parameters
    ----------
    distances : iterator
        an iterator of dictionaries with keys of distance, sample, and other_sample

    Returns
    -------
    dict :
        dictionary of distances by pairs of sample names
    """
    distances_by_sample_names = defaultdict(dict)
    for record in distances:
        sample_a = record["sample"]
        sample_b = record["other_sample"]
        distance = int(record["distance"])
        distances_by_sample_names[sample_a][sample_b] = distance

    return distances_by_sample_names


def get_distance_matrix_by_sample_names(samples_a, samples_b, distances):
    """Return a matrix of distances between pairs of given sample sets.

    Parameters
    ----------
    samples_a, samples_b : list
        names of samples whose pairwise distances should populate the matrix
        with the first samples in rows and the second samples in columns

    distances : dict
        dictionary of distances by pairs of sample names

    Returns
    -------
    ndarray :
        matrix of pairwise distances between the given samples


    >>> samples_a = ["a", "b"]
    >>> samples_b = ["c", "d"]
    >>> distances = {"a": {"c": 1, "d": 2}, "b": {"c": 3, "d": 4}}
    >>> get_distance_matrix_by_sample_names(samples_a, samples_b, distances)
    array([[1., 2.],
           [3., 4.]])
    >>> get_distance_matrix_by_sample_names(samples_b, samples_a, distances)
    array([[1., 3.],
           [2., 4.]])
    """
    matrix = np.zeros((len(samples_a), len(samples_b)))
    for i, sample_a in enumerate(samples_a):
        for j, sample_b in enumerate(samples_b):
            try:
                matrix[i, j] = distances[sample_a][sample_b]
            except KeyError:
                matrix[i, j] = distances[sample_b][sample_a]

    return matrix


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Annotated weighted distances between viruses",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--tip-attributes", required=True, help="a tab-delimited file describing tip attributes at one or more timepoints")
    parser.add_argument("--distances", required=True, help="tab-delimited file with pairwise distances between samples")
    parser.add_argument("--delta-months", required=True, type=int, help="number of months to project clade frequencies into the future")
    parser.add_argument("--output", required=True, help="tab-delimited output file with mean and standard deviation used to standardize each predictor")
    args = parser.parse_args()

    # Load tip attributes.
    tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"])

    # Load distances.
    with open(args.distances, "r") as fh:
        reader = csv.DictReader(fh, delimiter="\t")

        # Map distances by sample names.
        distances_by_sample_names = get_distances_by_sample_names(reader)

    # Find valid timepoints for calculating distances to the future.
    timepoints = tips["timepoint"].drop_duplicates()
    last_timepoint = timepoints.max() - pd.DateOffset(months=args.delta_months)

    # Calculate weighted distance to the present and future for each sample at a
    # given timepoint.
    weighted_distances = []
    for timepoint in timepoints:
        future_timepoint = timepoint + pd.DateOffset(months=args.delta_months)
        timepoint_tips = tips[tips["timepoint"] == timepoint]
        future_timepoint_tips = tips[tips["timepoint"] == future_timepoint]

        for current_tip, current_tip_frequency in timepoint_tips.loc[:, ["strain", "frequency"]].values:
            # Calculate the distance to the present for all timepoints.
            weighted_distance_to_present = 0.0
            for other_current_tip, other_current_tip_frequency in timepoint_tips.loc[:, ["strain", "frequency"]].values:
                weighted_distance_to_present += other_current_tip_frequency * distances_by_sample_names[current_tip][other_current_tip]

            # Calculate the distance to the future only for valid timepoints (those with future information).
            if timepoint <= last_timepoint:
                weighted_distance_to_future = 0.0
                for future_tip, future_tip_frequency in future_timepoint_tips.loc[:, ["strain", "frequency"]].values:
                    weighted_distance_to_future += future_tip_frequency * distances_by_sample_names[current_tip][future_tip]
            else:
                weighted_distance_to_future = np.nan

            weighted_distances.append({
                "timepoint": timepoint,
                "strain": current_tip,
                "weighted_distance_to_present": weighted_distance_to_present,
                "weighted_distance_to_future": weighted_distance_to_future
            })

    weighted_distances = pd.DataFrame(weighted_distances)

    # Calculate the magnitude of the difference between future and present
    # distances for each sample.
    weighted_distances["log2_distance_effect"] = np.log2(
        weighted_distances["weighted_distance_to_future"] /
        weighted_distances["weighted_distance_to_present"]
    )

    # Annotate samples with weighted distances.
    annotated_tips = tips.merge(
        weighted_distances,
        how="left",
        on=["strain", "timepoint"]
    )

    # Save the new data frame.
    annotated_tips.to_csv(args.output, sep="\t", index=False, na_rep="N/A")
ShowHide 103 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/blab/flu-forecasting
Name: flu-forecasting
Version: revised-submission
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...