Probe: Population Genomic Analysis Workflow for Diploid Organisms

public public 1yr ago 0 bookmarks

Probe

Probe is a snakemake workflow to apply a range of population genomic analyses to whole genome sequencing data, in any diploid organism. Data can be supplied in VCF or Zarr format, or An. gambiae s.l data can be accessed directly in the cloud through the malariagen_data API . Called SNPs should be of high confidence, although a site filter mask can additionally be applied.

The workflow is a WIP, most modules are functional but nothing can be guaranteed :P

Code Snippets

16
17
18
19
shell: 
    """
    python -m ipykernel install --user --name probe 2> log
    """
27
28
script:
    "../scripts/pca.py"
68
69
script:
    "../scripts/f2VariantLocate.py"
93
94
script:
    "../scripts/f2HaplotypeLength.py"
115
116
117
118
shell:
    """
    {params.basedir}/scripts/NgsRelate/ngsRelate -h {input.vcf} -O {output} -c 1 -T {params.tag} -p {threads} 2> {log}
    """
28
29
script:
    "../scripts/VariantsOfInterest.py"
61
62
63
64
65
66
67
68
69
shell:
    """
    papermill {input.nb} {output.nb} -k probe -p cloud {params.cloud} -p ag3_sample_sets {params.ag3_sample_sets} \
    -p contig {wildcards.contig} -p stat G12 -p windowSize {params.windowSize} -p windowStep {params.windowStep} \ 
    -p cutHeight {params.CutHeight} -p metaColumns {params.columns} -p minPopSize {params.minPopSize}

    python -m nbconvert {output.nb} --to html --stdout --no-input \
         --ExecutePreprocessor.kernel_name=probe > {output.html}
    """
96
97
script:
    "../scripts/GarudsStatistics.py"
193
194
script:
    "../scripts/PopulationBranchStatistic.py"
13
14
wrapper:
    "v0.69.0/bio/samtools/faidx"
40
41
script:
    "../scripts/ZarrToVCF.py"
57
58
script:
    "../scripts/ZarrToVCF_haplotypes.py"
72
73
74
75
shell:
    """
    bgzip {input.calls} 2> {log}
    """
84
85
86
87
shell:
    """
    bcftools index {input.calls} 2> {log}
    """
96
97
98
99
shell:
    """
    tabix {input.calls} 2> {log}
    """
112
113
114
115
shell:
    """
    bcftools concat -o {output.cattedVCF} -O z --threads {threads} {input.calls} 2> {log}
    """
125
126
127
128
shell:
    """
    bcftools index {input.calls} 2> {log}
    """
137
138
139
140
shell:
    """
    tabix {input.calls} 2> {log}
    """
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
from pathlib import Path
import numpy as np
import pandas as pd
import allel
import dask.array as da
import seaborn as sns
import matplotlib.pyplot as plt
from numba import njit

cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets']
contig = snakemake.wildcards['contig']
genotypePath = snakemake.params['genotypes'] 
positionsPath = snakemake.params['positions']


# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data..Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")


@njit()
def scanRight(geno1, geno2, upperBreakpoint, max): 
    # Reset position for next loop
    gn1 = geno1[upperBreakpoint]
    gn2 = geno2[upperBreakpoint]

    # Scan right along genome, as long as two inds are not both homozygous but different
    while not (gn1[0] == gn1[1]) & (gn2[0] == gn2[1]) & ((gn1 != gn2).all()) & (-1 not in gn1 and -1 not in gn2):

        upperBreakpoint += 1
        if upperBreakpoint == max-1: # limit the upper breakpoint at end of the contig
            return(upperBreakpoint)

        gn1 = geno1[upperBreakpoint]
        gn2 = geno2[upperBreakpoint]

    return(upperBreakpoint)

@njit()
def scanLeft(geno1, geno2, lowerBreakpoint): 

    # subset genotypes to this position, because we need somethign to start the while loop?
    gn1 = geno1[lowerBreakpoint]
    gn2 = geno2[lowerBreakpoint]

    # Scan left along genome
    while not (gn1[0] == gn1[1]) & (gn2[0] == gn2[1]) & ((gn1 != gn2).all()) & (-1 not in gn1 and -1 not in gn2):

        lowerBreakpoint -= 1
        if lowerBreakpoint == 0: # limit lower breakpoint at zero, start of contig
            return(lowerBreakpoint)

        gn1 = geno1[lowerBreakpoint]
        gn2 = geno2[lowerBreakpoint]

    return(lowerBreakpoint)


@njit()
def f2scans(dblton_arr, snps, pos):

    starts = []# np.empty((len(dblton_arr)),dtype='uint8')
    ends = []# np.empty((len(dblton_arr)),dtype='uint8')
    dbltonpos = []

    for idx in range(0, len(dblton_arr)):

        geno1 = snps[:, dblton_arr[idx][0]]
        geno2 = snps[:, dblton_arr[idx][1]]
        # get boolean of dblton idx
        #dblton_idx = bisect.bisect_left(pos, dblton_arr[idx][2])
        dblton_idx = np.searchsorted(pos, dblton_arr[idx][2])

        upperBreakpoint = scanRight(geno1, geno2, dblton_idx, len(pos))
        # Scan left along genome
        lowerBreakpoint = scanLeft(geno1, geno2, dblton_idx)

        starts.append(pos[lowerBreakpoint])
        ends.append(pos[upperBreakpoint])
        dbltonpos.append(dblton_arr[idx][2])

    return(np.array(starts), np.array(ends), np.array(dbltonpos))



