Simulation Workflow for Human Allele Frequency Change Study

public public 1yr ago Version: v1.0.2 0 bookmarks

Simulation workflow associated with the paper "The contribution of admixture, selection, and genetic drift to four thousand years of human allele frequency change" by Alexis Simon and Graham Coop.

Depends on:

  • snakemake

Code Snippets

19
20
script:
	'../scripts/analyse_simple_scenarios.py'
40
41
script:
	'../scripts/analyse_simple_scenarios.py'
53
54
script:
	'../scripts/plot_simple_scenarios.py'
67
68
script:
	'../scripts/plot_simple_scenarios.py'
92
93
script:
	'../scripts/analyse_europe_uk.py'
107
108
script:
	'../scripts/plot_europe_uk.py'
14
15
script:
	'../scripts/sim_msprime_simple_scenarios.py'
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
shell:
	'''
	slim \
	-d 'JSON_FILE="{input.demes_file}"' \
	-d 'TREES_FILE="{output.trees_file}"' \
	-d 'PHENO_FILE="{output.pheno_file}"' \
	-d 'backward_sampling={params.sampling_times}' \
	-d 'N_sample={params.n_sample}' \
	-d 'census_time={params.census_time}' \
	-d 'shift_type="{wildcards.type}"' \
	-d 'shift_size={wildcards.ssize}' \
	-d 'shift_delay={params.shift_delay}' \
	workflow/scripts/sim_slim_sel_simple_scenarios.slim \
	> {log}
	'''
64
65
script:
	'../scripts/sim_slim_postprocessing.py'
79
80
script:
	'../scripts/sim_msprime_europe_uk.py'
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import admixcov as ac
import tskit
import demes
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

#%%
files = snakemake.input['files']
demes_file = snakemake.input['demes_file']
census_time = snakemake.params['census_time']

# drop_times = 2 if 'slim' in files[0] else 1

ts = tskit.load(files[0]) # extract info common to all trees
# times = np.flip(ac.ts.get_times(ts))[drop_times:]
times = [150, 130, 110, 90, 70, 50, 0]
graph = demes.load(demes_file)

n_samples = snakemake.params['n_samples']
assert len(n_samples) == len(times)

ref_n_samples = snakemake.params['ref_n_samples']

# WHG, ANA, YAM
refs = [
    {'pop': p, 'time': c, 'n': n}
    for (p, n, c) in zip([5, 4, 7], ref_n_samples, [200, 200, 150])
]
alpha_mask = np.array([ # WHG, ANA, YAM
    [0, 0, 1],
    [0, 1, 0],
    [0, 1, 0],
    [0, 1, 0],
    [1, 0, 0],
    [0, 1, 0],
], dtype=bool)
rng = np.random.default_rng()

#%%
def ts_reps(files: list):
    for f in files:
        yield tskit.load(f)

results = []
for ts in ts_reps(files):
    results.append(
        ac.ts.analyze_trees(
            ts,
            times,
            n_samples,
            8, # focal pop
            refs,
            alpha_mask,
            rng,
        )
    )

#%% transform results
totvar = []
G = []
G_nc = []
Ap = []
G_nde = []
Q = []
covmat_nc = []
covmat = []
for r in results:
    (t, gnc, g, a, gnde) =  ac.stats_from_matrices(
        r['covmat'],
        r['admix_cov'],
        r['drift_err'],
    )
    totvar.append(np.array(t) / r['hz'][0]) # dividing by first time point hz
    G_nc.append(gnc)
    G.append(g)
    Ap.append(a)
    G_nde.append(gnde)
    Q.append(r['Q'])
    covmat_nc.append(r['covmat'])
    covmat.append(r['covmat'] - r['admix_cov'] - r['drift_err'])

totvar = np.array(totvar)
G_nc = np.array(G_nc)
G = np.array(G)
G_nde = np.array(G_nde)
Ap = np.array(Ap)
Q = np.stack(Q)
covmat_nc = np.stack(covmat_nc)
covmat = np.stack(covmat)
# convert to CIs
totvar_CI = ac.get_ci(totvar)
G_nc_CI = ac.get_ci(G_nc)
G_CI = ac.get_ci(G)
Ap_CI = ac.get_ci(Ap)
G_nde_CI = ac.get_ci(G_nde)

covmat_nc_CI = ac.get_ci(covmat_nc)
covmat_CI = ac.get_ci(covmat)

Q_CIs = [
    ac.get_ci(Q[:,:,i])
    for i in range(Q.shape[-1])
]

