Pipeline for reproduction of NealeLab 2018 UKB GWAS

public public 1yr ago Version: v0.1 0 bookmarks

NealeLab 2018 UK Biobank GWAS Reproduction Pipeline

This pipeline is a WIP, but it will attempt to reproduce this GWAS (with associated code at UK_Biobank_GWAS ) using sgkit .

Overview

To run this snakemake pipeline, the following infrastructure will be utilized at one point or another:

  1. A development GCE VM

    • It is possible for this workstation to exist outside of GCP, but that is not recommended because all clusters configured will not be addressable externally on ports beyond ssh (you will have to add firewall rules and/or modify the cluster installations)
  2. GKE clusters

    • These are created for tasks that run arbitrary snakemake jobs but do not need a Dask cluster
  3. Dask clusters

The development VM should be used to issue snakemake commands and will run some parts of the pipeline locally. This means that the development VM should have ~24G RAM and ~100G disk space. It is possible to move these steps on to external GKE clusters, but script execution is faster and easier to debug on a local machine.

Setup

  • Create an n1-standard-8 GCE instance w/ Debian 10 (buster) OS

  • Install NTP (so time is correct after pausing VM):

sudo apt-get install -y ntp
  • Install conda

  • Initialize the snakemake environment, which will provide the CLI from which most other commands will be run:

conda env create -f envs/snakemake.yaml 
conda activate snakemake

Notes:

  • All gcloud commands should be issued from this environment (particularly for Kubernetes) since commands are version-sensitive and will often fail if you run commands for a cluster using different gcloud versions (i.e. from different environments).

  • This will be mentioned frequently in the steps that follow, but it will be assumed when not stated otherwise that all commands are run from the root of this repo and that the .env as well as env.sh files have both been sourced.

  • Commands will often activated a conda environment first and where not stated otherwise, these environments can be generated using the definitions in envs .

The .env file contains more sensitive variable settings and a prototype for this file is shown here:

export GCP_PROJECT=uk-biobank-XXXXX
export GCP_REGION=us-east1
export GCP_ZONE=us-east1-c
export GCS_BUCKET=my-ukb-bucket-name # A single bucket is required for all operations
export GCP_USER_EMAIL=me@company.com # GCP user to be used in ACLs
export UKB_APP_ID=XXXXX # UK Biobank application id
export GCE_WORK_HOST=ukb-dev # Hostname given to development VM

You will have to create this file and populate the variable contents yourself.

Cluster Management

This pipeline involves steps that require very different resource profiles. Because of this, certain phases of the pipeline will require an appropriately defined GKE or Dask VM cluster. These clusters should be created/modified/deleted when necessary since they can be expensive, and while the commands below will suggest how to create a cluster, it will be up to the user to ultimately decide when they are no longer necessary. This is not tied into the code because debugging becomes far more difficult without long-running, user-managed clusters.

Kubernetes

Create Cluster

To create a GKE cluster that snakemake can execute rules on, follow these steps noting that the parameters used here are illustrative and may need to be altered based on the part of the pipeline being run:

source env.sh; source .env
gcloud init
gcloud components install kubectl
gcloud config set project "$GCP_PROJECT"
# Create cluster with 8 vCPUs/32GiB RAM/200G disk per node
# Memory must be multiple of 256 MiB (argument is MiB)
# Note: increase `--num-nodes` for greater throughput
gcloud container clusters create \
 --machine-type custom-${GKE_IO_NCPU}-${GKE_IO_MEM_MB} \
 --disk-type pd-standard \
 --disk-size ${GKE_IO_DISK_GB}G \
 --num-nodes 1 \
 --zone $GCP_ZONE \
 --node-locations $GCP_ZONE \
 --cluster-version latest \
 --scopes storage-rw \
 $GKE_IO_NAME
# Grant admin permissions on cluster
gcloud container clusters get-credentials $GKE_IO_NAME --zone $GCP_ZONE
kubectl create clusterrolebinding cluster-admin-binding \
 --clusterrole=cluster-admin \
 --user=$GCP_USER_EMAIL
# Note: If you see this, add IAM policy as below
# Error from server (Forbidden): clusterrolebindings.rbac.authorization.k8s.io is forbidden: 
# User "XXXXX" cannot create resource "clusterrolebindings" in API group "rbac.authorization.k8s.io" 
# at the cluster scope: requires one of ["container.clusterRoleBindings.create"] permission(s).
gcloud projects add-iam-policy-binding $GCP_PROJECT \
 --member=user:$GCP_USER_EMAIL \
 --role=roles/container.admin
# Login for GS Read/Write in pipeline rules
gcloud auth application-default login
# Run snakemake commands

Modify Cluster

source env.sh; source .env
## Resize
gcloud container clusters resize $GKE_IO_NAME --node-pool default-pool --num-nodes 2 --zone $GCP_ZONE
## Get status
kubectl get node # Find node name
gcloud compute ssh gke-ukb-io-default-pool-XXXXX
## Remove the cluster
gcloud container clusters delete $GKE_IO_NAME --zone $GCP_ZONE
## Remove node from cluster
kubectl get nodes
# Find node to delete: gke-ukb-io-1-default-pool-276513bc-48k5
kubectl drain gke-ukb-io-1-default-pool-276513bc-48k5 --force --ignore-daemonsets
gcloud container clusters describe ukb-io-1 --zone us-east1-c 
# Find instance group name: gke-ukb-io-1-default-pool-276513bc-grp
gcloud compute instance-groups managed delete-instances gke-ukb-io-1-default-pool-276513bc-grp --instances=gke-ukb-io-1-default-pool-276513bc-48k5 --zone $GCP_ZONE

Dask Cloud Provider

These commands show how to create a Dask cluster either for experimentation or for running steps in this pipeline:

conda env create -f envs/cloudprovider.yaml 
conda activate cloudprovider
source env.sh; source .env
source config/dask/cloudprovider.sh
python scripts/cluster/cloudprovider.py -- --interactive
>>> create(n_workers=1)
Launching cluster with the following configuration:
 Source Image: projects/ubuntu-os-cloud/global/images/ubuntu-minimal-1804-bionic-v20201014
 Docker Image: daskdev/dask:latest
 Machine Type: n1-standard-8
 Filesytsem Size: 50
 N-GPU Type:
 Zone: us-east1-c
Creating scheduler instance
dask-8a0571b8-scheduler
	Internal IP: 10.142.0.46
	External IP: 35.229.60.113
Waiting for scheduler to run
>>> scale(3)
Creating worker instance
Creating worker instance
dask-9347b93f-worker-60a26daf
	Internal IP: 10.142.0.52
	External IP: 35.229.60.113
dask-9347b93f-worker-4cc3cb6e
	Internal IP: 10.142.0.53
	External IP: 35.231.82.163
>>> adapt(0, 5, interval="60s", wait_count=3)
distributed.deploy.adaptive - INFO - Adaptive scaling started: minimum=0 maximum=5
>>> export_scheduler_info()
Scheduler info exported to /tmp/scheduler-info.txt
>>> shutdown()
Closing Instance: dask-9347b93f-scheduler
Cluster shutdown

To see the Dask UI for this cluster, run this on any workstation (outside of GCP):

gcloud beta compute ssh --zone "us-east1-c" "dask-9347b93f-scheduler" --ssh-flag="-L 8799:localhost:8787" .

The UI is then available at http://localhost:8799 .

Create Image

A custom image is created in this project as instructed in Creating custom OS images with Packer .

The definition of this image is generated automatically based on other environments used in this project, so a new image can be generated by using the following process.

  1. Determine package versions to be used by clients and cluster machines.

These can be found by running a command likek this: docker run daskdev/dask:v2.30.0 conda env export --from-history .

Alternatively, code with these references is here:

- https://hub.docker.com/layers/daskdev/dask/2.30.0/images/sha256-fb5d6b4eef7954448c244d0aa7b2405a507f9dad62ae29d9f869e284f0193c53?context=explore
- https://github.com/dask/dask-docker/blob/99fa808d4dac47b274b5063a23b5f3bbf0d3f105/base/Dockerfile

Ensure that the same versions are in docker/Dockerfile as well as envs/gwas.yaml .

  1. Create and deploy a new docker image (only necessary if Dask version has changed or new package dependencies were added).
DOCKER_USER=<user>
DOCKER_PWD=<password>
DOCKER_TAG="v2020.12.0" # Dask version
cd docker
docker build -t eczech/ukb-gwas-pipeline-nealelab:v2020.12.0 .
echo $DOCKER_PWD | docker login --username $DOCKER_USER --password-stdin
docker push eczech/ukb-gwas-pipeline-nealelab:v2020.12.0

Important : Update the desired docker image tag in config/dask/cloudprovider.sh .

  1. Build the Packer image
source .env; source env.sh
# From repo root, create the following configuration files:
conda activate cloudprovider
# See https://github.com/dask/dask-cloudprovider/issues/213 for more details (https://gist.github.com/jacobtomlinson/15404d5b032a9f91c9473d1a91e94c0a)
python scripts/cluster/packer.py create_cloud_init_config > config/dask/cloud-init-config.yaml
python scripts/cluster/packer.py create_packer_config > config/dask/packer-config.json
# Run the build
packer build config/dask/packer-config.json
googlecompute: output will be in this color.
==> googlecompute: Checking image does not exist...
==> googlecompute: Creating temporary rsa SSH key for instance...
==> googlecompute: Using image: ubuntu-minimal-1804-bionic-v20201014
==> googlecompute: Creating instance...
 googlecompute: Loading zone: us-east1-c
 googlecompute: Loading machine type: n1-standard-8
 googlecompute: Requesting instance creation...
 googlecompute: Waiting for creation operation to complete...
 googlecompute: Instance has been created!
==> googlecompute: Waiting for the instance to become running...
 googlecompute: IP: 35.196.0.219
==> googlecompute: Using ssh communicator to connect: 35.196.0.219
==> googlecompute: Waiting for SSH to become available...
==> googlecompute: Connected to SSH!
==> googlecompute: Provisioning with shell script: /tmp/packer-shell423808119
 googlecompute: Waiting for cloud-init
 googlecompute: Done
==> googlecompute: Deleting instance...
 googlecompute: Instance has been deleted!