########## main #################


snps = {}
pos = {}

# Load Arrays
snps, pos = probe.loadZarrArrays(genotypePath=genotypePath, 
                                            positionsPath=positionsPath,
                                            siteFilterPath=None,
                                            cloud=cloud)
ac = snps.count_alleles()
seg = ac.is_segregating()
snps = snps.compress(seg, axis=0).compute(numworkers=12)
pos = pos[seg]

### Load doubletons
dblton = pd.read_csv(snakemake.input['f2variantPairs'], sep="\t")
dblton = dblton.query("contig == @contig")
dblton_arr = dblton[['idx1', 'idx2', 'pos']].to_numpy()

# extract np array 
snps = snps.values
# Run F2 hap length scans
starts, ends, dbltonpos = f2scans(dblton_arr, snps, pos)

f2hapdf = pd.DataFrame({'start':starts, 'end':end, 'dbltonpos':dbltonpos})
f2hapdf.to_csv(f"results/f2HapLengths_{contig}.tsv", sep="\t")
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
from curses import meta
import sys
sys.stderr = open(snakemake.log[0], "w")

from tools import loadZarrArrays, log
from pathlib import Path
import numpy as np
import pandas as pd
import allel
import dask.array as da
import seaborn as sns
import matplotlib.pyplot as plt

cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets']
genotypePath = snakemake.params['genotypes'] 
positionsPath = snakemake.params['positions']
siteFiltersPath = snakemake.params['sitefilters']
contigs = snakemake.params['']


# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data..Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")


def is_variant01(gn, allele):
    return((gn == allele).any())

def checkSampleID(x, metadata=metadata):
    name = metadata.loc[x,'partner_sample_id']
    return(name)


snps = {}
pos = {}

for contig in contigs:
    # Load Arrays
    snps[contig], pos[contig] = loadZarrArrays(genotypePath=genotypePath, 
                                             positionsPath=positionsPath,
                                             siteFilterPath=siteFiltersPath)


    ac = snps[contig].count_alleles()
    seg = ac.is_segregating()
    snps[contig] = snps[contig].compress(seg, axis=0)
    pos[contig] = pos[contig][seg]

pos_dbltons = {}
inds_dbltons = {}

for contig in contigs:

    ac = snps[contig].count_alleles()
    doubletons_bool = ac.is_doubleton(allele=1).compute() # Need to do for each allele
    geno = snps[contig].compress(doubletons_bool, axis=0)
    log("Recorded dblton loc")
    pos_dbltons[contig] = pos[contig][doubletons_bool]

    n_doubletons = doubletons_bool.sum()
    log(f"There are {n_doubletons.shape} on {contig}")
    log("locating dblton sharers")

    # get 1 hets and 1 homs to each ind of the genotype data
    res = geno.is_het(1).compute()
    res2 = geno.is_hom(1).compute()
    dbhets = res.sum()
    dbhoms = res2.sum()
    totdbinds = dbhets + (2*dbhoms)
    assert (n_doubletons*2) == totdbinds, "Unequal individual samples v n_dbltons!!!"

    res = np.logical_or(res, res2)
    # and use np where to get indices. Pandas apply is fast.
    pairs = pd.DataFrame(res).apply(np.where, axis=1).apply(np.asarray).apply(lambda x: x.flatten())
    hom_filter = pairs.apply(len) == 2
    pos_dbltons[contig] = pos_dbltons[contig][hom_filter]
    pairs = pairs[hom_filter]

    #make 1d array into two column pd df
    log("organising arrays")
    idxs = pd.DataFrame(np.vstack(pairs.values))
    dblton = pd.DataFrame(np.vstack(pairs.values), columns=['partner_sample_id','partner_sample_id2'])
    # shouldnt be any but remove hom/homs
    dblton = dblton.query("partner_sample_id != partner_sample_id2").reset_index(drop=True)
    dblton = dblton.applymap(checkSampleID) #store WA-XXXX ID
    inds_dbltons[contig] = pd.concat([idxs, dblton], axis=1)
    inds_dbltons[contig]['pos'] = pos_dbltons[contig]

dblton = pd.concat(inds_dbltons).reset_index().drop(columns=['level_1']).rename(columns={'level_0':'contig', 0:'idx1', 1:'idx2'})

dblton = dblton.merge(metadata[['partner_sample_id', 'latitude', 'longitude']])
dblton = dblton.merge(metadata.rename(columns={'partner_sample_id':'partner_sample_id2', 
                                      'latitude': 'latitude2', 
                                      'longitude':  'longitude2'})[['partner_sample_id2', 'latitude2', 'longitude2']])