if 'slim' in files[0]:
    ztb = pd.read_csv(files[0].replace('.trees', '_pheno.tsv'), sep='\t')
    for f in files[1:]:
        ztb = pd.concat([ztb, pd.read_csv(f.replace('.trees', '_pheno.tsv'), sep='\t')])
    ztb['bgen'] = ztb.gen.max() - ztb.gen
else:
    ztb = None

with open(snakemake.output['pickle'], 'wb') as fw:
    pickle.dump(
        (
            times,
            totvar_CI,
            G_nc_CI,
            G_CI,
            Ap_CI,
            G_nde_CI,
            covmat_nc_CI,
            covmat_CI,
            Q_CIs,
            ztb,
        ),
        fw
    )
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import admixcov as ac
import tskit
import demes
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

#%%
files = snakemake.input['files']
demes_file = snakemake.input['demes_file']
census_time = snakemake.params['census_time']

drop_times = 2 if 'slim' in files[0] else 1

ts = tskit.load(files[0]) # extract info common to all trees
times = np.flip(ac.ts.get_times(ts))[drop_times:]
graph = demes.load(demes_file)
N_admix_pop = len(graph.demes) - 1

n_samples = [snakemake.params['n_sample']]*len(times)
assert len(n_samples) == len(times)

ref_n_sample = snakemake.params['ref_n_sample']

# not using ts.num_populations here as pyslim adds an additional one in ts
refs = [
    {'pop': i, 'time': census_time, 'n': ref_n_sample}
    for i in range(N_admix_pop)
]
alpha_mask = np.array(
    [p.proportions for p in graph.pulses]
) > 0 # create alphas from graph
rng = np.random.default_rng()

#%%
def ts_reps(files: list):
    for f in files:
        yield tskit.load(f)

results = []
for ts in ts_reps(files):
    results.append(
        ac.ts.analyze_trees(
            ts,
            times,
            n_samples,
            N_admix_pop,
            refs,
            alpha_mask,
            rng,
        )
    )

#%% transform results
totvar = []
G = []
G_nc = []
G_nde = []
Ap = []
Q = []
covmat_nc = []
covmat = []
for r in results:
    (t, gnc, g, a, gnde) =  ac.stats_from_matrices(
        r['covmat'],
        r['admix_cov'],
        r['drift_err'],
    )
    totvar.append(np.array(t) / r['hz'][0]) # dividing by first time point hz
    G_nc.append(gnc)
    G.append(g)
    Ap.append(a)
    G_nde.append(gnde)
    Q.append(r['Q'])
    covmat_nc.append(r['covmat'])
    covmat.append(r['covmat'] - r['admix_cov'] - r['drift_err'])

totvar = np.array(totvar)
G_nc = np.array(G_nc)
G = np.array(G)
Ap = np.array(Ap)
G_nde = np.array(G_nde)
Q = np.stack(Q)
covmat_nc = np.stack(covmat_nc)
covmat = np.stack(covmat)
# convert to CIs
totvar_CI = ac.get_ci(totvar)
G_nc_CI = ac.get_ci(G_nc)
G_CI = ac.get_ci(G)
Ap_CI = ac.get_ci(Ap)
G_nde_CI = ac.get_ci(G_nde)

covmat_nc_CI = ac.get_ci(covmat_nc)
covmat_CI = ac.get_ci(covmat)

Q_CIs = [
    ac.get_ci(Q[:,:,i])
    for i in range(Q.shape[-1])
]

if 'slim' in files[0]:
    ztb = pd.read_csv(files[0].replace('.trees', '_pheno.tsv'), sep='\t')
    for f in files[1:]:
        ztb = pd.concat([ztb, pd.read_csv(f.replace('.trees', '_pheno.tsv'), sep='\t')])
    ztb['bgen'] = ztb.gen.max() - ztb.gen
else:
    ztb = None

with open(snakemake.output['pickle'], 'wb') as fw:
    pickle.dump(
        (
            times,
            totvar_CI,
            G_nc_CI,
            G_CI,
            Ap_CI,
            G_nde_CI,
            covmat_nc_CI,
            covmat_CI,
            Q_CIs,
            ztb,
        ),
        fw
    )
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import admixcov as ac
import tskit
import demes
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

fig, axs = plt.subplots(3, 2, figsize=(12, 10), layout="tight")

with open(snakemake.input['sim_neutral'], 'rb') as fr:
	(
		times,
        totvar_CI,
        G_nc_CI,
        G_CI,
        Ap_CI,
        G_nde_CI,
        covmat_nc_CI,
        covmat_CI,
        Q_CIs,
        ztb,
	) = pickle.load(fr)

time_padding = 10

colors_oi = [
    '#000000', # black
    '#D55E00', # vermillion
    '#0072B2', # blue
    '#009E73', # green
    '#E69F00', # orange
    '#56B4E9', # sky blue
    '#CC79A7', # pink
    '#F0E442', # yellow
]