==> googlecompute: Creating image...
==> googlecompute: Deleting disk...
 googlecompute: Disk has been deleted!
Build 'googlecompute' finished after 1 minute 46 seconds.
==> Wait completed after 1 minute 46 seconds
==> Builds finished. The artifacts of successful builds are:
--> googlecompute: A disk image was created: ukb-gwas-pipeline-nealelab-dask-1608465809
  1. Test the new image.

You can launch an instance of the VM like this:

gcloud compute instances create test-image \ --project $GCP_PROJECT \ --zone $GCP_ZONE \ --image-project $GCP_PROJECT \ --image ukb-gwas-pipeline-nealelab-dask-1608465809

This is particularly useful for checking that the GCP monitoring agent was installed correctly.

Then, you can create a Dask cluster to test with like this:

source env.sh; source .env; source config/dask/cloudprovider.sh
python scripts/cluster/cloudprovider.py -- --interactive
create(1, machine_type='n1-highmem-2', source_image="ukb-gwas-pipeline-nealelab-dask-1608465809", bootstrap=False)
adapt(0, 5)
export_scheduler_info()
# Compare this to an invocation like this, which would load package dependencies from a file containing
# the environment variables "EXTRA_CONDA_PACKAGES" and "EXTRA_PIP_PACKAGES"
create(1, machine_type='n1-highmem-8', bootstrap=True, env_var_file='config/dask/env_vars.json')

Note that a valid env_var_file would contain:

{
 "EXTRA_CONDA_PACKAGES": "\"numba==0.51.2 xarray==0.16.1 gcsfs==0.7.1 dask-ml==1.7.0 zarr==2.4.0 pyarrow==2.0.0 -c conda-forge\"",
 "EXTRA_PIP_PACKAGES": "\"git+https://github.com/pystatgen/sgkit.git@c5548821653fa2759421668092716d2036834ffe#egg=sgkit\""
}

Generally you want to back these dependencies into the docker + GCP vm image, but they can also be introduced by environment variables like this to aid in development and testing since the image building process is slow.

Execution

All of the following should be run from the root directory from this repo.

Note that you can preview the effects of any snakemake command below by adding -np to the end. This will show the inputs/outputs to a command as well as any shell code that would be run for it.

# Run this first before any of the steps below
conda activate snakemake
source env.sh; source .env

To get static HTML performance reports , which are suitable for sharing, do

mkdir -p logs/reports
export GENERATE_PERFORMANCE_REPORT=True

The the reports can be found in logs/reports .

Main UKB dataset integration

# Convert main dataset to parquet
# Takes ~45 mins on 4 cores, 12g heap
snakemake --use-conda --cores=1 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep-data/main/ukb.ckpt
# Extract sample QC from main dataset (as zarr)
snakemake --use-conda --cores=1 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep-data/main/ukb_sample_qc.ckpt
# Download data dictionary
snakemake --use-conda --cores=1 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe-data/external/ukb_meta/data_dictionary_showcase.csv

Zarr Integration

# Create cluster with enough disk to hold two copies of each bgen file
gcloud container clusters create \
 --machine-type custom-${GKE_IO_NCPU}-${GKE_IO_MEM_MB} \
 --disk-type pd-standard \
 --disk-size ${GKE_IO_DISK_GB}G \
 --num-nodes 1 \
 --enable-autoscaling --min-nodes 1 --max-nodes 9 \
 --zone $GCP_ZONE \
 --node-locations $GCP_ZONE \
 --cluster-version latest \
 --scopes storage-rw \
 $GKE_IO_NAME
# Run all jobs
# This takes a couple minutes for snakemake to even dry-run, so specifying
# targets yourself is generally faster and more flexible (as shown in the next commands)
snakemake --kubernetes --use-conda --cores=23 --local-cores=1 --restart-times 3 \
--default-remote-provider GS --default-remote-prefix rs-ukb \
--allowed-rules bgen_to_zarr 
# Generate single zarr archive from bgen
# * Set local cores to 1 so that only one rule runs at a time on cluster hosts
snakemake --kubernetes --use-conda --local-cores=1 --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/gt-imputation/ukb_chrXY.ckpt
# Expecting running time (8 vCPUs): ~30 minutes
# Scale up to larger files
snakemake --kubernetes --use-conda --cores=2 --local-cores=1 --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/gt-imputation/ukb_chr{21,22}.ckpt
# Expecting running time (8 vCPUs): 12 - 14 hours
# Takes ~12 hours for chr 1 on 64 vCPU / 262 GiB RAM / 1TB disk instances.
# Common reasons for failures:
# - https://github.com/dask/gcsfs/issues/315
# - https://github.com/related-sciences/ukb-gwas-pipeline-nealelab/issues/20
gcloud container clusters resize $GKE_IO_NAME --node-pool default-pool --num-nodes 5 --zone $GCP_ZONE
snakemake --kubernetes --use-conda --cores=5 --local-cores=1 --restart-times 3 \
--default-remote-provider GS --default-remote-prefix rs-ukb \
rs-ukb/prep/gt-imputation/ukb_chr{1,2,3,4,5,6,8,9,10}.ckpt
# Run on all chromosomes
snakemake --kubernetes --use-conda --cores=5 --local-cores=1 --restart-times 3 \
--default-remote-provider GS --default-remote-prefix rs-ukb --allowed-rules bgen_to_zarr -np
# Note: With autoscaling, you may will always see one job fail and then get restarted with an error like this
# "Unknown pod snakejob-9174e1f0-c94c-5c76-a3d2-d15af6dd49cb. Has the pod been deleted manually?"
# Delete the cluster
gcloud container clusters delete $GKE_IO_NAME --zone $GCP_ZONE

GWAS QC

# TODO: note somewhere that default quotas of 1000 cpus and 70 IPs will make 62 n1-highmem-16 largest cluster possible
# Create the cluster
screen -S cluster
conda activate cloudprovider
source env.sh; source .env; source config/dask/cloudprovider.sh
python scripts/cluster/cloudprovider.py -- --interactive
create(1, machine_type='n1-highmem-16', source_image="ukb-gwas-pipeline-nealelab-dask-1608465809", bootstrap=False)
adapt(0, 50, interval="60s"); export_scheduler_info(); # Set interval to how long nodes should live between uses
# Run the workflows
screen -S snakemake
conda activate snakemake
source env.sh; source .env 
export DASK_SCHEDULER_IP=`cat /tmp/scheduler-info.txt | grep internal_ip | cut -d'=' -f 2`
export DASK_SCHEDULER_HOST=`cat /tmp/scheduler-info.txt | grep hostname | cut -d'=' -f 2`
export DASK_SCHEDULER_ADDRESS=tcp://$DASK_SCHEDULER_IP:8786
echo $DASK_SCHEDULER_HOST $DASK_SCHEDULER_ADDRESS
# For the UI, open a tunnel by running this command on your local
# workstation before visiting localhost:8799 :
echo "gcloud beta compute ssh --zone us-east1-c $DASK_SCHEDULER_HOST --ssh-flag=\"-L 8799:localhost:8787\""
# e.g. gcloud beta compute ssh --zone us-east1-c dask-6ebe0412-scheduler --ssh-flag="-L 8799:localhost:8787"
# Takes ~25 mins for either chr 21/22 on 20 n1-standard-8 nodes.
# Takes ~58 mins for chr 2 on 60 n1-highmem-16 nodes.
snakemake --use-conda --cores=1 --allowed-rules qc_filter_stage_1 --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/gt-imputation-qc/ukb_chr{1,2,3,4,5,6,7,8,9,10,13,16}.ckpt
# Takes ~25-30 mins for chr 21/22 on 20 n1-standard-8 nodes
# Takes ~52 mins for chr 6 on 60 n1-highmem-16 nodes
snakemake --use-conda --cores=1 --allowed-rules qc_filter_stage_2 --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/input/gt-imputation/ukb_chr{XY,21,22}.ckpt
snakemake --use-conda --cores=1 --allowed-rules qc_filter_stage_1 --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/gt-imputation-qc/ukb_chr{11,12,13,14,15,16,17,18,19,20}.ckpt

Phenotype Prep

These steps can be run locally, but the local machine must be resized to have at least 200G RAM. They can alternatively be run on a GKE cluster by adding --kubernetes to the commands below.

conda activate snakemake
source env.sh; source .env;
# Create the input PHESANT phenotype CSV (takes ~15 mins)
snakemake --use-conda --cores=1 --allowed-rules main_csv_phesant_field_prep \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep-data/main/ukb_phesant_prep.csv
# Extract sample ids from genetic data QC (~1 minute)
snakemake --use-conda --cores=1 --allowed-rules extract_gwas_qc_sample_ids \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/input/sample_ids.csv
# Clone PHESANT repository to download normalization script and metadata files
snakemake --use-conda --cores=1 --allowed-rules phesant_clone \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/temp/repos/PHESANT
# Use genetic data QC sample ids as filter on samples used in phenotype preparation (takes ~40 mins, uses 120G RAM)
snakemake --use-conda --cores=1 --allowed-rules filter_phesant_csv \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/main/ukb_phesant_filtered.csv
# Generate the normalized phenotype data (took 8 hrs and 24 minutes on 8 vCPU / 300 GB RAM)
snakemake --use-conda --cores=1 --allowed-rules main_csv_phesant_phenotypes \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/main/ukb_phesant_phenotypes.csv > phesant.log 2>&1
# This isn't stricly necessary, but these logs should be preserved for future debugging
gsutil cp /tmp/phesant/phenotypes.1.log gs://rs-ukb/prep/main/log/phesant/phenotypes.1.log
gsutil cp phesant.log gs://rs-ukb/prep/main/log/phesant/phesant.log
# Dump the resulting field ids into a separate csv for debugging
snakemake --use-conda --cores=1 --allowed-rules main_csv_phesant_phenotypes_field_id_export \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/main/ukb_phesant_phenotypes.field_ids.csv
# Convert the phenotype data to parquet (~45 mins)
snakemake --use-conda --cores=1 --allowed-rules convert_phesant_csv_to_parquet \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/main/ukb_phesant_phenotypes.parquet.ckpt
# Convert the phenotype data to zarr (~30 mins)
snakemake --use-conda --cores=1 --allowed-rules convert_phesant_parquet_to_zarr \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/prep/main/ukb_phesant_phenotypes.zarr.ckpt
# Sort the zarr according to the sample ids in imputed genotyping data (~45 mins)
snakemake --use-conda --cores=1 --allowed-rules sort_phesant_parquet_zarr \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/input/main/ukb_phesant_phenotypes.ckpt