dblton = dblton.sort_values(by=['contig', 'pos']).reset_index(drop=True)

dblton.to_csv("results/f2variantPairs.tsv", sep="\t", index=None)
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
import numpy as np
import pandas as pd
import allel
import dask.array as da
import scipy
import seaborn as sns
import matplotlib.pyplot as plt


# Garuds Selection Scans # 
cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets']
contig = snakemake.wildcards['contig']
stat = snakemake.params['GarudsStat']
windowSize = snakemake.params['windowSize']
windowStep = snakemake.params['windowStep']
cutHeight = snakemake.params['cutHeight'] if stat in ['G12', 'G123'] else []

if not cloud:
    genotypePath = snakemake.input['genotypes'] if stat in ['G12', 'G123'] else []
    haplotypePath = snakemake.input['haplotypes'] if stat in ['H1', 'H12', 'H2/1'] else []
    positionsPath = snakemake.input['positions']
    siteFilterPath = snakemake.input['siteFilters'] if stat in ['H1', 'H12', 'H2/1'] else []
else:
    genotypePath = []
    haplotypePath = []
    positionsPath = []
    siteFilterPath = []

# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data.Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")

# Load arrays 
if stat in ['H1', 'H12', 'H2/1']:
    haps, pos = probe.loadZarrArrays(haplotypePath, positionsPath, siteFilterPath=None, haplotypes=True, cloud=cloud, contig=contig)
elif stat in ['G12', 'G123']:
    snps, pos = probe.loadZarrArrays(genotypePath, positionsPath, siteFilterPath=siteFilterPath, haplotypes=False, cloud=cloud, contig=contig)
else:
    raise AssertionError("The statistic selected is not 'G12, G123, or H12")



# Define functions
def clusterMultiLocusGenotypes(gnalt, cut_height=0.1, metric='euclidean', g=2):
    """
    Hierarchically clusters genotypes and calculates G12 statistic. 
    """
    # cluster the genotypes in the window
    dist = scipy.spatial.distance.pdist(gnalt.T, metric=metric)
    if metric in {'hamming', 'jaccard'}:
        # convert distance to number of SNPs, easier to interpret
        dist *= gnalt.shape[0]

    Z = scipy.cluster.hierarchy.linkage(dist, method='single')
    cut = scipy.cluster.hierarchy.cut_tree(Z, height=cut_height)[:, 0]
    cluster_sizes = np.bincount(cut)
    #clusters = [np.nonzero(cut == i)[0] for i in range(cut.max() + 1)] #returns indices of individuals in each cluster

    # get freq of clusters and sort by largest freq
    freqs = cluster_sizes/gnalt.shape[1]
    freqs = np.sort(freqs)[::-1]

    # calculate garuds statistic
    gStat = np.sum(freqs[:g])**2 + np.sum(freqs[g:]**2)

    return(gStat)


def garudsStat(stat, geno, pos, cut_height=None, metric='euclidean', window_size=1200, step_size=600):

    """
    Calculates G12/G123/H12
    """

    # Do we want to cluster the Multi-locus genotypes (MLGs), or just group MLGs if they are identical
    if stat == "G12":
        garudsStat = allel.moving_statistic(geno, clusterMultiLocusGenotypes, size=window_size, step=step_size, metric=metric, cut_height=cut_height, g=2)
    elif stat == "G123":
        garudsStat = allel.moving_statistic(geno, clusterMultiLocusGenotypes, size=window_size, step=step_size, metric=metric, cut_height=cut_height, g=3)
    elif stat == "H12":
        garudsStat,_,_,_ = allel.moving_garud_h(geno, size=window_size, step=step_size)
    else:
        raise ValueError("Statistic is not G12/G123/H12")

    midpoint = allel.moving_statistic(pos, np.median, size=window_size, step=step_size)

    return(garudsStat, midpoint)



#### Load cohort data and their indices in genotype data
### run garudStat for that query. already loaded contigs 

cohorts = probe.getCohorts(metadata=metadata, 
                    columns=snakemake.params.columns, 
                    minPopSize=snakemake.params.minPopSize, excludepath="resources/sib_group_table.csv")