times = np.array(times) # ensure it is an array
delta_list = [f"$\\Delta p_{{{int(t)}}}$" for t in times[:-1]]

# sci notation formatter
import matplotlib.ticker as tkr
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((0, 0))

k, l = (0, 0)
for i in range(len(Q_CIs)):
    ac.plot_ci_line(x=times, CI=Q_CIs[i], ax=axs[k, l], color=colors_oi[i], label=f"Pop{i}", marker='o')
for x, txt in zip([t - 10 for t in times[:-1]], delta_list):
	_ = axs[k, l].text(x, 1, txt, ha='center')
for x in times[1::2]:
    _ = axs[k, l].axvspan(x, x + 20, facecolor='grey', alpha=0.10)
for x in [150, 130, 110, 50, 30, 10]:
	_ = axs[k, l].annotate("", xy=(x, 0.1), xytext=(x, 0), arrowprops=dict(arrowstyle="->"))
axs[k, l].set_xlim(times[0] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylabel("Mean ancestry proportion")
axs[k, l].set_xlabel("Time (gen. BP)")
axs[k, l].legend(loc="center left")
axs[k, l].set_title("A", loc='left', fontdict={'fontweight': 'bold'})

k, l = (0, 1)
combined_ci = ac.combine_covmat_CIs(covmat_CI, covmat_nc_CI)
scale_max = (
    np.max(np.abs([np.nanmin(combined_ci[1] - np.diag(np.diag(combined_ci[1]))),
    np.nanmax(combined_ci[1] - np.diag(np.diag(combined_ci[1])))]))
)
ac.plot_covmat_ci(
	combined_ci,
    axs[k, l],
    scale_max,
	delta_labels=delta_list,
	cbar_kws={
        'label': "covariance, Cov($\\Delta p_i$, $\\Delta p_j$)",
        "format": formatter
    },
)
axs[k, l].set_title("B", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].set_title('Neutral covariances')

x_shift = 2
k, l = (1, 0)
ac.cov_lineplot(times, covmat_nc_CI, axs[k, l], colors=colors_oi, d=2)
axs[k, l].set_ylabel("Cov($\\Delta p_i$, $\\Delta p_t$)")
axs[k, l].set_xlabel('t')
axs[k, l].set_title('Neutral, before admixture correction')
axs[k, l].set_title("C", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].set_xlim(times[1] + x_shift, times[-2] - 4 * x_shift)
axs[k, l].hlines(y=0, xmin=times[1] + x_shift, xmax=times[-2] - 4 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].yaxis.set_major_formatter(formatter)
k, l = (1, 1)
ac.cov_lineplot(times, covmat_CI, axs[k, l], colors=colors_oi, d=2, ylim=axs[k, l - 1].get_ylim())
axs[k, l].set_ylabel("Cov($\\Delta p_i$, $\\Delta p_t$)")
axs[k, l].set_xlabel('t')
axs[k, l].set_title('Neutral, after admixture correction')
axs[k, l].legend(loc='center left', bbox_to_anchor=(1, 0.5), title="$\\Delta p_i$")
axs[k, l].set_title("D", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].set_xlim(times[1] + x_shift, times[-2] - 4 * x_shift)
axs[k, l].hlines(y=0, xmin=times[1] + x_shift, xmax=times[-2] - 4 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].yaxis.set_major_formatter(formatter)


k, l = (2, 0)
x_shift = 2
ac.plot_ci_line(times[1:] + x_shift, G_nc_CI, ax=axs[k, l], marker='o', linestyle='dashed', label='$G_{nc}$')
# ac.plot_ci_line(times[1:] + 2 * x_shift, G_nde_CI, ax=axs[k, l], marker='^', linestyle='dashdot', label='$G_{nde}$')
ac.plot_ci_line(times[1:], G_CI, ax=axs[k, l], marker='o', label='$G$')
ac.plot_ci_line(times[1:] - x_shift, Ap_CI, ax=axs[k, l], marker='s', color='blue', label='$A$')
axs[k, l].set_xlim(times[1] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim(ymax=1.1)
axs[k, l].hlines(y=0, xmin=times[-1] - time_padding, xmax=times[1] + time_padding, linestyles='dotted', colors='grey')
axs[k, l].set_xlabel('t')
axs[k, l].set_ylabel("Proportion of variance ($p_t - p_{160}$)")
axs[k, l].set_title('Neutral, Var. decomposition')
axs[k, l].set_title("E", loc='left', fontdict={'fontweight': 'bold'})
for i, t in enumerate(times[1:]):
    if G_CI[0][i]*G_CI[2][i] > 0:
        axs[k, l].annotate("*", xy=(t, 0.1))

# ==================
with open(snakemake.input['sim_sel'], 'rb') as fr:
	(
		times,
        totvar_CI,
        G_nc_CI,
        G_CI,
        Ap_CI,
        G_nde_CI,
        covmat_nc_CI,
        covmat_CI,
        Q_CIs,
        ztb,
	) = pickle.load(fr)


k, l = (2, 1)
x_shift = 2
ac.plot_ci_line(times[1:] + x_shift, G_nc_CI, ax=axs[k, l], marker='o', linestyle='dashed', label='$G_{nc}$')
# ac.plot_ci_line(times[1:] + 2 * x_shift, G_nde_CI, ax=axs[k, l], marker='^', linestyle='dotted', label='$G_{nde}$')
ac.plot_ci_line(times[1:], G_CI, ax=axs[k, l], marker='o', label='$G$')
ac.plot_ci_line(times[1:] - x_shift, Ap_CI, ax=axs[k, l], marker='s', color='blue', label='$A$')
axs[k, l].set_xlim(times[1] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim(ymax=1.1)
axs[k, l].hlines(y=0, xmin=times[-1] - time_padding, xmax=times[1] + time_padding, linestyles='dotted', colors='black')
axs[k, l].set_xlabel('t')
axs[k, l].set_ylabel("Proportion of variance ($p_t - p_{160}$)")
axs[k, l].set_title('Gradual selection, Var. decomposition')
axs[k, l].legend(loc='center left', bbox_to_anchor=(1, 0.5))
axs[k, l].set_title("F", loc='left', fontdict={'fontweight': 'bold'})
for i, t in enumerate(times[1:]):
    if G_CI[0][i]*G_CI[2][i] > 0:
        axs[k, l].annotate("*", xy=(t, 0.3))


fig.savefig(snakemake.output['fig'])
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import admixcov as ac
import tskit
import demes
import demesdraw
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

demes_file = snakemake.input['demes_file']
graph = demes.load(demes_file)
fig, ax = plt.subplots(figsize=(8, 8))
demesdraw.tubes(graph, log_time=True, ax=ax)
fig.savefig(snakemake.output['fig_demo'])

with open(snakemake.input['pickle'], 'rb') as fr:
	(
		times,
        totvar_CI,
        G_nc_CI,
        G_CI,
        Ap_CI,
        G_nde_CI,
        covmat_nc_CI,
        covmat_CI,
        Q_CIs,
        ztb,	
	) = pickle.load(fr)

import matplotlib.ticker as plticker
loc = plticker.MultipleLocator(base=1.0)

time_padding = 10

colors_oi = [
    '#000000', # black
    '#D55E00', # vermillion
    '#0072B2', # blue
    '#009E73', # green
    '#E69F00', # orange
    '#56B4E9', # sky blue
    '#CC79A7', # pink
    '#F0E442', # yellow
]

times = np.array(times) # ensure it is an array
delta_list = [f"$\\Delta p_{{{int(t)}}}$" for t in range(len(times) - 1)]

# sci notation in colorbar
import matplotlib.ticker as tkr
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((0, 0))

fig, axs = plt.subplots(2, 2, figsize=(10, 8), layout='constrained')

k, l = (0, 1)
fmts = ['-o', '-s', '-^']
labels = ['WHG', 'ANA', 'YAM']
for i, pop in enumerate(labels):
    ac.plot_ci_line(x=times, CI=Q_CIs[i], ax=axs[k, l], color=colors_oi[i], label=pop, fmt=fmts[i])
for x1, x2, txt in zip(times[:-1], times[1:], delta_list):
    _ = axs[k, l].text(x2+(x1 - x2)/2, 0.9, txt, ha='center')
for i, t in enumerate(times):
    _ = axs[k, l].text(t, 0.8, str(i), ha='center')
for x1, x2 in zip(times[1::2], times[2::2]):
    _ = axs[k, l].axvspan(x1, x2, facecolor='grey', alpha=0.10)
axs[k, l].set_xlim(times[0] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim(top=1)
axs[k, l].set_ylabel("Mean ancestry proportion")
axs[k, l].set_xlabel("Time (years BP)")
axs[k, l].legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3)
axs[k, l].set_title("B", loc='left', fontdict={'fontweight': 'bold'})

x_shift = 0.1
new_times = np.array(range(len(times)))
k, l = (0, 0)
ac.cov_lineplot(new_times, covmat_nc_CI, axs[k, l], colors=colors_oi, d=x_shift, labels=delta_list)
axs[k, l].set_xlim(new_times[1] - x_shift, new_times[-2] + 3 * x_shift)
axs[k, l].hlines(y=0, xmin=0, xmax=new_times[-1] + 3 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].set_ylabel("Cov($\\Delta p_i$, $\\Delta p_t$)")
axs[k, l].set_xlabel("t")
axs[k, l].set_title('Before admix. correction')
axs[k, l].set_title("A", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].xaxis.set_major_locator(loc)
axs[k, l].yaxis.set_major_formatter(formatter)

k, l = (1, 0)
ac.cov_lineplot(new_times, covmat_CI, axs[k, l], colors=colors_oi, d=x_shift, labels=delta_list, ylim=axs[0, 0].get_ylim())
axs[k, l].set_xlim(new_times[1] - x_shift, new_times[-2] + 3 * x_shift)
axs[k, l].hlines(y=0, xmin=0, xmax=new_times[-1] + 3 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].set_ylabel("Cov($\\Delta p_i$, $\\Delta p_t$)")
axs[k, l].set_xlabel('t')
axs[k, l].set_title('After admix. correction')
axs[k, l].set_title("C", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), title="$\\Delta p_i$", ncol=3)
axs[k, l].xaxis.set_major_locator(loc)
axs[k, l].yaxis.set_major_formatter(formatter)

k, l = (1, 1)
ac.plot_ci_line(new_times[1:] + x_shift, G_nc_CI, ax=axs[k, l], linestyle='dashed', marker='o', label='$G_{nc}$')
ac.plot_ci_line(new_times[1:], G_CI, ax=axs[k, l], marker='o', label='$G$')
ac.plot_ci_line(new_times[1:] - x_shift, Ap_CI, ax=axs[k, l], color='blue', marker='s', label='$A$')
axs[k, l].set_xlim(new_times[1] - 2*x_shift, new_times[-1] + 2*x_shift)
axs[k, l].hlines(y=0, xmin=new_times[-1], xmax=new_times[1], colors='grey', linestyles='dotted')
axs[k, l].set_ylim(ymax=1)
axs[k, l].set_xlabel('t')
axs[k, l].set_ylabel("Proportion of variance ($p_t - p_{0}$)")
axs[k, l].legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3)
axs[k, l].set_title("D", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].xaxis.set_major_locator(loc)
for i, t in enumerate(new_times[1:]):
    if G_CI[0][i]*G_CI[2][i] > 0:
        axs[k, l].annotate("*", xy=(t, 0.1))

fig.savefig(
    snakemake.output['main_fig'],
)

#%%
if 'slim' in snakemake.input['pickle']:
    fig, ax = plt.subplots(figsize=(6, 4))
    sns.lineplot(
        # ztb[ztb.bgen < times[0] + 20], 
        ztb,
        x='bgen', y='mean_z', style='pop', hue='pop',
        estimator='mean', errorbar='ci', # 95% ci by default
        ax=ax,
    )
    ax.set_xlim(xmin=200, xmax=-5)
    ax.set_xlabel('generations')
    ax.set_ylabel('mean phenotype')
    fig.savefig(snakemake.output['pheno_fig'])
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import admixcov as ac
import tskit
import demes
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

with open(snakemake.input['pickle'], 'rb') as fr:
	(
		times,
        totvar_CI,
        G_nc_CI,
        G_CI,
        Ap_CI,
        G_nde_CI,
        covmat_nc_CI,
        covmat_CI,
        Q_CIs,
        ztb,	
	) = pickle.load(fr)

# %%
time_padding = 10

colors_oi = [
    '#000000', # black
    '#D55E00', # vermillion
    '#0072B2', # blue
    '#009E73', # green
    '#E69F00', # orange
    '#56B4E9', # sky blue
    '#CC79A7', # pink
    '#F0E442', # yellow
]

times = np.array(times) # ensure it is an array

fig, axs = plt.subplots(3, 2, figsize=(10, 8), layout="tight")

# sci notation in colorbar
import matplotlib.ticker as tkr
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((0, 0))

k, l = (0, 0)
for i in range(len(Q_CIs)):
    ac.plot_ci_line(x=times, CI=Q_CIs[i], ax=axs[0,0], color=colors_oi[i], label=f"Pop{i}", marker='o')
axs[k, l].set_xlim(times[0] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim((0,1))
axs[k, l].set_ylabel("Mean ancestry proportion")
axs[k, l].set_xlabel("Time point")
axs[k, l].legend(loc="upper right")
axs[k, l].set_title('$A$', loc='left', fontdict={'fontweight': 'bold'})

k, l = (0, 1)
combined_ci = ac.combine_covmat_CIs(covmat_CI, covmat_nc_CI)
scale_max = (
    np.max(np.abs([np.nanmin(combined_ci[1] - np.diag(np.diag(combined_ci[1]))),
    np.nanmax(combined_ci[1] - np.diag(np.diag(combined_ci[1])))]))
)
ac.plot_covmat_ci(
    combined_ci,
    axs[k, l],
    scale_max,
    cbar_kws={'label': 'covariance', "format": formatter},
)
axs[k, l].set_title("B", loc='left', fontdict={'fontweight': 'bold'})

x_shift = 2
k, l = (1, 0)
ac.cov_lineplot(times, covmat_nc_CI, axs[k, l], colors=colors_oi, d=2)
axs[k, l].set_ylabel('Before correction')
axs[k, l].set_title("C", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].set_xlim(times[1] + x_shift, times[-2] - 4 * x_shift)
axs[k, l].hlines(y=0, xmin=times[1] + x_shift, xmax=times[-2] - 4 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].yaxis.set_major_formatter(formatter)
k, l = (1, 1)
ac.cov_lineplot(times, covmat_CI, axs[k, l], colors=colors_oi, d=2, ylim=axs[1, 0].get_ylim())
axs[k, l].set_ylabel('After correction')
axs[k, l].set_title("D", loc='left', fontdict={'fontweight': 'bold'})
axs[k, l].set_xlim(times[1] + x_shift, times[-2] - 4 * x_shift)
axs[k, l].hlines(y=0, xmin=times[1] + x_shift, xmax=times[-2] - 4 * x_shift, linestyles='dotted', colors='grey')
axs[k, l].yaxis.set_major_formatter(formatter)

k, l = (2, 0)
ac.plot_ci_line(x=times[1:], CI=totvar_CI, ax=axs[k, l], marker='o')
axs[k, l].set_xlim(times[1] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim(0)
axs[k, l].set_ylabel('Total variance (t)')
axs[k, l].set_title("E", loc='left', fontdict={'fontweight': 'bold'})

k, l = (2, 1)
x_shift = 2
ymin = np.min([G_CI[1], G_nc_CI[1], G_nde_CI[1]]) - 0.1
ac.plot_ci_line(times[1:] + x_shift, G_nc_CI, ax=axs[k, l], marker='o', linestyle='dashed', label='$G_{nc}$')
ac.plot_ci_line(times[1:] + 2 * x_shift, G_nde_CI, ax=axs[k, l], marker='^', linestyle='dashdot', label='$G_{nde}$')
ac.plot_ci_line(times[1:], G_CI, ax=axs[k, l], marker='o', label='$G$')
ac.plot_ci_line(times[1:] - x_shift, Ap_CI, ax=axs[k, l], marker='s', color='blue', label='$A$')
axs[k, l].set_xlim(times[1] + time_padding, times[-1] - time_padding)
axs[k, l].set_ylim(ymax=1.1, ymin=ymin)
axs[k, l].hlines(y=0, xmin=times[-1] - time_padding, xmax=times[1] + time_padding, linestyles='dotted', colors='grey')
axs[k, l].set_xlabel('t')
axs[k, l].set_ylabel("Proportion of variance ($p_t - p_{160}$)")
axs[k, l].legend(loc='center left', bbox_to_anchor=(1, 0.5))
axs[k, l].set_title("F", loc='left', fontdict={'fontweight': 'bold'})
for i, t in enumerate(times[1:]):
    (_, ytop) = axs[k, l].get_ylim()
    if G_CI[0][i]*G_CI[2][i] > 0:
        axs[k, l].annotate("*", xy=(t, ytop))

fig.savefig(
    snakemake.output['main_fig'],
)

#%%
if 'slim' in snakemake.input['pickle']:
    n_traits = 3
    fig, axs = plt.subplots(1, n_traits, figsize=(n_traits*4, 3), layout='constrained')
    for i in range(n_traits):
        sns.lineplot(
            ztb,
            x='bgen', y=f'mean_z{i}', style='pop', hue='pop',
            estimator='mean', errorbar='ci', # 95% ci by default
            ax=axs[i],
        )
        axs[i].set_xlim(xmin=200, xmax=-5)
        axs[i].set_xlabel('generations')
        axs[i].set_ylabel(f'mean phenotype {i}')
    fig.savefig(snakemake.output['pheno_fig'])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import demes
import os
import demesdraw

sc = snakemake.wildcards['sc']

class Scenario:
    pulse_times = [
        150,
        130,
        110,
        90,
        70,
        50,
        30,
        10,
    ]

    def __init__(
        self,
        name: str,
        N_anc: int,
        pop_sizes: list[int],
        pulses: list[list],
        path: str
    ):
        assert len(pulses[0]) == N_anc
        assert len(pop_sizes) == (N_anc + 1)
        assert len(pulses) == len(self.pulse_times)
        self.name = name
        self.N_anc = N_anc
        self.pop_sizes = pop_sizes
        self.pulses = pulses
        self.file_prefix = path + '/scenario_' + name
        self.plot = f"{self.file_prefix}.svg"

    def build(self):
        b = demes.Builder(
            description=self.name,
            time_units="generations",
            generation_time=1,
        )
        b.add_deme(
            "Pop0",
            description="Ancestral 1",
            epochs=[dict(end_time=0, start_size=self.pop_sizes[0])],
        )
        start = 1500
        for i in range(1, self.N_anc):
            b.add_deme(
                f"Pop{i}",
                description=f"Ancestral {i + 1}",
                ancestors=["Pop0"],
                start_time=start,
                epochs=[dict(end_time=0, start_size=self.pop_sizes[i])],
            )
            start -= 200
        b.add_deme(
            f"Pop{self.N_anc}",
            description="Admixed",
            ancestors=["Pop0"],
            start_time=200,
            epochs=[dict(end_time=0, start_size=self.pop_sizes[self.N_anc])],
        )
        for t, p in zip(self.pulse_times, self.pulses):
            b.add_pulse(
                sources=[f"Pop{i}" for i in range(self.N_anc)],
                dest=f"Pop{self.N_anc}",
                proportions=p,
                time=t,
            )
        self.graph = b.resolve()
        demes.dump(self.graph, self.file_prefix + '.yaml')
        demes.dump(self.graph, self.file_prefix + '.json', format='json', simplified=False)
        ax = demesdraw.tubes(self.graph, log_time=True)
        ax.figure.savefig(self.plot)


# ensure pulses 0s are floats!

sc_dict = dict()
# Scenario 2NGF (No Gene Flow)
sc_dict['2NGF'] = Scenario( 
    name="2NGF",
    N_anc=2,
    pop_sizes=[10_000, 10_000, 10_000],
    pulses=[
        [.0, .0],
        [.0, .0],
        [.0, .0],
        [.0, .0],
        [.0, .0],
        [.0, .0],
        [.0, .0],
        [.0, .0],
    ],
    path=snakemake.params['outdir'],
)

# Scenario 2A
sc_dict['2A'] = Scenario(
    name="2A",
    N_anc=2,
    pop_sizes=[10_000, 10_000, 10_000],
    pulses=[
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
        [.0, .0],
        [.0, .0],
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
    ],
    path=snakemake.params['outdir'],
)

# Scenario 2B
sc_dict['2B'] = Scenario(
    name="2B",
    N_anc=2,
    pop_sizes=[10_000, 10_000, 10_000],
    pulses=[
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
        [.2, .0],
        [.2, .0],
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
    ],
    path=snakemake.params['outdir'],
)

# Scenario 2C
sc_dict['2C'] = Scenario(
    name="2C",
    N_anc=2,
    pop_sizes=[10_000, 1_000, 5_000],
    pulses=[
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
        [.2, .0],
        [.2, .0],
        [.0, 0.2],
        [.0, 0.2],
        [.0, 0.2],
    ],
    path=snakemake.params['outdir'],
)

# Scenario 3A
sc_dict['3A'] = Scenario(
    name="3A",
    N_anc=3,
    pop_sizes=[10_000, 10_000, 10_000, 10_000],
    pulses=[
        [.0, .2, .0],
        [.0, .2, .0],
        [.0, .2, .0],
        [.2, .0, .0],
        [.0, .0, .0],
        [.0, .0, .2],
        [.0, .0, .2],
        [.0, .0, .2],
    ],
    path=snakemake.params['outdir'],
)

# Scenario 3B
sc_dict['3B'] = Scenario(
    name="3B",
    N_anc=3,
    pop_sizes=[5_000, 1_000, 10_000, 5_000],
    pulses=[
        [.0, .2, .0],
        [.0, .0, .2],
        [.0, .2, .0],
        [.2, .0, .0],
        [.2, .0, .0],
        [.0, .0, .2],
        [.0, .2, .0],
        [.0, .0, .2],
    ],
    path=snakemake.params['outdir'],
)


sc_dict[sc].build()
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import demes
# import demesdraw
import msprime
# import math
import numpy as np
import stdpopsim
# import pickle

# import funcs as fn
# inputs
demes_file = snakemake.input['demes_file']
# outputs
trees_file = snakemake.output['trees_file']
# model_plot = snakemake.output['model_plot']
# rate_map_pickle = snakemake.output['rate_map_pickle']
# params
# census_time = snakemake.params['census_time']
n_sample = snakemake.params['n_sample']


graph = demes.load(demes_file)
# ax = demesdraw.tubes(graph, log_time=True)
# ax.figure.savefig(model_plot)

census_times = {'WHG': 200, 'ANA': 200, 'YAM': 150}

demography = msprime.Demography.from_demes(graph)
for ct in np.unique(list(census_times.values())):
	demography.add_census(time=ct)
demography.sort_events()

# cohorts = [
# 	'England.and.Wales_N',
# 	'England.and.Wales_C.EBA',
# 	'England.and.Wales_MBA',
# 	'England.and.Wales_LBA',
# 	'England.and.Wales_IA',
# 	'England.and.Wales_PostIA',
# 	'England.and.Wales_Modern',
# ]

# sampling
sampling_times = [150, 130, 110, 90, 70, 50, 0]
samples = []
for d in graph.demes:
	if d.name in ['NEO', 'WHG', 'ANA', 'YAM']:
		samples += [
			msprime.SampleSet(n_sample, population=d.name, time=t)
			for t in sampling_times
			if (t < d.epochs[0].start_time) & (t >= d.epochs[-1].end_time)
		]
for pop, ct in census_times.items():
	samples.append(msprime.SampleSet(n_sample, population=pop, time=ct))


# Contig setup
species = stdpopsim.get_species("HomSap")
contigs = [
	species.get_contig(chr)
	for chr in ['chr1']
]

# Simulation
ts = msprime.sim_ancestry(
	samples=samples,
	ploidy=2,
	recombination_rate=contigs[0].recombination_map,
	demography=demography,
)

ts = msprime.sim_mutations(
	ts,
	rate=contigs[0].mutation_rate,
)

# drop sites with recurrent mutations
ts = ts.delete_sites(np.where([len(s.mutations) > 1 for s in ts.sites()])[0])

ts.dump(trees_file)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import demes
import demesdraw
import msprime
# import math
import numpy as np
import stdpopsim
# import pickle

# import funcs as fn
# inputs
demes_file = snakemake.input['demes_file']
# outputs
trees_file = snakemake.output['trees_file']
# params
census_time = snakemake.params['census_time']
n_sample = snakemake.params['n_sample']
sampling_times = snakemake.params['sampling_times']

graph = demes.load(demes_file)

demography = msprime.Demography.from_demes(graph)
demography.add_census(time=census_time)
demography.sort_events()

# sampling
samples = []
for d in graph.demes:
	samples += [
		msprime.SampleSet(n_sample, population=d.name, time=t)
		for t in sampling_times
		if (t < d.epochs[0].start_time) & (t >= d.epochs[-1].end_time)
	]
	if (census_time < d.epochs[0].start_time) & (census_time >= d.epochs[-1].end_time):
		samples.append(msprime.SampleSet(n_sample, population=d.name, time=census_time))

# Contig setup
species = stdpopsim.get_species("HomSap")
# contigs = [
# 	species.get_contig(chr)
# 	for chr in ['chr22']
# ]

# Simulation
ts = msprime.sim_ancestry(
	samples=samples,
	ploidy=2,
	# recombination_rate=rate_map,
	sequence_length=1e8,
	recombination_rate=2e-8,
	demography=demography,
)

ts = msprime.sim_mutations(
	ts,
	rate=1e-8,
)

# drop sites with recurrent mutations
ts = ts.delete_sites(np.where([len(s.mutations) > 1 for s in ts.sites()])[0])

ts.dump(trees_file)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tskit
import demes
import pyslim
import msprime
import numpy as np

in_trees = snakemake.input['trees_file']
demes_file = snakemake.input['demes_file']
out_trees = snakemake.output['trees_file']
mut_rate = snakemake.params['neutral_mut_rate']

ts = pyslim.update( # as we use SLiM v3.7
    tskit.load(in_trees)
)
graph = demes.load(demes_file)

Ne = graph.demes[0].epochs[0].start_size

# recap and add neutral mutations
ts_recap = pyslim.recapitate(ts, ancestral_Ne=Ne)
ts_mut = msprime.sim_mutations(
    ts_recap,
    rate=mut_rate,
    keep=False, # discard existing mutations
)
ts_mut = ts_mut.delete_sites(
    np.where([
        len(s.mutations) > 1 for s in ts_mut.sites()
    ])[0]
)

ts_mut.dump(out_trees)
33
34
script:
	"scripts/prep_scenario.py"
44
45
script:
	"scripts/main_figure.py"
ShowHide 15 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/alxsimon/admixcov_sims
Name: admixcov_sims
Version: v1.0.2
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...