GWAS

# Notes: 
# - Client machine for these steps can be minimal (4 vCPU, 16 GB RAM)
# - A dask cluster should be created first as it was in the GWAS QC steps
# Copy Neale Lab sumstats from Open Targets
snakemake --use-conda --cores=1 --allowed-rules import_ot_nealelab_sumstats \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/external/ot_nealelab_sumstats/copy.ckpt
# Generate list of traits for GWAS based on intersection of 
# PHESANT results and OT sumstats
snakemake --use-conda --cores=1 --allowed-rules trait_group_ids \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/input/trait_group_ids.csv
# Generate sumstats using sgkit 
# See https://github.com/pystatgen/sgkit/issues/390 for timing information on this step.
snakemake --use-conda --cores=1 --allowed-rules gwas --restart-times 3 \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/output/gt-imputation/ukb_chr{21,22}.ckpt
# To clear: gsutil -m rm -rf gs://rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/output/gt-imputation/{sumstats,variables,*.ckpt}
# Takes ~10 mins on local host
snakemake --use-conda --cores=1 --allowed-rules sumstats \
 --default-remote-provider GS --default-remote-prefix rs-ukb \
 rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/output/sumstats-1990-20095.parquet

Misc

  • Never let fsspec overwrite Zarr archives! This technically works but it is incredibly slow compared to running "gsutil -m rm -rf " yourself. Another way to phrase this is that if you are expecting a pipeline step to overwrite an existing Zarr archive, delete it manually first.

  • To run the snakemake container manually, e.g. if you want to debug a GKE job, run docker run --rm -it -v $HOME/repos/ukb-gwas-pipeline-nealelab:/tmp/ukb-gwas-pipeline-nealelab snakemake/snakemake:v5.30.1 /bin/bash

    • This version should match that of the snakemake version used in the snakemake.yaml environment

Debug

# Generate DAG
gcloud auth application-default login
snakemake --dag data/prep-data/gt-imputation/ukb_chrXY.zarr | dot -Tsvg > dag.svg

Traits

This is a list of UKB traits that can be useful for testing or better understanding data coding schemes and PHESANT phenotype generation (or that are just entertaining):

  • 5610 - Which eye(s) affected by presbyopia (categorical)

  • 50 - Standing height (continuous)

  • 5183 - Current eye infection (binary)

  • 20101 - Thickness of butter/margarine spread on bread rolls (categorical)

  • 23098 - Weight (continuous)

  • 2395 - Hair/balding pattern (categorical)

  • 1990 - Tense / 'highly strung' (binary)

  • 20095 - Size of white wine glass drunk (categorical)

  • 845 - Age completed full time education (continuous)

  • 1160 - Sleep duration (continuous)

  • 4041 - Gestational diabetes only (binary)

  • 738 - Average total household income before tax (categorical)

  • 1100 - Drive faster than motorway speed limit (categorical)

  • 1100 - Usual side of head for mobile phone use (categorical)

Development Setup

For local development on this pipeline, run:

pip install -r requirements-dev.txt
pre-commit install

Code Snippets

 9
10
11
12
13
shell:
    "python scripts/gwas.py run_qc_1 "
    "--input-path={params.input_path} "
    "--output-path={params.output_path} "
    "&& touch {output}"
26
27
28
29
30
31
shell:
    "python scripts/gwas.py run_qc_2 "
    "--sample-qc-path={params.sample_qc_path} "
    "--input-path={params.input_path} "
    "--output-path={params.output_path} "
    "&& touch {output}"
43
44
45
46
47
shell:
    "python scripts/extract_trait_group_ids.py run "
    "--phenotypes-path={params.phenotypes_path} "
    "--sumstats-path={params.sumstats_path} "
    "--output-path={output} "
64
65
66
67
68
69
70
71
shell:
    "python scripts/gwas.py run_gwas "
    "--genotypes-path={params.genotypes_path} "
    "--phenotypes-path={params.phenotypes_path} "
    "--sumstats-path={params.sumstats_path} "
    "--variables-path={params.variables_path} "
    "--trait-group-ids={params.trait_group_ids} "
    "&& touch {output}"
88
89
90
91
92
93
94
shell:
    "python scripts/merge_sumstats.py run "
    "--gwas-sumstats-path={params.gwas_sumstats_path} "
    "--ot-sumstats-path={params.ot_sumstats_path} "
    "--output-path={output} "
    "--contigs={params.contigs} "
    "--trait_group_ids={params.trait_group_ids} "
7
8
shell:
    "Rscript scripts/external/phesant_phenotype_prep.R {input} {output}"
19
20
shell:
    "python scripts/extract_sample_ids.py run {params.input_path} {output}"
29
30
31
32
33
shell:
    "python scripts/extract_main_data.py phesant_qc_csv "
    "--input-path={input.input_path} "
    "--sample-id-path={input.sample_id_path} "
    "--output-path={output} "
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
run:
    field_ids = []
    with open(input[0], "r") as f:
        headers = f.readline().split("\t")
        headers = sorted(set(map(lambda v: v.split('_')[0], headers)))
        for field_id in headers:
            field_id = field_id.replace('"', '').strip()
            try:
                field_id = int(field_id)
            except ValueError:
                continue
            field_ids.append(field_id)
    with open(output[0], "w") as f:
        f.write("ukb_field_id\n")
        for field_id in field_ids:
            f.write(str(field_id) + "\n")
111
112
113
114
115
116
117
shell:
    "export SPARK_DRIVER_MEMORY=12g && "
    "python scripts/convert_phesant_data.py to_parquet "
    "--input-path={input} "
    "--output-path={params.output_path} && "
    "gsutil -m -q rsync -d -r {params.output_path} gs://{params.output_path} && "
    "touch {output}"
127
128
129
130
131
132
shell:
    "python scripts/convert_phesant_data.py to_zarr "
    "--input-path={params.input_path} "
    "--dictionary-path={params.dictionary_path} "
    "--output-path={params.output_path} && "
    "touch {output}"
143
144
145
146
147
148
shell:
    "python scripts/convert_phesant_data.py sort_zarr "
    "--input-path={params.input_path} "
    "--genotypes-path={params.genotypes_path} "
    "--output-path={params.output_path} && "
    "touch {output}"
 9
10
11
12
13
14
15
shell:
    "export SPARK_DRIVER_MEMORY=12g && "
    "python scripts/convert_main_data.py csv_to_parquet "
    "--input-path={input} "
    "--output-path={params.parquet_path} && "
    "gsutil -m -q rsync -d -r {params.parquet_path} gs://{params.parquet_path} && "
    "touch {output}"
20
shell: "mv {input} {output}"
28
29
30
31
32
shell:
    "export SPARK_DRIVER_MEMORY=10g && "
    "python scripts/extract_main_data.py sample_qc_csv "
    "--input-path={params.input_path} "
    "--output-path={output} "
40
41
42
43
44
45
shell:
    "python scripts/extract_main_data.py sample_qc_zarr "
    "--input-path={input} "
    "--output-path={params.output_path} "
    "--remote=True && "
    "touch {output}"
5
6
7
8
9
shell:
    "gsutil -u {gcp_project} -m cp "
    "gs://genetics-portal-raw/uk_biobank_sumstats/neale_v2/output/neale_v2_sumstats/*.neale2.gwas.imputed_v3.both_sexes.tsv.gz "
    "gs://{params.output_dir}/ && "
    "touch {output}"
15
16
17
18
19
20
21
22
23
24
shell:
    "python scripts/convert_genetic_data.py plink_to_zarr "
    "--input-path-bed={input.bed_path} "
    "--input-path-bim={input.bim_path} "
    "--input-path-fam={input.fam_path} "
    "--output-path={params.zarr_path} "
    "--contig-name={wildcards.plink_contig} "
    "--contig-index={params.contig_index} "
    "--remote=True && "
    "touch {output}"
43
44
45
46
47
48
49
50
51
52
shell:
    "python scripts/convert_genetic_data.py bgen_to_zarr "
    "--input-path-bgen={input.bgen_path} "
    "--input-path-variants={input.variants_path} "
    "--input-path-samples={input.samples_path} "
    "--output-path={params.zarr_path} "
    "--contig-name={wildcards.bgen_contig} "
    "--contig-index={params.contig_index} "
    "--remote=True && "
    "touch {output}"
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional, Tuple, Union

import dask
import fire
import gcsfs
import numpy as np
import pandas as pd
import xarray as xr
import zarr
from dask.diagnostics import ProgressBar
from sgkit.io.bgen import read_bgen, rechunk_bgen
from sgkit.io.plink import read_plink
from xarray import Dataset

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


@dataclass
class BGENPaths:
    bgen_path: str
    variants_path: str
    samples_path: str


@dataclass
class PLINKPaths:
    bed_path: str
    bim_path: str
    fam_path: str


@dataclass
class Contig:
    name: str
    index: int


def transform_contig(ds: Dataset, contig: Contig) -> Dataset:
    # Preserve the original contig index/name field
    # in case there are multiple (e.g. PAR1, PAR2 within XY)
    ds["variant_contig_name"] = xr.DataArray(
        np.array(ds.attrs["contigs"])[ds["variant_contig"].values].astype("S"),
        dims="variants",
    )
    # Overwrite contig index with single value matching index
    # for contig name in file name
    ds["variant_contig"].data = np.full(
        ds["variant_contig"].shape, contig.index, dtype=ds["variant_contig"].dtype
    )
    # Add attributes for convenience
    ds.attrs["contig_name"] = contig.name
    ds.attrs["contig_index"] = contig.index
    return ds