# Loop through each cohort, manipulate genotype arrays and calculate chosen Garuds Statistic
for idx, cohort in cohorts.iterrows():

    if stat in ['H1', 'H12', 'H123']:
        # get indices for haplotype Array and filter
        hapInds = np.sort(np.concatenate([np.array(cohort['indices'])*2, np.array(cohort['indices']*2)+1]))
        gt_cohort = haps.take(hapInds, axis=1)
    elif stat in ['G12', 'G123']:
        # filter to correct loc, year, species individuals
        gt_cohort = snps.take(cohort['indices'], axis=1)
    else:
        raise ValueError("Statistic is not G12/G123/H1/H12")

    probe.log(f"--------- Running {stat} on {cohort['cohortText']} | Chromosome {contig} ----------")
    probe.log("filter to biallelic segregating sites")

    ac_cohort = gt_cohort.count_alleles(max_allele=3).compute()
    # N.B., if going to use to_n_alt later, need to make sure sites are 
    # biallelic and one of the alleles is the reference allele
    ref_ac = ac_cohort[:, 0]
    loc_sites = ac_cohort.is_biallelic() & (ref_ac > 0)
    gt_seg = da.compress(loc_sites, gt_cohort, axis=0)
    pos_seg = da.compress(loc_sites, pos, axis=0)

    probe.log(f"compute input data for {stat}")
    pos_seg = pos_seg.compute()

    if stat in ['G12', 'G123']:
        gt_seg = allel.GenotypeDaskArray(gt_seg).to_n_alt().compute()

    # calculate G12/G123/H12 and plot figs 
    gStat, midpoint = garudsStat(stat=stat,
                                geno=gt_seg, 
                                pos=pos_seg, 
                                cut_height=cutHeight,
                                metric='euclidean',
                                window_size=windowSize,
                                step_size=windowStep)

    probe.windowedPlot(statName=stat, 
                cohortText = cohort['cohortText'],
                cohortNoSpaceText= cohort['cohortNoSpaceText'],
                values=gStat, 
                midpoints=midpoint,
                prefix=f"results/selection/{stat}", 
                contig=contig,
                colour=cohort['colour'],
                ymin=0,
                ymax=0.5)
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.express as px
import dask.array as da
import seaborn as sns
import matplotlib.pyplot as plt


# Garuds Selection Scans # 
cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets']
pcaColumn = snakemake.params['cohortColumn']
contig = snakemake.wildcards['contig']
dataset = snakemake.params['dataset']
genotypePath = snakemake.input['genotypes'] if not cloud else []
positionsPath = snakemake.input['positions'] if not cloud else []
siteFilterPath = snakemake.input['siteFilters'] if not cloud else []

results_dir = snakemake.params['data']

# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data..Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")

# Load Arrays
snps, pos = probe.loadZarrArrays(genotypePath, positionsPath, siteFilterPath=siteFilterPath, cloud=cloud, haplotypes=False, contig=contig)

# Determine cohorts
cohorts = probe.getCohorts(metadata, columns=pcaColumn)


# choose colours for species
species_palette = px.colors.qualitative.Plotly
species_color_map = {
    'gambiae': species_palette[0],
    'coluzzii': species_palette[1],
    'arabiensis': species_palette[2],
    'intermediate_gambiae_coluzzii': species_palette[3],
    'intermediate_arabiensis_gambiae': species_palette[4],
}


# Run PCA on whole dataset together
data, evr = probe.run_pca(contig=contig, gt=snps, pos=pos, df_samples=metadata,
    sample_sets=dataset, results_dir=results_dir
)
evr = evr.astype("float").round(4) # round decimals for variance explained % 

probe.plot_coords(data, evr, title=f" PCA | {dataset} | {contig}", filename=f"results/PCA/{dataset}.{contig}.html")

fig = plt.figure(figsize=(10, 10))
fig = sns.scatterplot('PC1','PC2', data=data, hue="species_gambiae_coluzzii")
fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.title(f"PCA | {dataset} | {contig}", fontsize=14)
plt.xlabel(f"PC1 ({evr[0]*100} % variance explained)", fontdict={"fontsize":14})
plt.ylabel(f"PC2 ({evr[1]*100} % variance explained)", fontdict={"fontsize":14})
plt.savefig(f"results/PCA/{dataset}.{contig}.png")



# Loop through each cohort, manipulate genotype arrays and calculate chosen Garuds Statistic
for idx, cohort in cohorts.iterrows():

    # filter to correct loc, year, species individuals
    gt_cohort = snps.take(cohort['indices'], axis=1)
    meta = metadata.take(cohort['indices'])

    data, evr = probe.run_pca(contig=contig, gt=gt_cohort, pos=pos, df_samples=meta,
        sample_sets=cohort['cohortNoSpaceText'], results_dir=results_dir
    )
    evr = evr.astype("float").round(4)

    probe.plot_coords(data, evr, title=f" PCA | {cohort['cohortText']} | {contig}", filename=f"results/PCA/{cohort['cohortNoSpaceText']}.{contig}.html")

    fig = plt.figure(figsize=(10, 10))
    fig = sns.scatterplot('PC1','PC2', data=data, hue='location')
    fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title(f"PCA | {cohort['cohortText']} | {contig}", fontsize=14)
    plt.xlabel(f"PC1 ({evr[0]*100} % variance explained)", fontdict={"fontsize":14})
    plt.ylabel(f"PC2 ({evr[1]*100} % variance explained)", fontdict={"fontsize":14})
    plt.savefig(f"results/PCA/{cohort['cohortNoSpaceText']}.{contig}.png")
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
import numpy as np
import pandas as pd
import allel
import dask.array as da
import seaborn as sns
import matplotlib.pyplot as plt


# PBS Selection Scans # 
contig = snakemake.wildcards['contig']
stat = "PBS"
windowSize = snakemake.params['windowSize']
windowStep = snakemake.params['windowStep']
genotypePath = snakemake.input['genotypes']
positionsPath = snakemake.input['positions']
siteFilterPath = snakemake.input['siteFilters']