def load_plink(paths: PLINKPaths, contig: Contig) -> Dataset:
    logger.info(f"Loading PLINK dataset for contig {contig} from {paths.bed_path}")
    with dask.config.set(scheduler="threads"):
        ds = read_plink(
            bed_path=paths.bed_path,
            bim_path=paths.bim_path,
            fam_path=paths.fam_path,
            bim_int_contig=False,
            count_a1=False,
        )
    ds["sample_id"] = ds["sample_id"].astype("int32")
    # All useful sample metadata will come from the
    # main UKB dataset instead
    ds = ds.drop_vars(
        [
            "sample_family_id",
            "sample_paternal_id",
            "sample_maternal_id",
            "sample_phenotype",
        ]
    )
    # Update contig index/names
    ds = transform_contig(ds, contig)
    return ds


def load_bgen_variants(path: str):
    # See: https://github.com/Nealelab/UK_Biobank_GWAS/blob/8f8ee456fdd044ce6809bb7e7492dc98fd2df42f/0.1/09.load_mfi_vds.py
    cols = [
        ("id", str),
        ("rsid", str),
        ("position", "int32"),
        ("allele1_ref", str),
        ("allele2_alt", str),
        ("maf", "float32"),
        ("minor_allele", str),
        ("info", "float32"),
    ]
    df = pd.read_csv(path, sep="\t", names=[c[0] for c in cols], dtype=dict(cols))
    ds = df.rename_axis("variants", axis="rows").to_xarray().drop("variants")
    ds["allele"] = xr.concat([ds["allele1_ref"], ds["allele2_alt"]], dim="alleles").T
    ds = ds.drop_vars(["allele1_ref", "allele2_alt"])
    for c in cols + [("allele", str)]:
        if c[0] in ds and c[1] == str:
            ds[c[0]] = ds[c[0]].compute().astype("S")
    ds = ds.rename({v: "variant_" + v for v in ds})
    ds = ds.chunk(chunks="auto")
    return ds


def load_bgen_samples(path: str) -> Dataset:
    cols = [("id1", "int32"), ("id2", "int32"), ("missing", str), ("sex", "uint8")]
    # Example .sample file:
    # head ~/data/rs-ukb/raw-data/gt-imputation/ukb59384_imp_chr4_v3_s487296.sample
    # ID_1 ID_2 missing sex
    # 0 0 0 D
    # 123123 123123 0 1  # Actual ids replaced with fake numbers
    df = pd.read_csv(
        path,
        sep=" ",
        dtype=dict(cols),
        names=[c[0] for c in cols],
        header=0,
        skiprows=1,  # Skip the first non-header row
    )
    # id1 always equals id2 and missing is always 0
    df = df[["id1", "sex"]].rename(columns={"id1": "id"})
    ds = df.rename_axis("samples", axis="rows").to_xarray().drop("samples")
    ds = ds.rename({v: "sample_" + v for v in ds})
    return ds


def load_bgen_probabilities(
    path: str, contig: Contig, chunks: Optional[Union[str, int, tuple]] = None
) -> Dataset:
    ds = read_bgen(path, chunks=chunks, gp_dtype="float16")

    # Update contig index/names
    ds = transform_contig(ds, contig)

    # Drop most variables since the external tables are more useful
    ds = ds[
        [
            "variant_contig",
            "variant_contig_name",
            "call_genotype_probability",
            "call_genotype_probability_mask",
        ]
    ]
    return ds


def load_bgen(
    paths: BGENPaths,
    contig: Contig,
    region: Optional[Tuple[int, int]] = None,
    variant_info_threshold: Optional[float] = None,
    chunks: Tuple[int, int] = (250, -1),
):
    logger.info(
        f"Loading BGEN dataset for contig {contig} from "
        f"{paths.bgen_path} (chunks = {chunks})"
    )
    # Load and merge primary + axis datasets
    dsp = load_bgen_probabilities(paths.bgen_path, contig, chunks=chunks + (-1,))
    dsv = load_bgen_variants(paths.variants_path)
    dss = load_bgen_samples(paths.samples_path)
    ds = xr.merge([dsv, dss, dsp], combine_attrs="no_conflicts")

    # Apply variant slice if provided
    if region is not None:
        logger.info(f"Applying filter to region {region}")
        n_variant = ds.dims["variants"]
        ds = ds.isel(variants=slice(region[0], region[1]))
        logger.info(f"Filtered to {ds.dims['variants']} variants of {n_variant}")

    # Apply variant info threshold if provided (this is applied
    # early because it is not particularly controversial and
    # eliminates ~80% of variants when near .8)
    if variant_info_threshold is not None:
        logger.info(f"Applying filter to variant info > {variant_info_threshold}")
        n_variant = ds.dims["variants"]
        ds = ds.isel(variants=ds.variant_info > variant_info_threshold)
        logger.info(f"Filtered to {ds.dims['variants']} variants of {n_variant}")
        # Make sure to rechunk after non-uniform filter
        for v in ds:
            if "variants" in ds[v].dims and "samples" in ds[v].dims:
                ds[v] = ds[v].chunk(chunks=dict(variants=chunks[0]))
            elif "variants" in ds[v].dims:
                ds[v] = ds[v].chunk(chunks=dict(variants="auto"))

    return ds


def rechunk_dataset(
    ds: Dataset,
    output: str,
    contig: Contig,
    fn: Callable,
    chunks: Tuple[int, int],
    max_mem: str,
    progress_update_seconds: int = 60,
    remote: bool = True,
    **kwargs,
) -> Dataset:
    logger.info(
        f"Rechunking dataset for contig {contig} "
        f"to {output} (chunks = {chunks}):\n{ds}"
    )

    if remote:
        gcs = gcsfs.GCSFileSystem()
        output = gcsfs.GCSMap(output, gcs=gcs, check=False, create=False)

    # Save to local zarr store with desired sample chunking
    with ProgressBar(dt=progress_update_seconds):
        res = fn(
            ds,
            output=output,
            chunk_length=chunks[0],
            chunk_width=chunks[1],
            max_mem=max_mem,
            **kwargs,
        )

    logger.info(f"Rechunked dataset:\n{res}")
    return res


def save_dataset(
    output_path: str,
    ds: Dataset,
    contig: Contig,
    scheduler: str = "threads",
    remote: bool = True,
    progress_update_seconds: int = 60,
):
    store = output_path
    if remote:
        gcs = gcsfs.GCSFileSystem()
        store = gcsfs.GCSMap(output_path, gcs=gcs, check=False, create=False)
    logger.info(
        f"Dataset to save for contig {contig}:\n{ds}\n"
        f"Writing dataset for contig {contig} to {output_path} "
        f"(scheduler={scheduler}, remote={remote})"
    )
    with dask.config.set(scheduler=scheduler), dask.config.set(
        {"optimization.fuse.ave-width": 50}
    ), ProgressBar(dt=progress_update_seconds):
        ds.to_zarr(store=store, mode="w", consolidated=True)


def plink_to_zarr(
    input_path_bed: str,
    input_path_bim: str,
    input_path_fam: str,
    output_path: str,
    contig_name: str,
    contig_index: int,
    remote: bool = True,
):
    """Convert UKB PLINK to Zarr"""
    paths = PLINKPaths(
        bed_path=input_path_bed, bim_path=input_path_bim, fam_path=input_path_fam
    )
    contig = Contig(name=contig_name, index=contig_index)
    ds = load_plink(paths, contig)
    # TODO: Switch to rechunker method
    save_dataset(output_path, ds, contig, scheduler="processes", remote=remote)
    logger.info("Done")


def bgen_to_zarr(
    input_path_bgen: str,
    input_path_variants: str,
    input_path_samples: str,
    output_path: str,
    contig_name: str,
    contig_index: int,
    max_mem: str = "500MB",  # per-worker
    remote: bool = True,
    region: Optional[Tuple[int, int]] = None,
):
    """Convert UKB BGEN to Zarr"""
    paths = BGENPaths(
        bgen_path=input_path_bgen,
        variants_path=input_path_variants,
        samples_path=input_path_samples,
    )
    contig = Contig(name=contig_name, index=contig_index)
    ds = load_bgen(paths, contig, region=region)

    # Chosen with expected shape across all chroms (~128MB chunks):
    # normalize_chunks('auto', shape=(97059328, 487409), dtype='float32')
    chunks = (5216, 5792)
    ds = rechunk_dataset(
        ds,
        output=output_path,
        contig=contig,
        fn=rechunk_bgen,
        chunks=chunks,
        max_mem=max_mem,
        remote=remote,
        compressor=zarr.Blosc(cname="zstd", clevel=7, shuffle=2, blocksize=0),
        probability_dtype="uint8",
        pack=True,
    )
    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
import logging.config
from pathlib import Path

import fire
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


def get_schema(path: str, sample_id_col: str = "eid", **kwargs):
    # Fetch header only
    cols = pd.read_csv(path, nrows=1, **kwargs).columns.tolist()
    assert (
        cols[0] == sample_id_col
    ), f'Expecting "{sample_id_col}" as first field, found "{cols[0]}"'
    # Convert field names for spark compatibility from
    # `{field_id}-{instance_index}.{array_index}` to `x{field_id}_{instance_index}_{array_index}`
    # See: https://github.com/related-sciences/data-team/issues/22#issuecomment-613048099
    cols = [
        c if c == sample_id_col else "x" + c.replace("-", "_").replace(".", "_")
        for c in cols
    ]

    # Generate generic schema with string types (except sample id)
    schema = [StructField(cols[0], IntegerType())]
    schema += [StructField(c, StringType()) for c in cols[1:]]
    assert len(cols) == len(schema)
    schema = StructType(schema)
    return schema


def csv_to_parquet(input_path: str, output_path: str):
    """Convert primary UKB dataset CSV to Parquet"""
    logger.info(f"Converting csv at {input_path} to {output_path}")
    spark = SparkSession.builder.getOrCreate()

    schema = get_schema(input_path, sample_id_col="eid", sep=",", encoding="cp1252")

    # Read csv with no header
    df = spark.read.csv(
        input_path, sep=",", encoding="cp1252", header=False, schema=schema
    )
    df = df.filter(F.col("eid").isNotNull())
    logger.info(f"Number of partitions in result: {df.rdd.getNumPartitions()}")
    df.write.mode("overwrite").parquet(output_path, compression="snappy")

    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
import logging
import logging.config
from pathlib import Path

import fire
import numpy as np
import pandas as pd

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


def to_parquet(input_path: str, output_path: str):
    from pyspark.sql import SparkSession

    logger.info(f"Converting csv at {input_path} to {output_path}")
    spark = SparkSession.builder.getOrCreate()
    df = spark.read.csv(input_path, sep="\t", header=True, inferSchema=True)
    df.write.mode("overwrite").parquet(output_path, compression="snappy")
    logger.info("Done")


def to_zarr(input_path: str, output_path: str, dictionary_path: str):
    import dask.dataframe as dd
    import fsspec
    import xarray as xr
    from dask.diagnostics import ProgressBar

    logger.info(f"Converting parquet at {input_path} to {output_path}")
    df = dd.read_parquet(input_path)

    trait_columns = df.columns[df.columns.to_series().str.match(r"^\d+")]
    # 41210_Z942 -> 41210 (UKB field id)
    trait_group_ids = [c.split("_")[0] for c in trait_columns]
    # 41210_Z942 -> Z942 (Data coding value as one-hot encoding in phenotype, e.g.)
    trait_code_ids = ["_".join(c.split("_")[1:]) for c in trait_columns]
    trait_values = df[trait_columns].astype("float").to_dask_array()
    trait_values.compute_chunk_sizes()

    trait_id_to_name = (
        pd.read_csv(
            dictionary_path,
            sep=",",
            usecols=["FieldID", "Field"],
            dtype={"FieldID": str, "Field": str},
        )
        .set_index("FieldID")["Field"]
        .to_dict()
    )
    trait_name = [trait_id_to_name.get(v) for v in trait_group_ids]

    ds = xr.Dataset(
        dict(
            id=("samples", np.asarray(df["userId"], dtype=int)),
            trait=(("samples", "traits"), trait_values),
            trait_id=("traits", np.asarray(trait_columns.values, dtype=str)),
            trait_group_id=("traits", np.array(trait_group_ids, dtype=int)),
            trait_code_id=("traits", np.array(trait_code_ids, dtype=str)),
            trait_name=("traits", np.array(trait_name, dtype=str)),
        )
    )
    # Keep chunks small in trait dimension for faster per-trait processing
    ds["trait"] = ds["trait"].chunk(dict(samples="auto", traits=100))
    ds = ds.rename_vars({v: f"sample_{v}" for v in ds})

    logger.info(f"Saving dataset to {output_path}:\n{ds}")
    with ProgressBar():
        ds.to_zarr(fsspec.get_mapper(output_path), consolidated=True, mode="w")
    logger.info("Done")


def sort_zarr(input_path: str, genotypes_path: str, output_path: str):
    import fsspec
    import xarray as xr
    from dask.diagnostics import ProgressBar

    ds_tr = xr.open_zarr(fsspec.get_mapper(input_path), consolidated=True)
    ds_gt = xr.open_zarr(fsspec.get_mapper(genotypes_path), consolidated=True)

    # Sort trait data using genomic data sample ids;
    # Note that this will typically produce a warning like:
    # "PerformanceWarning: Slicing with an out-of-order index is generating 69909 times more chunks"
    # which is OK since the purpose of this step is to incur this cost once
    # instead of many times in a repetitive GWAS workflow
    ds = ds_tr.set_index(samples="sample_id").sel(samples=ds_gt.sample_id)
    ds = ds.rename_vars({"samples": "sample_id"}).reset_coords("sample_id")

    # Restore chunkings; reordered traits array will have many chunks
    # of size 1 (or other small sizes) in samples dim without this
    for v in ds_tr:
        ds[v].encoding.pop("chunks", None)
        ds[v] = ds[v].chunk(ds_tr[v].data.chunksize)

    logger.info(f"Saving dataset to {output_path}:\n{ds}")
    with ProgressBar():
        ds.to_zarr(fsspec.get_mapper(output_path), consolidated=True, mode="w")


if __name__ == "__main__":
    fire.Fire()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
library("data.table")

args = commandArgs(trailingOnly=TRUE)
input_path <- args[1]
output_path <- args[2]

# Neale Lab application.
bd2 <- fread(input_path, header=TRUE, sep=',', na.strings=c("NA", ""))
bd2 <- as.data.frame(bd2)

# First get rid of things that are > 4.
colnames(bd2) <- gsub("-", "\\.", colnames(bd2))
visit_number <- as.integer(unlist(lapply(strsplit(names(bd2), "\\."), "[", 2)))[-1]
fields <- unlist(lapply(strsplit(names(bd2), "\\."), "[", 1))[-1]

to_include <- which(visit_number <= 4)
# Get rid of the fields that have up to 31 visits - small collection of Cancer variables,
# need to look at those separately.

# Cancer variables 
cancer_fields <- unique(fields[which(visit_number > 4)])
to_include <- setdiff(to_include, which(fields %in% cancer_fields))

bd2 <- bd2[,c(1,to_include+1)]

# Now, want to go through these column names and determine, if there's a 0, include it, if not, look to 1 etc. 
fields <- unlist(lapply(strsplit(names(bd2), "\\."), "[", 1))[-1]
visit_number <- as.integer(unlist(lapply(strsplit(names(bd2), "\\."), "[", 2))[-1])
unique_fields <- unique(fields)
fields_to_use <- c()

for (i in unique_fields) {
    matches <- which(fields == i)
    match_min <- which(visit_number[matches] == min(visit_number[matches]))
    match <- matches[match_min]
    fields_to_use <- c(fields_to_use, match)
}

bd2 <- bd2[,c(1,fields_to_use+1)]

c1 <- gsub("^", "x", colnames(bd2))
c2 <- gsub("\\.", "_", c1)
c3 <- c("userId", c2[2:length(c2)])
colnames(bd2) <- c3

# Hack to ensure that the age and sex are in as columns.
if (any(colnames(bd2) == "x21022_0_0") == FALSE) {
    bd2 <- cbind(rep(1, nrow(bd2)), bd2)
    colnames(bd2)[1] <- "x21022_0_0"
}

if (any(colnames(bd2) == "x31_0_0") == FALSE) {
    bd2 <- cbind(rep(1, nrow(bd2)), bd2)
    colnames(bd2)[1] <- "x31_0_0"
}

fwrite(bd2, file=output_path, row.names=FALSE, quote=FALSE, sep='\t')
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import logging
import logging.config
from pathlib import Path

import fire
import pandas as pd

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


# See here for a description of this resource and its associated fields:
# http://biobank.ctsu.ox.ac.uk/crystal/label.cgi?id=100313
# Also here for similar Neale Lab extraction:
# https://github.com/Nealelab/UK_Biobank_GWAS/blob/master/0.1/00.load_sample_qc_kt.py
SAMPLE_QC_COLS = {
    "eid": "eid",
    "x22000_0_0": "genotype_measurement_batch",
    "x22007_0_0": "genotype_measurement_plate",
    "x22008_0_0": "genotype_measurement_well",
    "x22001_0_0": "genetic_sex",
    "x22021_0_0": "genetic_kinship_to_other_participants",
    # This is boolean flag (as float) that "Indicates samples who self-identified as 'White British' according to Field 21000"
    # https://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=22006
    "x22006_0_0": "genetic_ethnic_grouping",
    "x22019_0_0": "sex_chromosome_aneuploidy",
    "x22027_0_0": "outliers_for_heterozygosity_or_missing_rate",
    "x22003_0_0": "heterozygosity",
    "x22004_0_0": "heterozygosity_pca_corrected",
    "x22005_0_0": "missingness",
    "x22020_0_0": "used_in_genetic_principal_components",
    "x22022_0_0": "sex_inference_x_probe_intensity",
    "x22023_0_0": "sex_inference_y_probe_intensity",
    "x22025_0_0": "affymetrix_quality_control_metric_cluster_cr",
    "x22026_0_0": "affymetrix_quality_control_metric_dqc",
    "x22024_0_0": "dna_concentration",
    "x22028_0_0": "use_in_phasing_chromosomes_1_22",
    "x22029_0_0": "use_in_phasing_chromosome_x",
    "x22030_0_0": "use_in_phasing_chromosome_xy",
    # -----------------------------------------------------
    # Additional fields beyond resource but relevant for QC
    # see: https://github.com/atgu/ukbb_pan_ancestry/blob/master/reengineering_phenofile_neale_lab2.r
    "x21022_0_0": "age_at_recruitment",
    "x31_0_0": "sex",
    "x21000_0_0": "ethnic_background",
    # -----------------------------------------------------
    # PCs
    "x22009_0_1": "genetic_principal_component_01",
    "x22009_0_2": "genetic_principal_component_02",
    "x22009_0_3": "genetic_principal_component_03",
    "x22009_0_4": "genetic_principal_component_04",
    "x22009_0_5": "genetic_principal_component_05",
    "x22009_0_6": "genetic_principal_component_06",
    "x22009_0_7": "genetic_principal_component_07",
    "x22009_0_8": "genetic_principal_component_08",
    "x22009_0_9": "genetic_principal_component_09",
    "x22009_0_10": "genetic_principal_component_10",
    "x22009_0_11": "genetic_principal_component_11",
    "x22009_0_12": "genetic_principal_component_12",
    "x22009_0_13": "genetic_principal_component_13",
    "x22009_0_14": "genetic_principal_component_14",
    "x22009_0_15": "genetic_principal_component_15",
    "x22009_0_16": "genetic_principal_component_16",
    "x22009_0_17": "genetic_principal_component_17",
    "x22009_0_18": "genetic_principal_component_18",
    "x22009_0_19": "genetic_principal_component_19",
    "x22009_0_20": "genetic_principal_component_20",
    "x22009_0_21": "genetic_principal_component_21",
    "x22009_0_22": "genetic_principal_component_22",
    "x22009_0_23": "genetic_principal_component_23",
    "x22009_0_24": "genetic_principal_component_24",
    "x22009_0_25": "genetic_principal_component_25",
    "x22009_0_26": "genetic_principal_component_26",
    "x22009_0_27": "genetic_principal_component_27",
    "x22009_0_28": "genetic_principal_component_28",
    "x22009_0_29": "genetic_principal_component_29",
    "x22009_0_30": "genetic_principal_component_30",
    "x22009_0_31": "genetic_principal_component_31",
    "x22009_0_32": "genetic_principal_component_32",
    "x22009_0_33": "genetic_principal_component_33",
    "x22009_0_34": "genetic_principal_component_34",
    "x22009_0_35": "genetic_principal_component_35",
    "x22009_0_36": "genetic_principal_component_36",
    "x22009_0_37": "genetic_principal_component_37",
    "x22009_0_38": "genetic_principal_component_38",
    "x22009_0_39": "genetic_principal_component_39",
    "x22009_0_40": "genetic_principal_component_40",
}