# Outgroup data
outgroupPath = snakemake.input['outgroupPath']
outgroupMetaPath = snakemake.input['outgroupMetaPath']
Mali2004Meta = pd.read_csv(outgroupMetaPath)
species = pd.read_csv("resources/AG1000G-ML-B/samples.species_aim.csv")
Mali2004Meta = Mali2004Meta.merge(species)

# Read metadata 
metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")

# Load arrays
snps, pos = probe.loadZarrArrays(genotypePath, positionsPath, siteFilterPath=siteFilterPath)

### Load outgroup Arrays and subset to each species, storing
snpsOutgroup, pos = probe.loadZarrArrays(outgroupPath, positionsPath, siteFilterPath=siteFilterPath)
snpsOutgroupDict = {}

for sp in ['gambiae', 'coluzzii']:
    sp_bool = Mali2004Meta['species_gambiae_coluzzii'] == sp
    snpsOutgroupDict[sp] =  snpsOutgroup.compress(sp_bool, axis=1)

#### Load cohort data and their indices in genotype data
### run garudStat for that query. already loaded contigs
cohorts = probe.getCohorts(metadata=metadata,
                    columns=snakemake.params.columns,
                    comparatorColumn=snakemake.params.comparatorColumn,
                    minPopSize=snakemake.params.minPopSize,
                    excludepath="resources/sib_group_table.csv")
cohorts = cohorts.dropna()

# Get name for phenotype of interest
pheno1, pheno2 = cohorts['indices'].columns.to_list()

# Loop through each cohort, manipulate genotype arrays and calculate chosen Garuds Statistic
for idx, cohort in cohorts.iterrows():

    probe.log(f"--------- Running {stat} on {cohort['cohortText'].to_list()} | Chromosome {contig} ----------")
    probe.log("filter to biallelic segregating sites")    
    species = cohort['species'].to_list()[0]
    if len(cohort['indices'][pheno1]) < snakemake.params.minPopSize:
        continue
    elif len(cohort['indices'][pheno2]) < snakemake.params.minPopSize:
        continue
    elif cohort['indices'][pheno1] == 'NaN':
        continue
    elif cohort['indices'][pheno2] == 'NaN':
        continue

    ac_cohort = snps.count_alleles(max_allele=3).compute()
    # N.B., if going to use to_n_alt later, need to make sure sites are 
    # biallelic and one of the alleles is the reference allele
    ref_ac = ac_cohort[:, 0]
    loc_sites = ac_cohort.is_biallelic() & (ref_ac > 0)
    gt_seg = da.compress(loc_sites, snps, axis=0)
    pos_seg = da.compress(loc_sites, pos, axis=0)

    probe.log(f"compute input data for {stat}")
    pos_seg = pos_seg.compute()

    ac_out = allel.GenotypeArray(da.compress(loc_sites, snpsOutgroupDict[species], axis=0)).count_alleles()
    ac_pheno1 = allel.GenotypeArray(gt_seg).take(cohort['indices'][pheno1], axis=1).count_alleles()
    ac_pheno2 = allel.GenotypeArray(gt_seg).take(cohort['indices'][pheno2], axis=1).count_alleles()

    assert ac_out.shape[0] == pos_seg.shape[0], "Array Outgroup/POS are the wrong length"
    assert ac_pheno1.shape[0] == pos_seg.shape[0], "Array phenotype1/POS are the wrong length"
    assert ac_pheno2.shape[0] == pos_seg.shape[0], "Arrays phenotype2/POS the wrong length"

    probe.log("calculate PBS and plot figs")
    # calculate PBS and plot figs 
    pbsArray = allel.pbs(ac_pheno1, ac_pheno2, ac_out, 
                window_size=windowSize, window_step=windowStep, normed=True)
    midpoint = allel.moving_statistic(pos_seg, np.mean, size=windowSize, step=windowStep)

    probe.windowedPlot(statName=stat, 
                cohortText = cohort['cohortText'].to_numpy()[0],
                cohortNoSpaceText= cohort['cohortNoSpaceText'].to_numpy()[0],
                values=pbsArray, 
                midpoints=midpoint,
                prefix=f"results/selection/{stat}", 
                contig=contig,
                colour=cohort['colour'].to_numpy()[0],
                ymin=-0.3,
                ymax=0.3,
                save=True)
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
import numpy as np
import pandas as pd
import allel
import dask.array as da
import seaborn as sns


#Selection Scans #
cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets']
contigs = ['2L', '2R', '3R', '3L', 'X']
genotypePath = snakemake.params['genotypePath'] if not cloud else "placeholder_{contig}"
positionsPath = snakemake.params['positionPath'] if not cloud else "placeholder2_{contig}"
dataset = snakemake.params['dataset']

## Read VOI data
vois = pd.read_csv(snakemake.input['variants'], sep="\t")