def sample_qc_csv(input_path: str, output_path: str):
    """Extract sample QC data from main dataset as csv"""
    from pyspark.sql import SparkSession

    logger.info(f"Extracting sample qc from {input_path} into {output_path}")
    spark = SparkSession.builder.getOrCreate()
    df = spark.read.parquet(input_path)
    pdf = df[list(SAMPLE_QC_COLS.keys())].toPandas()
    pdf = pdf.rename(columns=SAMPLE_QC_COLS)
    logger.info("Sample QC info:")
    pdf.info()
    logger.info(f"Saving csv at {output_path}")
    pdf.to_csv(output_path, sep="\t", index=False)


def sample_qc_zarr(input_path: str, output_path: str, remote: bool):
    """Convert sample QC csv to zarr"""
    import gcsfs
    import pandas as pd

    logger.info("Converting to Xarray")
    df = pd.read_csv(input_path, sep="\t")
    pc_vars = df.filter(regex="^genetic_principal_component").columns.tolist()
    ds = (
        df[[c for c in df if c not in pc_vars]]
        .rename_axis("samples", axis="rows")
        .to_xarray()
        .drop_vars("samples")
    )
    pcs = (
        df[pc_vars]
        .rename_axis("samples", axis="rows")
        .to_xarray()
        .drop_vars("samples")
        .to_array(dim="principal_components")
        .T
    )
    ds = ds.assign(
        genotype_measurement_plate=ds.genotype_measurement_plate.astype("S"),
        genotype_measurement_well=ds.genotype_measurement_well.astype("S"),
        principal_component=pcs.drop_vars("principal_components"),
    )
    # Rechunk to enforce stricter dtypes as well as ease
    # downstream loading/processing of PC array
    ds = ds.chunk("auto")

    store = output_path
    if remote:
        gcs = gcsfs.GCSFileSystem()
        store = gcsfs.GCSMap(output_path, gcs=gcs, check=False, create=True)

    logger.info(f"Sample QC dataset:\n{ds}")
    logger.info(f"Saving zarr archive at {output_path}")
    ds.to_zarr(store, mode="w", consolidated=True)


def phesant_qc_csv(input_path: str, sample_id_path: str, output_path: str):
    logger.info(
        f"Filtering csv at {input_path} to {output_path} (using samples at {sample_id_path})"
    )
    sample_ids = pd.read_csv(sample_id_path, sep="\t")
    sample_ids = [int(v) for v in set(sample_ids.sample_id.values)]
    logger.info(f"Loaded {len(sample_ids)} sample ids to filter to")

    df = pd.read_csv(
        input_path,
        sep="\t",
        dtype=str,
        keep_default_na=False,
        na_values=None,
        na_filter=False,
    )
    df["userId"] = df["userId"].astype(int)
    logger.info(f"Number of records before filter = {len(df)}")
    df = df[df["userId"].isin(sample_ids)]
    logger.info(f"Number of records after filter = {len(df)}")
    df.to_csv(output_path, index=False, sep="\t")


if __name__ == "__main__":
    fire.Fire()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import logging
import logging.config
from pathlib import Path

import fire
import fsspec
import pandas as pd
import xarray as xr

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


def get_sample_ids(path):
    ds = xr.open_zarr(fsspec.get_mapper(path))
    return ds.sample_id.to_series().to_list()


def run(input_path, output_path):
    logger.info(f"Extracting sample ids from {input_path} into {output_path}")
    ids = get_sample_ids(input_path)
    ids = pd.DataFrame(dict(sample_id=list(set(ids))))
    ids.to_csv(output_path, sep="\t", index=False)
    logger.info(f"Number of sample ids saved: {len(ids)}")
    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
import logging.config
from pathlib import Path

import fire
import fsspec
import numpy as np
import pandas as pd
import xarray as xr

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)


def get_ot_trait_groups(path):
    store = fsspec.get_mapper(path)
    ids = []
    for file in list(store):
        if not file.endswith(".tsv.gz"):
            continue
        ids.append(int(file.split(".")[0].split("_")[0]))
    return np.unique(ids)


def get_gwas_trait_groups(path):
    ds = xr.open_zarr(fsspec.get_mapper(path))
    return ds["sample_trait_group_id"].to_series().astype(int).unique()


def run(phenotypes_path: str, sumstats_path: str, output_path: str):
    logger.info(
        f"Extracting intersecting traits from {phenotypes_path} and {sumstats_path} to {output_path}"
    )
    ids2 = get_ot_trait_groups(sumstats_path)
    ids1 = get_gwas_trait_groups(phenotypes_path)
    ids = pd.DataFrame(dict(trait_group_id=np.intersect1d(ids1, ids2)))
    ids.to_csv(output_path, sep="\t", index=False)
    logger.info(f"Number of trait group ids saved: {len(ids)}")
    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import logging
import logging.config
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
from urllib.parse import urlparse

import dask
import dask.array as da
import fire
import fsspec
import numpy as np
import pandas as pd
import sgkit as sg
import xarray as xr
from dask.diagnostics import ProgressBar
from dask.distributed import Client, get_task_stream, performance_report
from retrying import retry
from sgkit.io.bgen.bgen_reader import unpack_variables
from xarray import DataArray, Dataset

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)

fs = fsspec.filesystem("gs")


def init():
    # Set this globally to avoid constant warnings like:
    # PerformanceWarning: Slicing is producing a large chunk. To accept the large chunk and silence this warning, set the option
    # >>> with dask.config.set(**{'array.slicing.split_large_chunks': False})
    dask.config.set(**{"array.slicing.split_large_chunks": False})
    ProgressBar().register()
    if "DASK_SCHEDULER_ADDRESS" in os.environ:
        client = Client()
        logger.info(f"Initialized script with dask client:\n{client}")
    else:
        logger.info(
            "Skipping initialization of distributed scheduler "
            "(no `DASK_SCHEDULER_ADDRESS` found in env)"
        )


def add_protocol(url, protocol="gs"):
    if not urlparse(str(url)).scheme:
        return protocol + "://" + url
    return url


def get_chunks(ds: Dataset, var: str = "call_genotype_probability") -> Dict[str, int]:
    chunks = dict(zip(ds[var].dims, ds[var].data.chunksize))
    return {d: chunks[d] if d in {"variants", "samples"} else -1 for d in ds.dims}


def load_dataset(
    path: str, unpack: bool = False, consolidated: bool = False
) -> Dataset:
    store = fsspec.get_mapper(path, check=False, create=False)
    ds = xr.open_zarr(store, concat_characters=False, consolidated=consolidated)
    if unpack:
        ds = unpack_variables(ds, dtype="float16")
    for v in ds:
        # Workaround for https://github.com/pydata/xarray/issues/4386
        if v.endswith("_mask"):
            ds[v] = ds[v].astype(bool)
    return ds


def save_dataset(ds: Dataset, path: str, retries: int = 3):
    store = fsspec.get_mapper(path, check=False, create=False)
    for v in ds:
        ds[v].encoding.pop("chunks", None)
    task = ds.to_zarr(store, mode="w", consolidated=True, compute=False)
    task.compute(retries=retries)


def load_sample_qc(sample_qc_path: str) -> Dataset:
    store = fsspec.get_mapper(sample_qc_path, check=False, create=False)
    ds = xr.open_zarr(store, consolidated=True)
    ds = ds.rename_vars(dict(eid="id"))
    ds = ds.rename_vars({v: f"sample_{v}" for v in ds})
    if "sample_sex" in ds:
        # Rename to avoid conflict with bgen field
        ds = ds.rename_vars({"sample_sex": "sample_qc_sex"})
    return ds


def variant_genotype_counts(ds: Dataset) -> DataArray:
    gti = ds["call_genotype_probability"].argmax(dim="genotypes")
    gti = gti.astype("uint8").expand_dims("genotypes", axis=-1)
    gti = gti == da.arange(ds.dims["genotypes"], dtype="uint8")
    return gti.sum(dim="samples", dtype="int32")


def apply_filters(ds: Dataset, filters: Dict[str, Any], dim: str) -> Dataset:
    logger.info("Filter summary (True = kept, False = removed):")
    mask = []
    for k, v in filters.items():
        v = v.compute()
        logger.info(f"\t{k}: {v.to_series().value_counts().to_dict()}")
        mask.append(v.values)
    mask = np.stack(mask, axis=1)
    mask = np.all(mask, axis=1)
    assert len(mask) == ds.dims[dim]
    if len(filters) > 1:
        logger.info(f"\toverall: {pd.Series(mask).value_counts().to_dict()}")
    return ds.isel(**{dim: mask})


def add_traits(ds: Dataset, phenotypes_path: str) -> Dataset:
    ds_tr = load_dataset(phenotypes_path, consolidated=True)
    ds = ds.assign_coords(samples=lambda ds: ds.sample_id).merge(
        ds_tr.assign_coords(samples=lambda ds: ds.sample_id),
        join="left",
        compat="override",
    )
    return ds.reset_index("samples").reset_coords(drop=True)


def add_covariates(ds: Dataset, npc: int = 20) -> Dataset:
    # See https://github.com/Nealelab/UK_Biobank_GWAS/blob/67289386a851a213f7bb470a3f0f6af95933b041/0.1/22.run_regressions.py#L71
    ds = (
        ds.assign(
            sample_age_at_recruitment_2=lambda ds: ds["sample_age_at_recruitment"] ** 2
        )
        .assign(
            sample_sex_x_age=lambda ds: ds["sample_genetic_sex"]
            * ds["sample_age_at_recruitment"]
        )
        .assign(
            sample_sex_x_age_2=lambda ds: ds["sample_genetic_sex"]
            * ds["sample_age_at_recruitment_2"]
        )
    )
    covariates = np.column_stack(
        [
            ds["sample_age_at_recruitment"].values,
            ds["sample_age_at_recruitment_2"].values,
            ds["sample_genetic_sex"].values,
            ds["sample_sex_x_age"].values,
            ds["sample_sex_x_age_2"].values,
            ds["sample_principal_component"].values[:, :npc],
        ]
    )
    assert np.all(np.isfinite(covariates))
    ds["sample_covariate"] = xr.DataArray(covariates, dims=("samples", "covariates"))
    ds["sample_covariate"] = ds.sample_covariate.pipe(
        lambda x: (x - x.mean(dim="samples")) / x.std(dim="samples")
    )
    assert np.all(np.isfinite(ds.sample_covariate))
    return ds