## separate chrom and pos data and sort 
vois['contig'] = vois['Location'].str.split(":").str.get(0)
vois['pos'] = vois['Location'].str.split(":").str.get(1).str.split("-").str.get(0)
vois = vois.sort_values(['contig', 'pos'])


# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data.Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")

# Load cohorts
cohorts = probe.getCohorts(metadata=metadata, 
                    columns=snakemake.params.columns, 
                    minPopSize=snakemake.params.minPopSize)



snps = {}
pos = {}
for contig in contigs:

    probe.log(f"Loading arrays for {contig}")
    # Load Arrays
    snps[contig], pos[contig] = probe.loadZarrArrays(genotypePath=genotypePath.format(contig = contig), 
                                             positionsPath=positionsPath.format(contig = contig),
                                             siteFilterPath=None, 
                                             sample_sets=ag3_sample_sets,
                                             cloud=cloud,
                                             contig=contig,
                                             haplotypes=False)


Dict = {}
allCohorts = {}

for idx, cohort in cohorts.iterrows():

    for i, row in vois.iterrows():
        name = row['Name']
        contig = row['contig']
        voiPos = int(row['pos'])
        longName = contig + ":"+ str(voiPos) + "  " + row['Gene'] + " | " + row['Name']

        VariantsOfInterest = pd.DataFrame([{'contig':contig, 'pos':voiPos, 'variant':name, 'name':longName}])

        bool_ = pos[contig][:] == voiPos

        geno = snps[contig].compress(bool_, axis=0).take(cohort['indices'], axis=1)
        ac = geno.count_alleles().compute()
        # if there are no ALTs lets add the zeros for the ALTs otherwise only REF count returned 
        aclength = ac.shape[1]
        acneeded = 4-aclength
        ac = np.append(ac, np.repeat(0, acneeded))
        #get frequency and round
        freqs = pd.DataFrame([ac/ac.sum().round(2)])
        df2 = freqs.apply(pd.Series).rename(columns={0:'REF', 1:'ALT1', 2:'ALT2', 3:'ALT3'})
        VariantsOfInterest[cohort['cohortText']] = df2.drop(columns=['REF']).sum(axis=1).round(2)

        Dict[name] = VariantsOfInterest

    allCohorts[idx] = pd.concat(Dict)

# Concatenated the table and write table to TSV
VariantsOfInterest = pd.concat(allCohorts, axis=1).T.drop_duplicates().T.droplevel(level=0, axis=1)
VariantsOfInterest.to_csv(f"results/variantsOfInterest/VOI.{dataset}.frequencies.tsv", sep="\t")

#Drop unnecessary columns for plotting as heatmap
VariantsOfInterest = VariantsOfInterest.drop(columns=['contig', 'pos', 'variant']).set_index('name').astype("float64").round(2)
probe.plotRectangular(VariantsOfInterest, path=f"results/variantsOfInterest/VOI.{dataset}.heatmap.png", figsize=[14,14], xlab='cohort')
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
#sys.stderr = open(snakemake.log[0], "w")

import numpy as np
import zarr
import pandas as pd
import allel
import dask.array as da
from datetime import date
from pathlib import Path

def ZarrToPandasToHaplotypeVCF(vcf_file, metadata, sample_sets, contig, sample_query=None, analysis='gamb_colu', nchunks=50, sampleNameColumn = 'partner_sample_id'):

    """
    Converts genotype and POS arrays to vcf, using pd dataframes in chunks. 
    Segregating sites only. Needs REF and ALT arrays.
    """

    #if file exists ignore and skip
    myfile = Path(f"{vcf_file}.gz")
    if myfile.is_file():
        print(f"File {vcf_file}.gz Exists...")
        return

    print(f"Loading array for {contig}...")

    ds_haps = ag3.haplotypes(contig, sample_sets=sample_sets, sample_query=sample_query, analysis=analysis)
    sample_ids = ds_haps['sample_id'].values
    metadata = metadata.set_index('sample_id').loc[sample_ids, :].reset_index()
    positions = ds_haps['variant_position']
    geno = allel.GenotypeDaskArray(ds_haps['call_genotype'])

    refs = ds_haps['variant_allele'][:,0].compute().values.astype(str)
    alts = ds_haps['variant_allele'][:,1].compute().values.astype(str)

    print("calculating chunks sizes...")
    chunks = np.round(np.arange(0, geno.shape[0], geno.shape[0]/nchunks)).astype(int).tolist()
    chunks.append(geno.shape[0])

    for idx, chunk in enumerate(chunks[:-1]):

        gn = geno[chunks[idx]:chunks[idx+1]].compute()
        pos = positions[chunks[idx]:chunks[idx+1]]
        ref = refs[chunks[idx]:chunks[idx+1]]
        alt = alts[chunks[idx]:chunks[idx+1]]

        # Contruct SNP info DF
        vcf_df = pd.DataFrame({'#CHROM': contig,
                 'POS': pos,
                 'ID': '.',
                 'REF': ref,
                 'ALT': alt,
                 'QUAL': '.',
                 'FILTER': '.',
                 'INFO':'.',
                'FORMAT': 'GT'})

        print(f"Pandas SNP info DataFrame constructed...{idx}")

        # Geno to VCF
        vcf = pd.DataFrame(np.char.replace(gn.to_gt().astype(str), "/", "|"), columns=metadata[sampleNameColumn])
        print("Concatenating info and genotype dataframes...")
        vcf = pd.concat([vcf_df, vcf], axis=1)

        print(f"Pandas Genotype data constructed...{idx}")

        if (idx==0) is True:
            with open(f"{vcf_file}", 'w') as vcfheader:
                    write_vcf_header(vcfheader, contig)

        print("Writing to .vcf")

        vcf.to_csv(vcf_file, 
                   sep="\t", 
                   index=False,
                   mode='a',
                  header=(idx==0), 
                  lineterminator="\n")

def write_vcf_header(vcf_file, contig):
    """
    Writes a VCF header.
    """

    print('##fileformat=VCFv4.1', file=vcf_file)
    # write today's date
    today = date.today().strftime('%Y%m%d')
    print('##fileDate=%s' % today, file=vcf_file)
    # write source
    print('##source=scikit-allel-%s + ZarrToVCF.py' % allel.__version__, file=vcf_file)
    #write refs and contigs 
    print('##reference=resources/reference/Anopheles-gambiae-PEST_CHROMOSOMES_AgamP4.fa', file=vcf_file)
    print('##contig=<ID=2R,length=61545105>', file=vcf_file) if contig == '2R' else None
    print('##contig=<ID=3R,length=53200684>', file=vcf_file) if contig == '3R' else None 
    print('##contig=<ID=2L,length=49364325>', file=vcf_file) if contig == '2L' else None
    print('##contig=<ID=3L,length=41963435>', file=vcf_file) if contig == '3L' else None
    print('##contig=<ID=X,length=24393108>', file=vcf_file) if contig == 'X' else None
    print('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">', file=vcf_file)

import sys


# Zarr to VCF # 
ag3_sample_sets = ["1288-VO-UG-DONNELLY-VMF00168","1288-VO-UG-DONNELLY-VMF00219"] #'1244-VO-GH-YAWSON-VMF00149'  #snakemake.params['ag3_sample_sets'] if cloud else []
contig = sys.argv[1] #'2L' #snakemake.wildcards['contig']
dataset = 'llineup' #snakemake.params['dataset']
sampleNameColumn = 'partner_sample_id'
sample_query = 'taxon == "gambiae"'

import malariagen_data
ag3 = malariagen_data.Ag3(pre=True)
metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets, sample_query=sample_query)

print(f"Running for {contig}...")

### MAIN ####
ZarrToPandasToHaplotypeVCF(
     f"resources/vcfs/{dataset}_{contig}.haplotypes.vcf", 
     metadata=metadata,
     sample_query=sample_query,
     contig=contig, 
     nchunks=20, 
     sample_sets=ag3_sample_sets
    )
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys
sys.stderr = open(snakemake.log[0], "w")

import probetools as probe
import numpy as np
import zarr
import pandas as pd
import allel
import dask.array as da
from datetime import date
from pathlib import Path


# Zarr to VCF # 
cloud = snakemake.params['cloud']
ag3_sample_sets = snakemake.params['ag3_sample_sets'] if cloud else []
contig = snakemake.wildcards['contig']
dataset = snakemake.params['dataset']
genotypePath = snakemake.input['genotypes'] if not cloud else []
positionsPath = snakemake.input['positions'] if not cloud else []
siteFilterPath = snakemake.input['siteFilters'] if not cloud else []
refPath = snakemake.input['refPath']
altPath = snakemake.input['altPath']

sampleNameColumn = 'partner_sample_id'

# Load metadata 
if cloud:
    import malariagen_data
    ag3 = malariagen_data.Ag3(pre=True)
    metadata = ag3.sample_metadata(sample_sets=ag3_sample_sets)
else:
    metadata = pd.read_csv(snakemake.params['metadata'], sep="\t")


def write_vcf_header(vcf_file, contig):
    """
    Writes a VCF header.
    """

    print('##fileformat=VCFv4.1', file=vcf_file)
    # write today's date
    today = date.today().strftime('%Y%m%d')
    print('##fileDate=%s' % today, file=vcf_file)
    # write source
    print('##source=scikit-allel-%s + ZarrToVCF.py' % allel.__version__, file=vcf_file)
    #write refs and contigs 
    print('##reference=resources/reference/Anopheles-gambiae-PEST_CHROMOSOMES_AgamP4.fa', file=vcf_file)
    print('##contig=<ID=2R,length=61545105>', file=vcf_file) if contig == '2R' else None
    print('##contig=<ID=3R,length=53200684>', file=vcf_file) if contig == '3R' else None 
    print('##contig=<ID=2L,length=49364325>', file=vcf_file) if contig == '2L' else None
    print('##contig=<ID=3L,length=41963435>', file=vcf_file) if contig == '3L' else None
    print('##contig=<ID=X,length=24393108>', file=vcf_file) if contig == 'X' else None
    print('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">', file=vcf_file)