SAMPLE_QC_COLS = [
    "sample_id",
    "sample_qc_sex",
    "sample_genetic_sex",
    "sample_age_at_recruitment",
    "sample_principal_component",
    "sample_ethnic_background",
    "sample_genotype_measurement_batch",
    "sample_genotype_measurement_plate",
    "sample_genotype_measurement_well",
]


def apply_sample_qc_1(ds: Dataset, sample_qc_path: str) -> Dataset:
    ds_sqc = load_sample_qc(sample_qc_path)
    ds_sqc = sample_qc_1(ds_sqc)
    ds_sqc = ds_sqc[SAMPLE_QC_COLS]
    ds = ds.assign_coords(samples=lambda ds: ds.sample_id).merge(
        ds_sqc.assign_coords(samples=lambda ds: ds.sample_id).compute(),
        join="inner",
        compat="override",
    )
    return ds.reset_index("samples").reset_coords(drop=True)


def sample_qc_1(ds: Dataset) -> Dataset:
    # See:
    # - https://github.com/Nealelab/UK_Biobank_GWAS#imputed-v3-sample-qc
    # - https://github.com/Nealelab/UK_Biobank_GWAS/blob/master/0.1/04.subset_samples.py
    filters = {
        "no_aneuploidy": ds.sample_sex_chromosome_aneuploidy.isnull(),
        "has_age": ds.sample_age_at_recruitment.notnull(),
        "in_phasing_chromosome_x": ds.sample_use_in_phasing_chromosome_x == 1,
        "in_in_phasing_chromosomes_1_22": ds.sample_use_in_phasing_chromosomes_1_22
        == 1,
        "in_pca": ds.sample_used_in_genetic_principal_components == 1,
        # 1001 = White/British, 1002 = Mixed/Irish
        "in_ethnic_groups": ds.sample_ethnic_background.isin([1001, 1002]),
    }
    return apply_filters(ds, filters, dim="samples")


def variant_qc_1(ds: Dataset) -> Dataset:
    # See: https://github.com/Nealelab/UK_Biobank_GWAS#imputed-v3-variant-qc
    ds = apply_filters(ds, {"high_info": ds.variant_info > 0.8}, dim="variants")
    return ds


def variant_qc_2(ds: Dataset) -> Dataset:
    # See: https://github.com/Nealelab/UK_Biobank_GWAS#imputed-v3-variant-qc
    ds["variant_genotype_counts"] = variant_genotype_counts(ds)[
        :, [1, 0, 2]
    ]  # Order: het, hom_ref, hom_alt
    ds = sg.hardy_weinberg_test(ds, genotype_counts="variant_genotype_counts", ploidy=2)
    ds = apply_filters(ds, {"high_maf": ds.variant_maf > 0.001}, dim="variants")
    ds = apply_filters(ds, {"in_hwe": ds.variant_hwe_p_value > 1e-10}, dim="variants")

    return ds


def run_qc_1(input_path: str, output_path: str):
    init()
    logger.info(
        f"Running stage 1 QC (input_path={input_path}, output_path={output_path})"
    )
    ds = load_dataset(input_path, unpack=False, consolidated=False)

    logger.info(f"Loaded dataset:\n{ds}")
    chunks = get_chunks(ds)

    logger.info("Applying QC filters")
    ds = variant_qc_1(ds)

    ds = ds.chunk(chunks=chunks)
    logger.info(f"Saving dataset to {output_path}:\n{ds}")
    save_dataset(ds, output_path)
    logger.info("Done")


def run_qc_2(input_path: str, sample_qc_path: str, output_path: str):
    init()
    logger.info(
        f"Running stage 1 QC (input_path={input_path}, output_path={output_path})"
    )
    ds = load_dataset(input_path, unpack=True, consolidated=True)

    logger.info(f"Loaded dataset:\n{ds}")
    chunks = get_chunks(ds)

    logger.info("Applying variant QC filters")
    ds = variant_qc_2(ds)

    # Drop probability since it is very large and was only necessary
    # for computing QC-specific quantities
    ds = ds.drop_vars(["call_genotype_probability", "call_genotype_probability_mask"])

    logger.info(f"Applying sample QC filters (sample_qc_path={sample_qc_path})")
    ds = apply_sample_qc_1(ds, sample_qc_path=sample_qc_path)

    ds = ds.chunk(chunks=chunks)
    logger.info(f"Saving dataset to {output_path}:\n{ds}")
    save_dataset(ds, output_path)
    logger.info("Done")


def load_gwas_ds(genotypes_path: str, phenotypes_path: str) -> Dataset:
    ds = load_dataset(genotypes_path, consolidated=True)
    ds = add_covariates(ds)
    ds = add_traits(ds, phenotypes_path)
    ds = ds[[v for v in sorted(ds)]]
    return ds


def wait_fn(attempts, delay):
    delay = min(2 ** attempts * 1000, 300000)
    logger.info(f"Attempt {attempts}, retrying in {delay} ms")
    return delay


def exception_fn(e):
    logger.error(f"A retriable error occurred: {e}")
    logger.error("Traceback:\n")
    logger.error("\n" + "".join(traceback.format_tb(e.__traceback__)))
    return True


@retry(retry_on_exception=exception_fn, wait_func=wait_fn)
def run_trait_gwas(
    ds: Dataset,
    trait_group_id: int,
    trait_name: str,
    batch_index: int,
    min_samples: int,
    retries: int = 3,
) -> pd.DataFrame:
    assert ds["sample_trait_group_id"].to_series().nunique() == 1
    assert ds["sample_trait_name"].to_series().nunique() == 1

    # Filter to complete cases
    start = time.perf_counter()
    n = ds.dims["samples"]
    ds = ds.isel(samples=ds["sample_trait"].notnull().all(dim="traits").values)
    stop = time.perf_counter()
    sample_size = ds.dims["samples"]
    logger.info(
        f"Found {sample_size} complete cases of {n} for '{trait_name}' (id={trait_group_id}) in {stop - start:.1f} seconds"
    )

    # Bypass if sample size too small
    if sample_size < min_samples:
        logger.warning(
            f"Sample size ({sample_size}) too small (<{min_samples}) for trait '{trait_name}' (id={trait_group_id})"
        )
        return None

    logger.info(
        f"Running GWAS for '{trait_name}' (id={trait_group_id}) with {sample_size} samples, {ds.dims['traits']} traits"
    )

    start = time.perf_counter()
    logger.debug(
        f"Input dataset for trait '{trait_name}' (id={trait_group_id}) GWAS:\n{ds}"
    )

    ds = sg.gwas_linear_regression(
        ds,
        dosage="call_dosage",
        covariates="sample_covariate",
        traits="sample_trait",
        add_intercept=True,
        merge=True,
    )

    # Project and convert to data frame for convenience
    # in downstream analysis/comparisons
    ds = ds[
        [
            "sample_trait_id",
            "sample_trait_name",
            "sample_trait_group_id",
            "sample_trait_code_id",
            "variant_id",
            "variant_contig",
            "variant_contig_name",
            "variant_position",
            "variant_p_value",
            "variant_beta",
        ]
    ]

    if os.getenv("GENERATE_PERFORMANCE_REPORT", "").lower() == "true":
        with performance_report(
            f"logs/reports/pr_{trait_group_id}_{batch_index}.html"
        ), get_task_stream(
            plot="save", filename=f"logs/reports/ts_{trait_group_id}_{batch_index}.html"
        ):
            ds = ds.compute(retries=retries)
    else:
        ds = ds.compute(retries=retries)
    df = (
        ds.to_dataframe()
        .reset_index()
        .assign(sample_size=sample_size)
        .rename(columns={"traits": "trait_index", "variants": "variant_index"})
    )
    stop = time.perf_counter()
    logger.info(
        f"GWAS for '{trait_name}' (id={trait_group_id}) complete in {stop - start:.1f} seconds"
    )
    return df


@retry(retry_on_exception=exception_fn, wait_func=wait_fn)
def save_gwas_results(df: pd.DataFrame, path: str):
    start = time.perf_counter()
    df.to_parquet(path)
    stop = time.perf_counter()
    logger.info(f"Save to {path} complete in {stop - start:.1f} seconds")


@retry(retry_on_exception=exception_fn, wait_func=wait_fn)
def run_batch_gwas(
    ds: xr.Dataset,
    trait_group_id: str,
    trait_names: np.ndarray,
    batch_size: int,
    min_samples: int,
    sumstats_path: str,
):
    # Determine which individual traits should be regressed as part of this group
    mask = (ds["sample_trait_group_id"] == trait_group_id).values
    index = np.sort(np.argwhere(mask).ravel())
    if len(index) == 0:
        logger.warning(f"Trait group id {trait_group_id} not found in data (skipping)")
        return
    trait_name = trait_names[index][0]

    # Break the traits for this group into batches of some maximum size
    batches = np.array_split(index, np.ceil(len(index) / batch_size))
    if len(batches) > 1:
        logger.info(
            f"Broke {len(index)} traits for '{trait_name}' (id={trait_group_id}) into {len(batches)} batches"
        )
    for batch_index, batch in enumerate(batches):
        path = f"{sumstats_path}_{batch_index:03d}_{trait_group_id}.parquet"
        if fs.exists(path):
            logger.info(
                f"Results for trait '{trait_name}' (id={trait_group_id}) at path {path} already exist (skipping)"
            )
            continue
        dsg = ds.isel(traits=batch)
        df = run_trait_gwas(
            dsg, trait_group_id, trait_name, batch_index, min_samples=min_samples
        )
        if df is None:
            continue
        # Write results for all traits in the batch together so that partitions
        # will have a maximum possible size determined by trait batch size
        # and number of variants in the current contig
        df = df.assign(batch_index=batch_index, batch_size=len(batch))
        logger.info(
            f"Saving results for trait '{trait_name}' id={trait_group_id}, batch={batch_index} to path {path}"
        )
        save_gwas_results(df, path)