def ZarrToPandasToVCF(vcf_file, genotypePath, positionsPath, siteFilterPath, contig, nchunks=50, snpfilter = "segregating"):

    """
    Converts genotype and POS arrays to vcf, using pd dataframes in chunks. 
    Segregating sites only. Needs REF and ALT arrays.
    """

    #if file exists ignore and skip
    myfile = Path(f"{vcf_file}.gz")
    if myfile.is_file():
        print(f"File {vcf_file}.gz Exists...")
        return

    probe.log(f"Loading array for {contig}...")

    geno, pos = probe.loadZarrArrays(genotypePath, positionsPath, siteFilterPath=siteFilterPath, cloud=cloud, contig=contig, sample_sets=ag3_sample_sets, haplotypes=False)
    allpos = allel.SortedIndex(zarr.open_array(positionsPath)[:])
    ref_alt_filter = allpos.locate_intersection(pos)[0]

    refs = zarr.open_array(refPath.format(contig=contig))[:][ref_alt_filter]
    alts = zarr.open_array(altPath.format(contig=contig))[:][ref_alt_filter]

    if snpfilter == "segregating":
        probe.log("Find segregating sites...")
        flt = geno.count_alleles().is_segregating()
        geno = geno.compress(flt, axis=0)
        positions = pos[flt]
        refs = refs[flt].astype(str)
        alts = [a +"," + b + "," + c for a,b,c in alts[flt].astype(str)]
    elif snpfilter == 'biallelic':
        probe.log("Finding biallelic sites and recoding to 0 and 1...")
        ac = geno.count_alleles()
        flt = ac.is_biallelic()
        geno = geno.compress(flt, axis=0)
        ac = ac.compress(flt, axis=0).compute()
        ref0 = ac[:,0] > 0                      # Make sure one of bialleles is 0
        geno = geno.compress(ref0, axis=0)
        ac = ac.compress(ref0, axis=0)

        alt_idx = np.where(ac[:,1:] > 0)[1]     # Get alt idx (is it 0,1,2)
        mapping = np.tile(np.array([0,1,1,1]), reps=geno.shape[0]).reshape(geno.shape[0], 4) # create mapping to recode bialleles to 1 
        geno = geno.map_alleles(mapping)
        positions = pos[flt][ref0]
        refs = refs[flt].astype(str)[ref0]
        alts = np.take_along_axis(alts[flt][ref0].astype(str), alt_idx[:, None], axis=-1).flatten() # select correct ALT allele
    else:
        assert np.isin(snpfilter, ['segregating', "biallelic01"]).any(), "incorrect snpfilter value"

    probe.log("calculating chunks sizes...")
    chunks = np.round(np.arange(0, geno.shape[0], geno.shape[0]/nchunks)).astype(int).tolist()
    chunks.append(geno.shape[0])

    for idx, chunk in enumerate(chunks[:-1]):

        gn = geno[chunks[idx]:chunks[idx+1]].compute()
        pos = positions[chunks[idx]:chunks[idx+1]]
        ref = refs[chunks[idx]:chunks[idx+1]]
        alt = alts[chunks[idx]:chunks[idx+1]]

        # Contruct SNP info DF
        vcf_df = pd.DataFrame({'#CHROM': contig,
                 'POS': pos,
                 'ID': '.',
                 'REF': ref,
                 'ALT': alt,
                 'QUAL': '.',
                 'FILTER': '.',
                 'INFO':'.',
                'FORMAT': 'GT'})

        probe.log(f"Pandas SNP info DataFrame constructed...{idx}")

        # Geno to VCF
        vcf = pd.DataFrame(gn.to_gt().astype(str), columns=metadata[sampleNameColumn])
        probe.log("Concatenating info and genotype dataframes...")
        vcf = pd.concat([vcf_df, vcf], axis=1)

        probe.log(f"Pandas Genotype data constructed...{idx}")

        if (idx==0) is True:
            with open(f"{vcf_file}", 'w') as vcfheader:
                    write_vcf_header(vcfheader, contig)

        probe.log("Writing to .vcf")

        vcf.to_csv(vcf_file, 
                   sep="\t", 
                   index=False,
                   mode='a',
                  header=(idx==0), 
                  line_terminator="\n")



### MAIN ####

#ZarrToPandasToVCF(f"../resources/vcfs/ag3_gaardian_{contig}.multiallelic.vcf", genotypePath, positionsPath, siteFilterPath, contig, snpfilter="segregating")
ZarrToPandasToVCF(f"resources/vcfs/{dataset}_{contig}.biallelic.vcf", genotypePath, positionsPath, siteFilterPath, contig, snpfilter="biallelic")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
__author__ = "Michael Chambers"
__copyright__ = "Copyright 2019, Michael Chambers"
__email__ = "[email protected]"
__license__ = "MIT"


from snakemake.shell import shell


shell("samtools faidx {snakemake.params} {snakemake.input[0]} > {snakemake.output[0]}")
ShowHide 18 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/sanjaynagi/probe
Name: probe
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...