def run_gwas(
    genotypes_path: str,
    phenotypes_path: str,
    sumstats_path: str,
    variables_path: str,
    batch_size: int = 100,
    trait_group_ids: Optional[Union[Sequence[Union[str, int]], str]] = None,
    min_samples: int = 100,
):
    init()

    logger.info(
        f"Running GWAS (genotypes_path={genotypes_path}, phenotypes_path={phenotypes_path}, "
        f"sumstats_path={sumstats_path}, variables_path={variables_path})"
    )

    ds = load_gwas_ds(genotypes_path, phenotypes_path)

    # Promote to f4 to avoid:
    # TypeError: array type float16 is unsupported in linalg
    ds["call_dosage"] = ds["call_dosage"].astype("float32")

    # Rechunk dosage (from 5216 x 5792 @ TOW) down to something smaller in the
    # variants dimension since variant_chunk x n_sample arrays need to
    # fit in memory for linear regression (652 * 365941 * 4 = 954MB)
    # See: https://github.com/pystatgen/sgkit/issues/390
    ds["call_dosage"] = ds["call_dosage"].chunk(chunks=(652, 5792))

    logger.info(f"Loaded dataset:\n{ds}")

    # Determine the UKB field ids corresponding to all phenotypes to be used
    # * a `trait_group_id` is equivalent to a UKB field id
    if trait_group_ids is None:
        # Use all known traits
        trait_group_ids = list(map(int, np.unique(ds["sample_trait_group_id"].values)))
    elif isinstance(trait_group_ids, str):
        # Load from file
        trait_group_ids = [
            int(v) for v in pd.read_csv(trait_group_ids, sep="\t")["trait_group_id"]
        ]
    else:
        # Assume a sequence was provided
        trait_group_ids = [int(v) for v in trait_group_ids]
    logger.info(
        f"Using {len(trait_group_ids)} trait groups; first 10: {trait_group_ids[:10]}"
    )

    # Loop through the trait groups and run separate regressions for each.
    # Note that the majority of groups (89%) have only one phenotype/trait
    # associated, some (10%) have between 1 and 10 phenotypes and ~1% have
    # large numbers of phenotypes (into the hundreds or thousands).
    trait_names = ds["sample_trait_name"].values
    for trait_group_id in trait_group_ids:
        run_batch_gwas(
            ds=ds,
            trait_group_id=trait_group_id,
            trait_names=trait_names,
            batch_size=batch_size,
            min_samples=min_samples,
            sumstats_path=sumstats_path,
        )
    logger.info("Sumstat generation complete")

    ds = ds[
        [
            "variant_contig",
            "variant_contig_name",
            "variant_id",
            "variant_rsid",
            "variant_position",
            "variant_allele",
            "variant_minor_allele",
            "variant_hwe_p_value",
            "variant_maf",
            "variant_info",
            "sample_id",
            "sample_principal_component",
            "sample_covariate",
            "sample_genetic_sex",
            "sample_age_at_recruitment",
            "sample_ethnic_background",
            "sample_trait",
            "sample_trait_id",
            "sample_trait_group_id",
            "sample_trait_code_id",
            "sample_trait_name",
        ]
    ]
    ds = ds.chunk("auto")
    path = variables_path + "_variables.zarr"
    logger.info(f"Saving GWAS variables to {path}:\n{ds}")
    save_dataset(ds, path)

    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import logging
import logging.config
import re
from pathlib import Path
from typing import Optional, Sequence, Union

import fire
import fsspec
import numpy as np
import pandas as pd

logging.config.fileConfig(Path(__file__).resolve().parents[1] / "log.ini")
logger = logging.getLogger(__name__)

###########################
# Sgkit sumstat functions #


def get_gwas_sumstat_manifest(path: str) -> pd.DataFrame:
    store = fsspec.get_mapper(path)
    df = []
    for f in list(store):
        fn = f.split("/")[-1]
        parts = re.findall(r"ukb_chr(\d+)_(\d+)_(.*).parquet", fn)
        if not parts:
            continue
        parts = parts[0]
        df.append(
            dict(
                contig=parts[0],
                batch=int(parts[1]),
                trait_id=parts[2],
                trait_group_id=parts[2].split("_")[0],
                trait_code_id="_".join(parts[2].split("_")[1:]),
                file=f,
            )
        )
    return pd.DataFrame(df)


def load_gwas_sumstats(path: str) -> pd.DataFrame:
    logger.info(f"Loading GWAS sumstats from {path}")
    return (
        pd.read_parquet(path)
        .rename(columns=lambda c: c.replace("sample_", ""))
        .rename(columns={"variant_contig_name": "contig"})
        # b'21' -> '21'
        .assign(contig=lambda df: df["contig"].str.decode("utf-8"))
        # b'21:9660864_G_A' -> '21:9660864:G:A'
        .assign(
            variant_id=lambda df: df["variant_id"]
            .str.decode("utf-8")
            .str.replace("_", ":")
        )
        .drop(["variant_index", "trait_index", "variant_contig"], axis=1)
        .rename(columns={"variant_p_value": "p_value"})
        .set_index(["trait_id", "contig", "variant_id"])
        .add_prefix("gwas_")
    )


########################
# OT sumstat functions #


def get_ot_sumstat_manifest(path: str) -> pd.DataFrame:
    # See https://github.com/related-sciences/ukb-gwas-pipeline-nealelab/issues/31 for example paths
    store = fsspec.get_mapper(path)
    files = list(store)
    df = []
    for f in files:
        if not f.endswith(".tsv.gz"):
            continue
        trait_id = f.split("/")[-1].split(".")[0]
        if trait_id.endswith("_irnt"):
            # Ignore rank normalized continuous outcomes in favor of "raw" outcomes
            continue
        if trait_id.endswith("_raw"):
            trait_id = trait_id.replace("_raw", "")
        df.append(
            dict(
                trait_id=trait_id,
                trait_group_id=trait_id.split("_")[0],
                trait_code_id="_".join(trait_id.split("_")[1:]),
                file=f,
            )
        )
    return pd.DataFrame(df)


OT_COLS = [
    "chromosome",
    "variant",
    "variant_id",
    "minor_allele",
    "n_complete_samples",
    "beta",
    "p-value",
    "tstat",
]


def load_ot_trait_sumstats(path: str, row: pd.Series) -> pd.DataFrame:
    path = path + "/" + row["file"]
    logger.info(f"Loading OT sumstats from {path}")
    return (
        pd.read_csv(path, sep="\t", usecols=OT_COLS, dtype={"chromosome": str},)
        .rename(columns={"variant_id": "variant_rsid"})
        .rename(
            columns={
                "chromosome": "contig",
                "p-value": "p_value",
                "variant": "variant_id",
            }
        )
        .assign(
            trait_id=row["trait_id"],
            trait_group_id=row["trait_group_id"],
            trait_code_id=row["trait_code_id"],
        )
        .set_index(["trait_id", "contig", "variant_id"])
        .add_prefix("ot_")
    )


def load_ot_sumstats(
    path: str, df: pd.DataFrame, contigs: Sequence[str]
) -> pd.DataFrame:
    df = pd.concat([load_ot_trait_sumstats(path, row) for _, row in df.iterrows()])
    df = df.loc[df.index.get_level_values("contig").isin(contigs)]
    return df


#########
# Merge #


def run(
    gwas_sumstats_path: str,
    ot_sumstats_path: str,
    output_path: str,
    contigs: Optional[Union[str, Sequence[str]]] = None,
    trait_group_ids: Optional[Union[str, Sequence[str]]] = None,
):
    df_sg = get_gwas_sumstat_manifest(
        "gs://rs-ukb/pipe/nealelab-gwas-uni-ancestry-v3/output/gt-imputation/sumstats"
    )
    df_ot = get_ot_sumstat_manifest("gs://rs-ukb/external/ot_nealelab_sumstats")

    def prep_filter(v, default_values):
        if isinstance(v, str):
            v = v.split(",")
        if v is None:
            v = default_values
        return [str(e) for e in v]

    contigs = prep_filter(contigs, df_sg["contig"].unique())
    trait_group_ids = prep_filter(trait_group_ids, df_sg["trait_group_id"].unique())
    logger.info(f"Using {len(contigs)} contigs (first 10: {contigs[:10]})")
    logger.info(
        f"Using {len(trait_group_ids)} trait_group_ids (first 10: {trait_group_ids[:10]})"
    )

    def apply_trait_filter(df):
        return df[df["trait_group_id"].isin(trait_group_ids)]

    def apply_contig_filter(df):
        return df[df["contig"].isin(contigs)]

    df_sg, df_ot = df_sg.pipe(apply_trait_filter), df_ot.pipe(apply_trait_filter)
    df_sg = df_sg.pipe(apply_contig_filter)

    # Only load OT sumstats for traits present in GWAS comparison data
    df_ot = df_ot[df_ot["trait_group_id"].isin(df_sg["trait_group_id"].unique())]

    logger.info(f"Loading GWAS sumstats ({len(df_sg)} partitions)")
    df_sg = pd.concat(
        [load_gwas_sumstats(gwas_sumstats_path + "/" + f) for f in df_sg["file"]]
    )
    df_sg.info()

    logger.info(f"Loading OT sumstats ({len(df_ot)} partitions)")
    df_ot = load_ot_sumstats(ot_sumstats_path, df_ot, contigs)
    df_ot.info()

    logger.info("Merging")
    assert df_sg.index.unique
    assert df_ot.index.unique
    df = pd.concat([df_sg, df_ot], axis=1, join="outer")
    df["gwas_log_p_value"] = -np.log10(df["gwas_p_value"])
    df["ot_log_p_value"] = -np.log10(df["ot_p_value"])
    df = df[sorted(df)]

    logger.info(f"Saving result to {output_path}:\n")
    df.info()
    df.to_parquet(output_path)
    logger.info("Done")


if __name__ == "__main__":
    fire.Fire()
ShowHide 24 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/related-sciences/ukb-gwas-pipeline-nealelab
Name: ukb-gwas-pipeline-nealelab
Version: v0.1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...