Snakemake Workflow for tsABC Method with A. thaliana Data Application

public public 1yr ago 0 bookmarks

This git provides a snakemake workflow to run tsABC, the method introduced in our paper doi: https://doi.org/10.1101/2022.07.29.502030

The workflow requires some knowledge about snakemake (https://snakemake.readthedocs.io/en/stable/).

Theoretically, the whole pipeline could be run with $ snakemake -j 1 However, we suggest a more careful use with exact understanding of each step.

tsABC was first developed to run on A. thaliana, thus, the data application module is named "athal" (module04).

How to run

  • provide correct parameters, e. g. mutation rate and recombination rate and nsim, then run everything

IMPORTANT

  • The downsampling for model choice does NOT work correctly; there are two possible ways to handle it: First, one could simulate the same number of simulations for both proposed models. Second, the downsampling should be taken care manually in the according (or before) R script.

Code Snippets

36
37
script:
    "../scripts/00.calculate_theta_watterson.py"
29
30
script:
    "../scripts/01.create_abc_simulation_randints.py"
49
50
script:
    "../scripts/01.create_abc_pod_randints.py"
74
75
script:
    "../scripts/01.simulate_treeseqs.py"
92
93
script:
    "../scripts/01.confirm_params_for_loci_being_same_and_save_params.py"
116
117
script:
    "../scripts/01.simulate_treeseq_pods.py"
155
156
script:
    "../scripts/01.generate_discretizing_breakpoints_for_sumstats.py"
177
178
script:
    "../scripts/01.calculate_sumstats.py"
200
201
script:
    "../scripts/01.calculate_masked_sumstats.py"
231
232
script:
    "../scripts/01.aggregate_sumstats.py"
262
263
script:
    "../scripts/01.aggregate_sumstats.py"
279
280
script:
    "../scripts/01.create_abc_simulation_randints.py"
304
305
script:
    "../scripts/01.alternative_model_simulate_treeseqs.py"
326
327
script:
    "../scripts/01.calculate_sumstats.py"
349
350
script:
    "../scripts/01.calculate_masked_sumstats.py"
367
368
script:
    "../scripts/01.confirm_params_for_loci_being_same_and_save_params.py"
398
399
script:
    "../scripts/01.aggregate_sumstats.py"
429
430
script:
    "../scripts/01.aggregate_sumstats.py"
451
452
script:
    "../scripts/01.calculate_podstats.py"
473
474
script:
    "../scripts/01.aggregate_podstats.py"
497
498
script:
    "../scripts/01.calculate_masked_podstats.py"
519
520
script:
    "../scripts/01.aggregate_podstats.py"
148
149
script:
    "../scripts/02.model_choice.R"
168
169
script:
    "../scripts/02.model_choice.R"
186
187
script:
    "../scripts/02.aggregate_model_choice.R"
204
205
script:
    "../scripts/02.aggregate_model_choice.R"
26
27
script:
    "../scripts/03.parameter_estimation.R"
50
51
script:
    "../scripts/03.parameter_estimation.R"
68
69
script:
    "../scripts/03.aggregate_parameter_estimation.R"
86
87
script:
    "../scripts/03.aggregate_parameter_estimation.R"
28
29
script:
    "../scripts/01.create_abc_simulation_randints.py"
53
54
script:
    "../scripts/04.simulate_sixparmodel_treeseqs.py"
71
72
script:
    "../scripts/01.confirm_params_for_loci_being_same_and_save_params.py"
93
94
script:
    "../scripts/01.calculate_sumstats.py"
116
117
script:
    "../scripts/01.calculate_masked_sumstats.py"
147
148
script:
    "../scripts/01.aggregate_sumstats.py"
178
179
script:
    "../scripts/01.aggregate_sumstats.py"
202
203
script:
    "../scripts/04.calculate_athal_observations.py"
226
227
script:
    "../scripts/04.calculate_athal_observations.py"
251
252
script:
    "../scripts/04.calculate_athal_all_region_observations.py"
276
277
script:
    "../scripts/04.calculate_athal_all_region_observations.py"
297
298
script:
    "../scripts/04.aggregate_athal_sumstats.py"
318
319
script:
    "../scripts/04.aggregate_athal_sumstats.py"
339
340
script:
    "../scripts/04.aggregate_athal_sumstats.py"
360
361
script:
    "../scripts/04.aggregate_athal_sumstats.py"
382
383
script:
    "../scripts/04.parameter_estimation.R"
404
405
script:
    "../scripts/04.parameter_estimation.R"
426
427
script:
    "../scripts/04.parameter_estimation.R"
448
449
script:
    "../scripts/04.parameter_estimation.R"
472
473
script:
    "../scripts/04.aggregate_athal_parameter_estimation.R"
496
497
script:
    "../scripts/04.aggregate_athal_parameter_estimation.R"
512
513
514
515
run:
    sys.exit(
        "#" * 600 + "inside aggregate_athal_parameter_estimation_masked\n" + ""
    )
529
530
531
532
run:
    sys.exit(
        "#" * 600 + "inside aggregate_athal_parameter_estimation_masked\n" + ""
    )
23
24
script:
    "../scripts/05.visualize_podstats.R"
39
40
script:
    "../scripts/05.visualize_podstats.R"
55
56
script:
    "../scripts/05.visualize_model_choice.R"
71
72
script:
    "../scripts/05.visualize_model_choice.R"
89
90
script:
    "../scripts/05.visualize_parameter_estimation.R"
107
108
script:
    "../scripts/05.visualize_parameter_estimation.R"
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import sys
import datetime
import numpy as np
import msprime
import tskit
import json
import pickle
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        f"start calculating theta Watterson with mode: {snakemake.wildcards.mode}",
        file=logfile,
    )


# random seed generator subsetting and simulations
rng = np.random.default_rng(snakemake.params.seed)


if snakemake.wildcards.mode == "region":
    # read tree sequence
    treeseq_athal = tskit.load(snakemake.input.treeseq_athal)

    # read sample names of first population
    sample_names = np.loadtxt(snakemake.input.samples, dtype=str)

    # find the node ids for the sample of the population
    population_sample = []
    for individual in treeseq_athal.individuals():
        if str(json.loads(individual.metadata)["id"]) in sample_names:
            population_sample.extend(individual.nodes)

    # sample treeseq to provided samples
    treeseq_athal_population = treeseq_athal.simplify(samples=population_sample)
    del treeseq_athal

    # get the chromosomal regions from the config file
    chromosome_regions = []
    for chromid, (start, stop) in enumerate(
        snakemake.config["ABC"]["athaliana"]["observations"]["treeseq_1001"][
            "chosen_region"
        ],
        start=1,
    ):
        start += snakemake.params.chrom_multiplier * chromid
        stop += snakemake.params.chrom_multiplier * chromid
        chromosome_regions.append((start, stop))

        # log
        with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
            print(datetime.datetime.now(), end="\t", file=logfile)
            print(
                f"chropped region {chromid} of "
                + f"""{len(snakemake.config["ABC"]["athaliana"]["observations"][
                    "treeseq_1001"]["chosen_region"])}""",
                file=logfile,
            )

    # chop down to regions
    treeseq_list = []
    for chromid, (start, stop) in enumerate(chromosome_regions):
        treeseq_list.append(
            treeseq_athal_population.keep_intervals([(start, stop)]).trim()
        )
    del treeseq_athal_population

    # create subsample from treesequence
    specs = {
        "num_observations": int(
            float(
                snakemake.config["ABC"]["athaliana"]["observations"]["num_observations"]
            )
        ),
        "nsam": int(float(snakemake.config["ABC"]["simulations"]["nsam"])),
    }
    tsl = pyfuncs.create_subsets_from_treeseqlist(
        treeseq_list, specs, rng, snakemake.log.log1
    )

elif snakemake.wildcards.mode == "genome":
    # read tree sequence
    treeseq_athal = tskit.load(snakemake.input.treeseq_athal)

    # read sample names of first population
    sample_names = np.loadtxt(snakemake.input.samples, dtype=str)

    # find the node ids for the sample of the population
    population_sample = []
    for individual in treeseq_athal.individuals():
        if str(json.loads(individual.metadata)["id"]) in sample_names:
            population_sample.extend(individual.nodes)

    # sample treeseq to provided samples
    treeseq_athal_population = treeseq_athal.simplify(samples=population_sample)
    del treeseq_athal

    # get the chromosomal regions from the config file
    chromosome_regions = []
    for chromid, start, stop in snakemake.config["ABC"]["athaliana"]["observations"][
        "treeseq_1001"
    ]["whole_gemome_approach"]:
        start += snakemake.params.chrom_multiplier * chromid
        stop += snakemake.params.chrom_multiplier * chromid
        chromosome_regions.append((start, stop))

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print("prepared regions", file=logfile)

    # chop down to regions
    treeseq_list = []
    for start, stop in chromosome_regions:
        # calculate chromid from position in original treeseq
        chromid = int(start / snakemake.params.chrom_multiplier)

        # add tuple with chromosome/treeseq, as chromid is needed for masking
        treeseq_list.append(
            (chromid, treeseq_athal_population.keep_intervals([(start, stop)]).trim())
        )

        # log
        with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
            print(datetime.datetime.now(), end="\t", file=logfile)
            print(
                f"prepared regions for chromsome {chromid + 1} of {len(chromosome_regions)}",
                file=logfile,
            )

    del treeseq_athal_population

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(
            "created region treeseq list (1 per region of each chromosome)",
            file=logfile,
        )

    # create subsample from treesequence
    specs = {
        "num_observations": int(
            float(
                snakemake.config["ABC"]["athaliana"]["observations"]["treeseq_1001"][
                    "whole_genome_approach_num_observations"
                ]
            )
        ),
        "nsam": int(float(snakemake.config["ABC"]["simulations"]["nsam"])),
    }

    # remove chromosome from treeseq_list
    treeseq_list = [treeseq for _, treeseq in treeseq_list]

    tsl = pyfuncs.create_subsets_from_treeseqlist(
        treeseq_list, specs, rng, snakemake.log.log1
    )

elif snakemake.wildcards.mode == "pod":
    # Loading list of tree sequences
    with open(snakemake.input.tsl_pod[0], "rb") as tsl_file:
        tsl = np.array(pickle.load(tsl_file), dtype=object)

else:
    assert False, "unknown mode to calclate theta Watterson"


# calculate theta_watterson per treeseq
theta_watterson = []
for treeseq in tsl.flatten():
    theta_watterson.append(
        treeseq.segregating_sites(span_normalise=True)
        / sum([1 / i for i in range(1, treeseq.num_samples)])
    )

theta_watterson = np.array(theta_watterson).mean()

# calculate expected 2N
two_N_zero = round(
    theta_watterson / (2 * float(snakemake.config["ABC"]["simulations"]["mutrate"]))
)


print("\n" + "_" * 80, file=sys.stderr)
print(f"theta_watterson\t{theta_watterson}", file=sys.stderr)
print(f"two_N_zero\t{two_N_zero}", file=sys.stderr)
print("=" * 80 + "\n", file=sys.stderr)

# print to output file
with open(snakemake.output.txt, "w", encoding="utf-8") as outfile:
    print(f"theta_watterson\t\t{theta_watterson}", file=outfile)
    print(f"two_N_zero\t\t{two_N_zero}", file=outfile)


# log the results
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"theta_watterson\t{theta_watterson}", file=logfile)
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"two_N_zero\t{two_N_zero}", file=logfile)
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
import re
import datetime
import numpy as np
import pandas as pd


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start aggregating params and sumstats",
        file=logfile,
    )


# obtain simid and locid from filenames
npodid = len(snakemake.params.podid_wc)

# create data list to read and store by locus
read_data_list = [[] for _ in range(npodid)]
for npy in snakemake.input.npys:
    split = re.split(r"_|\.|\/", npy)
    podid = int(split[np.where(np.array(split) == "podid")[0].max() + 1])
    read_data_list[
        np.where(podid == np.array(snakemake.params.podid_wc))[0][0]
    ] = np.load(npy)
read_data = np.array(read_data_list, dtype=float)

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("worked out sumstats", file=logfile)


# read sumstat names
sumstat_names_list = []
for name_file in snakemake.input.sumstat_names:
    sumstat_names_list.append(np.load(name_file))
    if len(sumstat_names_list) >= 2:
        assert all(
            sumstat_names_list[-2] == sumstat_names_list[-1]
        ), "sumstats are not congruent"
sumstat_names = sumstat_names_list[0]


# read params from config
npodid = len(snakemake.config["ABC"]["performance"]["pods"][0])
param_values = [[] for _ in range(npodid)]
for podid in range(npodid):
    for param in snakemake.config["ABC"]["performance"]["pods"]:
        param_values[podid].append(param[podid])
param_values = np.array(param_values).astype(float)


# concatenate and parse data into pandas dataframe; add the parameters
result_table_list = []
result_podid_list = []
for podid in range(read_data.shape[0]):
    result_table_list.append(pd.DataFrame(data=read_data[podid], columns=sumstat_names))
    result_table_list[-1]["podid"] = podid

    parameter_columns_names = []
    for parid, param in enumerate(param_values[podid]):
        result_table_list[-1][f"param_{parid}"] = param
        parameter_columns_names.append(f"param_{parid}")

    # save pod index to list
    result_podid_list.extend([podid for _ in range(len(result_table_list[-1].index))])

    # move podid and parameter columns to front
    result_table_list[-1] = result_table_list[-1][
        parameter_columns_names
        + [
            col
            for col in result_table_list[-1].columns
            if col not in parameter_columns_names
        ]
    ]


pd.concat(result_table_list, ignore_index=True).to_feather(
    snakemake.output.sumstats, compression="lz4"
)

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved dataframe to {snakemake.output.sumstats}", file=logfile)


# save podid according to config file
np.savetxt(snakemake.output.podid, np.array(result_podid_list).astype(int), fmt="%s")

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved podid npy array", file=logfile)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import pandas as pd
import re
import datetime

# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start aggregating params and sumstats",
        file=logfile,
    )


# obtain simid and locid from filenames
nloci = len(snakemake.params.locid_wc)
nsimid = len(snakemake.params.simid_wc)

# create data list to read and store by locus
read_data_list = [[[] for _ in range(nloci)] for _ in range(nsimid)]
for npyid, npy in enumerate(snakemake.input.npys):
    split = re.split(r"_|\.|\/", npy)
    simid = int(split[np.where(np.array(split) == "sim")[0].max() + 1])
    locid = int(split[np.where(np.array(split) == "locus")[0].max() + 1])
    read_data_list[np.where(simid == np.array(snakemake.params.simid_wc))[0][0]][
        locid
    ] = np.load(npy)
read_data = np.array(read_data_list, dtype=float)


# take mean value over the independent loci
read_data = read_data.mean(axis=1)


# reshape the data into a 2d-array, because each input file provides a cluster
# batch of independent simulations
read_data = read_data.reshape(
    (read_data.shape[0] * read_data.shape[1], read_data.shape[2])
)

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("worked out sumstats", file=logfile)

# read parameters
read_param_list = [[] for _ in range(nsimid)]
for npy in snakemake.input.params:
    split = re.split(r"_|\.|\/", npy)
    simid = int(split[np.where(np.array(split) == "sim")[0].max() + 1])
    read_param_list[
        np.where(simid == np.array(snakemake.params.simid_wc))[0][0]
    ] = np.load(npy)
read_param = np.array(read_param_list)
num_provided_params = read_param.shape[2]
read_param = read_param.reshape(
    (read_param.shape[0] * read_param.shape[1], read_param.shape[2])
)

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("worked out params", file=logfile)

# concatenate parameters and sumstatas
result_table = np.concatenate((read_param, read_data), axis=1)

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("concatenated out params and sumstats", file=logfile)


# read sumstat names
sumstat_names_list = []
for name_file in snakemake.input.sumstat_names:
    sumstat_names_list.append(np.load(name_file))
    if len(sumstat_names_list) >= 2:
        assert all(
            sumstat_names_list[-2] == sumstat_names_list[-1]
        ), "sumstats are not congruent"
sumstat_names = sumstat_names_list[0]

# add param names
column_names = np.concatenate(
    (
        np.array([f"param_{parid}" for parid in range(num_provided_params)]),
        sumstat_names,
    ),
    axis=0,
)

result_table = pd.DataFrame(data=result_table, columns=column_names)
result_table.to_feather(snakemake.output.sumstats, compression="lz4")

with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved dataframe to {snakemake.output.sumstats}", file=logfile)
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import sys
import datetime
import pickle
import numpy as np
import tskit
import msprime
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start simulating tree sequences", file=logfile)


def main(simid):
    """Heart of this script

    The main function will create a single tree sequence. The main function has
    to be repeatedly executed until the list of treesequences has been created.
    """
    # provide the correct seed to create the random number generator
    seed = np.load(snakemake.input.npy)[simid][int(float(snakemake.wildcards.locid))]
    seed_params = np.load(snakemake.input.npy)[simid][
        0
    ]  # the parameters for the independent loci must be the same
    rng = np.random.default_rng(seed)
    rng_params = np.random.default_rng(seed_params)

    # draw the parameters of the model
    params = [
        pyfuncs.draw_parameter_from_prior(prior_definition, rng_params)
        for prior_definition in snakemake.params.model["priors"]
    ]
    del (
        seed_params,
        rng_params,
    )  # only the parameter drawing relies on the same seed, the simulations must be independent

    # simulate
    ts = pyfuncs.simulate_treesequence_under_alternative_model(
        params, snakemake.params, rng, snakemake.log.log1
    )

    return ts, params


# define the loop for the clustering of consequential simulations
start = int(float(snakemake.wildcards.simid))
end = min(
    start + int(float(snakemake.config["ABC"]["jobluster_simulations"])),
    int(float(snakemake.config["ABC"]["alternative_model"]["nsim"])),
)

assert start < end, "simid (Simulation IDs) maldefined, you will not simulate anything"

# run the simulations
treesequence_list = []
parameters_list = []
for simid in range(start, end):
    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(f"creating treesequence no {simid}", file=logfile)

    treeseq, params = main(simid)
    treesequence_list.append(treeseq)
    parameters_list.append(params)


# save treeseq list
with open(snakemake.output.tsl, "wb") as tsl_file:
    pickle.dump(treesequence_list, tsl_file)


# save parameters list
np.save(snakemake.output.params, np.array(parameters_list))
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import sys
import datetime
import warnings
import itertools
import pickle
import re
import math
import numpy as np
import tskit
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start calculating summarizing statistics", file=logfile)


# Loading list of tree sequences
with open(snakemake.input.tsl, "rb") as tsl_file:
    tsl = np.array(pickle.load(tsl_file))


# sample only one haplotype per individual
rng = np.random.default_rng(snakemake.params.seed)
tsl_haploid = np.empty(tsl.shape, dtype=tskit.TreeSequence)
for treeid, treeseq in np.ndenumerate(tsl):
    sample_set = [
        this_individual.nodes[rng.integers(low=0, high=2)]
        for this_individual in treeseq.individuals()
    ]
    tsl_haploid[treeid] = treeseq.simplify(samples=sample_set)
del tsl
tsl = tsl_haploid
del tsl_haploid


# read and prepare mask files
mask = []
for mask_file in snakemake.input.mask:
    # the id of the chromosome from input file
    split = re.split(r"_|\.|\/", mask_file)
    chromid = int(split[np.where(np.array(split) == "locus")[0].max() + 1])

    # load mask file
    this_mask = np.loadtxt(mask_file, dtype=int)
    region_start, region_end = snakemake.config["ABC"]["athaliana"]["observations"][
        "treeseq_1001"
    ]["chosen_region"][chromid]
    region_mask = this_mask[
        (region_start <= this_mask[:, 0]) & (region_end > this_mask[:, 1])
    ]
    this_mask = region_mask - region_start

    # filter mask for disjoint intervals
    this_mask = pyfuncs.filter_mask_for_disjoint_intervals(
        this_mask, log=snakemake.log.log1
    )
    mask.append(this_mask)
    del region_mask, this_mask


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("loaded all masks for pods", file=logfile)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read and prepared the mask file", file=logfile)


tsl_masked = np.empty(tsl.shape, dtype=tskit.TreeSequence)
for treeid, treeseq in np.ndenumerate(tsl):
    chromid = treeid[1]
    tsl_masked[treeid] = treeseq.delete_intervals(
        mask[chromid], simplify=True, record_provenance=True
    )

del tsl
tsl = tsl_masked
del tsl_masked


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read and masked the tree sequences", file=logfile)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs

    dataframe_sumstats_sfs = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_sfs[treeid] = treeseq.allele_frequency_spectrum(
            sample_sets=None,
            windows=None,
            mode="site",
            span_normalise=False,
            polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
        )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)

    # get average of sfs over the different loci
    dataframe_sumstats_sfs = dataframe_sumstats_sfs.mean(axis=1)

    dataframe_sumstats_sfs = np.concatenate(dataframe_sumstats_sfs, axis=0).reshape(
        (len(dataframe_sumstats_sfs), dataframe_sumstats_sfs[0].shape[0])
    )
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

    # get average of ld over the different loci
    dataframe_sumstats_ld = dataframe_sumstats_ld.mean(axis=1)

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win

    dataframe_sumstats_tm_win = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_tm_win[treeid] = pyfuncs.calculate_tm_win(
            treeseq, specs, breaks_tm_win
        )

    # get average of tm_win over the different loci
    dataframe_sumstats_tm_win = dataframe_sumstats_tm_win.mean(axis=1)
    dataframe_sumstats_tm_win = np.concatenate(
        dataframe_sumstats_tm_win, axis=0
    ).reshape((len(dataframe_sumstats_tm_win), dataframe_sumstats_tm_win[0].shape[0]))
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save treeseq list
np.save(snakemake.output.npy, dataframe_sumstats)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.npy}", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


assert (
    len(sumstat_names) == dataframe_sumstats.shape[1]
), "name vector for summary statistics has the wrong dimension"


np.save(snakemake.output.sumstat_count, sumstat_names)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import sys
import datetime
import warnings
import itertools
import pickle
import math
import numpy as np
import tskit
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start calculating summarizing statisticss", file=logfile)


# Loading list of tree sequences
with open(snakemake.input.tsl, "rb") as tsl_file:
    tsl = pickle.load(tsl_file)

# sample only one haplotype per individual
rng = np.random.default_rng(snakemake.params.seed)
tsl_haploid = []
for treeseq in tsl:
    sample_set = [
        this_individual.nodes[rng.integers(low=0, high=2)]
        for this_individual in treeseq.individuals()
    ]
    tsl_haploid.append(treeseq.simplify(samples=sample_set))
del tsl
tsl = tsl_haploid
del tsl_haploid


# read and prepare mask files
mask = np.loadtxt(snakemake.input.mask, dtype=int)
region_start, region_end = snakemake.config["ABC"]["athaliana"]["observations"][
    "treeseq_1001"
]["chosen_region"][int(float(snakemake.wildcards.locid))]
region_mask = mask[(region_start <= mask[:, 0]) & (region_end > mask[:, 1])]
mask = region_mask - region_start
del region_mask

# filter mask for disjoint intervals
mask = pyfuncs.filter_mask_for_disjoint_intervals(mask, log=snakemake.log.log1)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "proportion of exons to mask in chromosome(1-indexed) "
        + f"{int(float(snakemake.wildcards.locid)) + 1}: "
        + f"{(mask[:, 1] - mask[:, 0]).sum()/(region_end-region_start)}",
        file=logfile,
    )


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read and prepared the mask file", file=logfile)


# mask for regions provided
tsl_masked = []
for treeseq in tsl:
    tsl_masked.append(
        treeseq.delete_intervals(mask, simplify=True, record_provenance=True)
    )
tsl = tsl_masked
del tsl_masked


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read and masked the tree sequences", file=logfile)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs
    dataframe_sumstats_sfs = np.array(
        [
            treeseq.allele_frequency_spectrum(
                sample_sets=None,
                windows=None,
                mode="site",
                span_normalise=False,
                polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
            )
            for treeseq in tsl
        ]
    )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(np.array(tsl).shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win
    dataframe_sumstats_tm_win = np.array(
        [pyfuncs.calculate_tm_win(treeseq, specs, breaks_tm_win) for treeseq in tsl]
    )
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save treeseq list
np.save(snakemake.output.npy, dataframe_sumstats)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.npy}", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


assert (
    len(sumstat_names) == dataframe_sumstats.shape[1]
), "name vector for summary statistics has the wrong dimension"


np.save(snakemake.output.sumstat_count, sumstat_names)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import sys
import datetime
import warnings
import itertools
import pickle
import math
import numpy as np
import tskit
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start calculating summarizing statistics", file=logfile)


# Loading list of tree sequences
with open(snakemake.input.tsl, "rb") as tsl_file:
    tsl = np.array(pickle.load(tsl_file))


# sample only one haplotype per individual
rng = np.random.default_rng(snakemake.params.seed)
tsl_haploid = np.empty(tsl.shape, dtype=tskit.TreeSequence)
for treeid, treeseq in np.ndenumerate(tsl):
    sample_set = [
        this_individual.nodes[rng.integers(low=0, high=2)]
        for this_individual in treeseq.individuals()
    ]
    tsl_haploid[treeid] = treeseq.simplify(samples=sample_set)
del tsl
tsl = tsl_haploid
del tsl_haploid


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the tree sequences", file=logfile)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs

    dataframe_sumstats_sfs = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_sfs[treeid] = treeseq.allele_frequency_spectrum(
            sample_sets=None,
            windows=None,
            mode="site",
            span_normalise=False,
            polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
        )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)

    # get average of sfs over the different loci
    dataframe_sumstats_sfs = dataframe_sumstats_sfs.mean(axis=1)

    dataframe_sumstats_sfs = np.concatenate(dataframe_sumstats_sfs, axis=0).reshape(
        (len(dataframe_sumstats_sfs), dataframe_sumstats_sfs[0].shape[0])
    )
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

    # get average of ld over the different loci
    dataframe_sumstats_ld = dataframe_sumstats_ld.mean(axis=1)

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win

    dataframe_sumstats_tm_win = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_tm_win[treeid] = pyfuncs.calculate_tm_win(
            treeseq, specs, breaks_tm_win
        )

    # get average of tm_win over the different loci
    dataframe_sumstats_tm_win = dataframe_sumstats_tm_win.mean(axis=1)
    dataframe_sumstats_tm_win = np.concatenate(
        dataframe_sumstats_tm_win, axis=0
    ).reshape((len(dataframe_sumstats_tm_win), dataframe_sumstats_tm_win[0].shape[0]))
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save treeseq list
np.save(snakemake.output.npy, dataframe_sumstats)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.npy}", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


assert (
    len(sumstat_names) == dataframe_sumstats.shape[1]
), "name vector for summary statistics has the wrong dimension"


np.save(snakemake.output.sumstat_count, sumstat_names)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import sys
import datetime
import warnings
import itertools
import pickle
import math
import numpy as np
import tskit
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start calculating summarizing statisticss", file=logfile)


# Loading list of tree sequences
with open(snakemake.input.tsl, "rb") as tsl_file:
    tsl = pickle.load(tsl_file)

# sample only one haplotype per individual
rng = np.random.default_rng(snakemake.params.seed)
tsl_haploid = []
for treeseq in tsl:
    sample_set = [
        this_individual.nodes[rng.integers(low=0, high=2)]
        for this_individual in treeseq.individuals()
    ]
    tsl_haploid.append(treeseq.simplify(samples=sample_set))
del tsl
tsl = tsl_haploid
del tsl_haploid


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the tree sequences", file=logfile)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs
    dataframe_sumstats_sfs = np.array(
        [
            treeseq.allele_frequency_spectrum(
                sample_sets=None,
                windows=None,
                mode="site",
                span_normalise=False,
                polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
            )
            for treeseq in tsl
        ]
    )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(np.array(tsl).shape, dtype=tskit.TreeSequence)
    for treeid, treeseq in np.ndenumerate(tsl):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win
    dataframe_sumstats_tm_win = np.array(
        [pyfuncs.calculate_tm_win(treeseq, specs, breaks_tm_win) for treeseq in tsl]
    )
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save treeseq list
np.save(snakemake.output.npy, dataframe_sumstats)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.npy}", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


assert (
    len(sumstat_names) == dataframe_sumstats.shape[1]
), "name vector for summary statistics has the wrong dimension"


np.save(snakemake.output.sumstat_count, sumstat_names)
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import datetime
import numpy as np


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start confirming the parameters of the independent loci being equal",
        file=logfile,
    )


# read params
params_per_locus = [np.load(params_file) for params_file in snakemake.input.params]
LOCUS_INDEX = 0
for param_index in range(1, len(params_per_locus)):
    params1, params2 = params_per_locus[param_index - 1 : param_index + 1]
    assert (
        params1 == params2
    ).all(), "Parameters for the different loci are not the same!"
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(f"{LOCUS_INDEX} ?== {LOCUS_INDEX+1}: True", file=logfile)
        LOCUS_INDEX += 1


with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "confirmed the parameters of the independent loci being equal",
        file=logfile,
    )


# save one copy of the params into a .npy file
np.save(
    snakemake.output.params,
    # params1 is the pre-last np.array from the read files, but as all of them
    # are checked for equality, it does not matter which one to save
    params1,
)


with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved params to .npy file: {snakemake.output.params}", file=logfile)
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import datetime
import numpy as np


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start creating random seeds for simulations", file=logfile)


# create new random seed generator
rng = np.random.default_rng(snakemake.params.seed)

# draw seeds and save as .npy to file
np.save(
    snakemake.output.npy,
    rng.integers(low=0, high=np.iinfo(int).max, size=snakemake.params.nseed),
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created random seeds for simulations", file=logfile)
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import datetime
import numpy as np


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start creating random seeds for simulations", file=logfile)


# create new random seed generator
rng = np.random.default_rng(snakemake.params.seed)

# draw seeds and save as .npy to file
np.save(
    snakemake.output.npy,
    rng.integers(low=0, high=np.iinfo(int).max, size=snakemake.params.nseed),
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created random seeds for simulations", file=logfile)
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import datetime
import itertools
import numpy as np
import warnings
import pickle
import tskit
import json
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start creating breaks for discretization of summarizing statistics",
        file=logfile,
    )


# get unique sumstats from sumstat sets
unique_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.config["ABC"]["sumstats"]]
    )
)


# random number generator to choose the haplotype per individual and subsampling
rng = np.random.default_rng(snakemake.params.seed)


# loop through sumstats
for sumstat in unique_sumstats:
    breaks = None
    if sumstat == "SFS":
        breaks = np.empty(shape=0)  # we do not discretize SFS
    elif sumstat == "LD":
        breaks = np.array(
            [
                float(this_break)
                for this_break in snakemake.config["ABC"]["sumstats_specs"]["LD"][
                    "breaks"
                ]
            ]
        )
    elif sumstat == "TM_WIN":
        if snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["breaks_mode"] == "pod":
            # Loading list of tree sequences
            with open(snakemake.input.tsl[0], "rb") as tsl_file:
                tsl = pickle.load(tsl_file)

            # calculate the data-based breakspoints
            breaks = pyfuncs.find_breakpoints_for_TM_WIN(
                tsl=tsl,
                specs=snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"],
                rng=rng,
                log=snakemake.log.log1,
            )
        elif (
            snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["breaks_mode"]
            == "athal"
        ):
            # read tree sequence
            treeseq_athal = tskit.load(snakemake.input.tsl[0])

            # read sample names of first population
            sample_names = np.loadtxt(snakemake.input.tsl[1], dtype=str)

            # find the node ids for the sample of the population
            population_sample = []
            for individual in treeseq_athal.individuals():
                if str(json.loads(individual.metadata)["id"]) in sample_names:
                    population_sample.extend(individual.nodes)

            # sample treeseq to provided samples
            treeseq_athal_population = treeseq_athal.simplify(samples=population_sample)
            del treeseq_athal

            # get the chromosomal regions from the config file
            chromosome_regions = []
            for chromid, (start, stop) in enumerate(
                snakemake.config["ABC"]["athaliana"]["observations"]["treeseq_1001"][
                    "chosen_region"
                ],
                start=1,
            ):
                start += snakemake.params.chrom_multiplier * chromid
                stop += snakemake.params.chrom_multiplier * chromid
                chromosome_regions.append((start, stop))

            # chop down to regions
            treeseq_list = []
            for chromid, (start, stop) in enumerate(chromosome_regions):
                treeseq_list.append(
                    treeseq_athal_population.keep_intervals([(start, stop)]).trim()
                )
            del treeseq_athal_population

            # create subsample from treesequence
            specs = {
                "num_observations": int(
                    float(
                        snakemake.config["ABC"]["athaliana"]["observations"][
                            "num_observations"
                        ]
                    )
                ),
                "nsam": int(float(snakemake.config["ABC"]["simulations"]["nsam"])),
            }
            tsl = pyfuncs.create_subsets_from_treeseqlist(
                treeseq_list, specs, rng, snakemake.log.log1
            )

            # calculate the data-based breakspoints
            breaks = pyfuncs.find_breakpoints_for_TM_WIN(
                tsl=tsl,
                specs=snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"],
                rng=rng,
                log=snakemake.log.log1,
            )
        elif (
            snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["breaks_mode"]
            == "expected"
        ):
            discretized_times = pyfuncs.discretized_times(
                n=snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["classes"], M=2
            )
            mutrate = float(snakemake.config["ABC"]["simulations"]["mutrate"])
            two_N_zero = int(
                float(
                    snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"][
                        "two_N_zero_if_expected"
                    ]
                )
            )
            window_size = int(
                float(snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["winsize"])
            )
            breaks = pyfuncs.snp_freq_from_times(
                discretized_times, two_N_zero, mutrate, window_size
            )
        else:
            sys.exit(
                "#" * 600
                + " inside generate_discretizing_breakpoints_for_sumstats\n"
                + f'your breaks_mode is maldefined: {snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]["breaks_mode"]}\n'
                + "YOU SHOULD NEVER REACH HERE!"
            )

    # test if they belong to the sumstats that are implemented, may need to
    # co-check with the config check rule in module_00
    if breaks is None:
        warnings.warn(
            "".join(
                [
                    "some sumstats are not taken into account",
                    " for the creation of discretization breaks",
                ]
            )
        )

    outfile_name = f"resources/discretization/sumstat_{sumstat}.npytxt"
    np.savetxt(outfile_name, breaks)

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(f"created breaks for discretization of {sumstat}", file=logfile)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created breaks for discretization of summarizing statistics", file=logfile)
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
import datetime
import pickle
import numpy as np
import tskit
import msprime
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start simulating tree sequences", file=logfile)


def main(simid):
    """Heart of this script

    The main function will create a single tree sequence. The main function has
    to be repeatedly executed until the list of treesequences has been created.
    """
    # provide the correct seed to create the random number generator
    seed = np.load(snakemake.input.npy)[
        simid[0],
        simid[1],
        int(float(snakemake.wildcards.podid)),
    ]
    rng = np.random.default_rng(seed)

    # draw the parameters of the model
    params = (
        int(
            float(snakemake.params["params"][0][int(float(snakemake.wildcards.podid))])
        ),
        float(snakemake.params["params"][1][int(float(snakemake.wildcards.podid))]),
        float(snakemake.params["params"][2][int(float(snakemake.wildcards.podid))]),
        int(
            float(snakemake.params["params"][3][int(float(snakemake.wildcards.podid))])
        ),
    )

    # simulate
    ts = pyfuncs.simulate_treesequence_under_model(
        params, snakemake.params, rng, snakemake.log.log1
    )

    return ts


# run the simulations
pod_list_of_treesequence_list_per_locus = []
for repid in range(snakemake.params.nsim):
    treesequence_list = []
    for locid in range(snakemake.params.nloci):
        # log
        with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
            print(datetime.datetime.now(), end="\t", file=logfile)
            print(
                f"creating treeseq for pod {snakemake.wildcards.podid} locus {locid} of rep {repid}",
                file=logfile,
            )

        treeseq = main((repid, locid))
        treesequence_list.append(treeseq)
    pod_list_of_treesequence_list_per_locus.append(treesequence_list)


# save treeseq list
with open(snakemake.output.tsl, "wb") as tsl_file:
    pickle.dump(pod_list_of_treesequence_list_per_locus, tsl_file)
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import sys
import datetime
import pickle
import numpy as np
import tskit
import msprime
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start simulating tree sequences", file=logfile)


def main(simid):
    """Heart of this script

    The main function will create a single tree sequence. The main function has
    to be repeatedly executed until the list of treesequences has been created.
    """
    # provide the correct seed to create the random number generator
    seed = np.load(snakemake.input.npy)[simid][int(float(snakemake.wildcards.locid))]
    seed_params = np.load(snakemake.input.npy)[simid][
        0
    ]  # the parameters for the independent loci must be the same
    rng = np.random.default_rng(seed)
    rng_params = np.random.default_rng(seed_params)

    # draw the parameters of the model
    params = [
        pyfuncs.draw_parameter_from_prior(prior_definition, rng_params)
        for prior_definition in snakemake.params.model["priors"]
    ]
    del (
        seed_params,
        rng_params,
    )  # only the parameter drawing relies on the same seed, the simulations must be independent

    # simulate
    ts = pyfuncs.simulate_treesequence_under_model(
        params, snakemake.params, rng, snakemake.log.log1
    )

    return ts, params


# define the loop for the clustering of consequential simulations
start = int(float(snakemake.wildcards.simid))
end = min(
    start + int(float(snakemake.config["ABC"]["jobluster_simulations"])),
    int(float(snakemake.config["ABC"]["simulations"]["nsim"])),
)

assert start < end, "simid (Simulation IDs) maldefined, you will not simulate anything"

# run the simulations
treesequence_list = []
parameters_list = []
for simid in range(start, end):
    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(f"creating treesequence no {simid}", file=logfile)

    treeseq, params = main(simid)
    treesequence_list.append(treeseq)
    parameters_list.append(params)


# save treeseq list
with open(snakemake.output.tsl, "wb") as tsl_file:
    pickle.dump(treesequence_list, tsl_file)


# save parameters list
np.save(snakemake.output.params, np.array(parameters_list))
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# read data
bayes_factors <- list()
for (infileid in 1:length(snakemake@input$bayes_factors)) {
  infile.name <- snakemake@input$bayes_factors[infileid]
  bayes_factors[[infileid]] <- readRDS(infile.name)
  split <-
    strsplit(infile.name, split = "_|\\.|/")[[1]]  # read wc from filename
  bayes_factors[[infileid]]$statcomposition <-
    rep(as.numeric(split[which(split == "statcomp") + 1]), length(bayes_factors[[infileid]]$podid))
  bayes_factors[[infileid]]$pls <-
    rep(as.numeric(split[which(split == "pls") + 1]), length(bayes_factors[[infileid]]$podid))
  bayes_factors[[infileid]]$tolid <-
    rep(as.numeric(split[which(split == "tolid") + 1]), length(bayes_factors[[infileid]]$podid))
}
bayes_factors <- do.call(rbind.data.frame, bayes_factors)


# log
cat("aggregated data\n",
    file = LOG,
    append = T)


# save file
saveRDS(bayes_factors, snakemake@output$bayes_factors)


# log
cat(
  "saved RDS",
  snakemake@output$bayes_factors,
  "\n",
  file = LOG,
  append = T
)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
library(abc)
library(tidyr)
library(pbapply)

# save.image(file = "rdev.RData")
# stop("saved rdev.RData")
# setwd(
#   "/Users/struett/MPIPZ/netscratch-2/dep_tsiantis/grp_laurent/struett/git_tsABC/tsABC/"
# )
# load("rdev.RData")

# perform a model choice between two proposed models

# read data
sumstats <- read.csv(snakemake@input$sumstats, sep = "\t")
sumstats <-
  sumstats[, grep("LinearCombination", colnames(sumstats))]
alternative_sumstats <-
  read.csv(snakemake@input$alternative_sumstats, sep = "\t")
alternative_sumstats <-
  alternative_sumstats[, grep("LinearCombination", colnames(alternative_sumstats))]
podstats <- read.csv(snakemake@input$podstats, sep = "\t")
podparam <- podstats[, grep("param", colnames(podstats))]
podstats <-
  podstats[, grep("LinearCombination", colnames(podstats))]


# make podindex being same as in config file
podconfig <-
  matrix(
    data = as.numeric(snakemake@config$ABC$performance$pods),
    ncol = ncol(podparam)
  )
podindex <- apply(podparam, 1, function(x) {
  this_podindex <- numeric()
  for (i in 1:nrow(podconfig)) {
    if (all(podconfig[i, ] == x)) {
      this_podindex <- c(this_podindex, i)
    }
  }
  return(this_podindex)
})

# test if it is consistent with provided podid
podid <- read.table(snakemake@input$podid, header = F)[[1]]
stopifnot(all(podid+1 == podindex))

# read paremeter for inference
ntolerated <- as.numeric(snakemake@wildcards$tolid)
npls <- as.numeric(snakemake@wildcards$plsid)


# prepare data
sumstats$model = "transition"
alternative_sumstats$model = "constant_selfing"
sumstats_model_choice <-
  rbind.data.frame(sumstats, alternative_sumstats)
model_indices <- sumstats_model_choice$model
sumstats_model_choice$model <- NULL
sumstats_model_choice <-
  sumstats_model_choice[, which(colnames(sumstats_model_choice) %in% paste("LinearCombination", 1:npls -
                                                                             1, sep = "_"))]
podstats <-
  podstats[, which(colnames(podstats) %in% paste("LinearCombination", 1:npls - 1, sep = "_"))]


model_choice_result <- pbapply(podstats, 1, function(x) {
  a <- postpr(
    x,
    model_indices,
    sumstats_model_choice,
    tol = ntolerated / nrow(sumstats),
    method = "mnlogistic",
    corr = TRUE   # corr seems not to work
  )
  return(summary(a, rejection = T, print = F))
})

bayes_rejection <- numeric(length = length(model_choice_result))
bayes_mnlogistic <- numeric(length = length(model_choice_result))
for (i in 1:length(model_choice_result)) {
  res <- model_choice_result[[i]]
  bayes_rejection[i] <-
    model_choice_result[[i]]$rejection$BayesF["transition", "constant_selfing"]
  bayes_mnlogistic[i] <-
    model_choice_result[[i]]$mnlogistic$BayesF["transition", "constant_selfing"]
}
bayes_results <- list("podid" = podindex,
                      "rejection" = bayes_rejection,
                      "mnlogistic" = bayes_mnlogistic)

# save Bayes Factors RDS
saveRDS(bayes_results, file = snakemake@output$bayes_factors)


# plot
pdf(snakemake@output$bayes_plot)
plot(podindex, bayes_rejection, log = "y")
plot(podindex, bayes_mnlogistic, log = "y")
dev.off()
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
library(tidyverse)
library(modeest)


# logfile
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# read in files
max_tolid <- 0
parameter.estimations <- list()
for (infileid in 1:length(snakemake@input$estims)) {
  infile.name <- snakemake@input$estims[infileid]
  parameter.estimations[[infileid]] <- readRDS(infile.name)
  parameter.estimations[[infileid]]$filename <- infile.name
  split <-
    strsplit(infile.name, split = "_|\\.|/")[[1]]  # read wc from filename
  parameter.estimations[[infileid]]$plsid <-
    as.numeric(split[which(split == "pls") + 1])
  parameter.estimations[[infileid]]$tolid <-
    as.numeric(split[which(split == "tolid") + 1])
  max_tolid <-
    max(c(max_tolid, as.numeric(split[which(split == "tolid") + 1])))
  parameter.estimations[[infileid]]$statcomposition <-
    as.numeric(split[which(split == "statcomp") + 1])
}


# log
cat("read data\n",
    file = LOG,
    append = T)

# loop through all estimates
df_collector <- list()
for (parestid in 1:length(parameter.estimations)) {
  parest <- parameter.estimations[[parestid]]

  # get values that are not posteriors; afterwards we can loop through the
  # posteriors
  parest.filename <- parest$filename
  parest.plsid <- parest$plsid
  parest.tolid <- parest$tolid
  parest.statcomposition <- parest$statcomposition

  # remove from list; afterwards we can loop through
  parest$filename <- NULL
  parest$plsid <- NULL
  parest$tolid <- NULL
  parest$statcomposition <- NULL
  parest$prior <- NULL


  # extract true params
  for (parest.id in 1:length(parest)) {
    true_params <- parest[[parest.id]]$true_params
    podid <- parest[[parest.id]]$podid

    # loop through parameters and extract the posterior for each parameter
    for (param.name in colnames(true_params)) {
      post.rej <- parest[[parest.id]]$rej[, param.name]
      post.adj <- parest[[parest.id]]$adj[, param.name]

      if (parest.tolid != length(post.rej)) {
        post.rej <- c(post.rej, rep(NA, parest.tolid - length(post.rej)))
      }
      stopifnot(parest.tolid == length(post.rej))

      if (parest.tolid != length(post.adj)) {
        post.adj <- c(post.adj, rep(NA, parest.tolid - length(post.adj)))
      }
      stopifnot(parest.tolid == length(post.adj))


      df_0 <- data.frame(
        param = param.name,
        true_value = true_params[[param.name]],
        statcomposition = parest.statcomposition,
        pls = parest.plsid,
        tol = parest.tolid,
        regression = c("rej", "adj"),
        podid = podid
      )
      df_1 <-
        rbind.data.frame(c(post.rej, rep(NA, max_tolid - parest.tolid)),
                         c(post.adj, rep(NA, max_tolid - parest.tolid)))
      colnames(df_1) <- paste0("acc_", 1:parest.tolid)

      # collect df
      df_collector[[length(df_collector) + 1]] <-
        cbind.data.frame(df_0, df_1)

      # log
      cat(
        "collected data: ",
        parestid,
        "of",
        length(parameter.estimations),
        ";",
        parest.id,
        "of",
        length(parest),
        ";",
        param.name,
        "of",
        length(colnames(true_params)),
        ";",
        "\n",
        file = LOG,
        append = T
      )

      rm(df_0, df_1)
    }
  }
}
rm(parest, parameter.estimations, true_params)


# put into a single tibble
df <- do.call(rbind.data.frame, df_collector)

# log
cat("extracted and pasted data\n",
    file = LOG,
    append = T)


# add mode, mean, median
posteriors <- df %>%
  select(starts_with("acc_"))

df$mean <- apply(posteriors, 1, function(x) {
  return(mean(as.numeric(x), na.rm = TRUE))
})
df$mode <- apply(posteriors, 1, function(x) {
  return(mlv(x, method = "meanshift", na.rm = TRUE))
})
df$median <- apply(posteriors, 1, function(x) {
  return(median(as.numeric(x), na.rm = TRUE))
})

# log
cat(
  "calculated mean, mode, median of each posterior\n",
  file = LOG,
  append = T
)


# rearrange columns
df <- df %>% select(!starts_with("acc_"))
df <- cbind.data.frame(df,
                       posteriors)


# save to file
saveRDS(object = df, file = snakemake@output$estims)


# log
cat("saved RDS",
    snakemake@output$estims,
    "\n",
    file = LOG,
    append = T)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
library(locfit)
library("MASS")
library(abc)


# Obtain parameters for estimate
{
  OUTFILE_ESTIMS  <- snakemake@output$estims
  INFILE_ABC  <- snakemake@input$sumstats
  INFILE_POD  <- snakemake@input$podstats
  INFILE_PODID <- snakemake@input$podid
  NTOLERATED <-
    as.numeric(snakemake@params$ntol)
  STATCOMP_NO <- as.numeric(snakemake@params$statcomposition_no)
  STATCOMP <- snakemake@params$statcompositions[STATCOMP_NO]
  PLS_COMP <- as.numeric(snakemake@params$pls)
  PLOT_DIR <- snakemake@output$postplots
  REGRESSION <- snakemake@params$regression
  LOG <- snakemake@log$log1
}

cat("creating new log file\n", file = LOG, append = F)

# Read abc sim data and subset to pls
{
  # read in table
  df_abc <- data.frame(read.table(INFILE_ABC, header = T))


  # subset stats (remove parameters)
  all_sumstats_abc <-
    df_abc[, grep("LinearCombination", names(df_abc))]
  if (ncol(all_sumstats_abc) <  PLS_COMP) {
    cat(
      "  changed PLS set from ",
      PLS_COMP,
      " to ",
      ncol(all_sumstats_abc),
      "\n",
      file = LOG,
      append = T
    )
    PLS_COMP <- ncol(all_sumstats_abc)
  }
  sumstats_abc <-
    all_sumstats_abc[, paste("LinearCombination", 0:(PLS_COMP - 1), sep = "_")]


  # Get params and remove the columns that do not have variance
  params_abc_all <- df_abc[, grep("param", names(df_abc))]
  params_abc <- params_abc_all[, apply(params_abc_all, 2, var) != 0]

  cat("  read and prepared simulations",
      "\n",
      file = LOG,
      append = T)
}


# Read pod data
{
  # read in table
  df_pod <- data.frame(read.table(INFILE_POD, header = T))


  # read pod index
  podid <- read.csv(INFILE_PODID, header = FALSE)[[1]] + 1  # podid are 0-based


  # subset stats (remove parameters)
  all_sumstats_pods <-
    df_pod[, grep("LinearCombination", names(df_pod))]
  sumstats_pods <-
    all_sumstats_pods[paste("LinearCombination", 0:(PLS_COMP - 1), sep = "_")]

  # read pods_params
  pod_params <- df_pod[, grep("param_", names(df_pod))]

  cat("  read and prepared pods and pod_params",
      "\n",
      file = LOG,
      append = T)
}


# Produce logit boundaries from actual priors
{
  logit_boundaries <- t(apply(params_abc, 2, range))
  colnames(logit_boundaries) <- c("minimal", "maximal")
}

# results lists
{
  prior_and_posteriors <- list()
}


{
  # parameter preparation
  my_tolerance <- NTOLERATED / nrow(sumstats_abc)
}


# make estimates with library abc
{
  # make an estimate for each set of sumstats in the table (i. e. the samples)
  for (i in 1:nrow(sumstats_pods)) {
    # extract summary statistic to estimate from
    my_target <- as.numeric(sumstats_pods[i, ])

    # extract true parameters
    true_params <- pod_params[i, ]


    # create estimate
    flag <- TRUE
    tryCatch(
      expr = {
        abc_result <- abc(
          target = my_target,
          param = params_abc,
          sumstat = sumstats_abc,
          tol = my_tolerance,
          method = REGRESSION,
          MaxNWts = 5000,
          transf = "logit",
          logit.bounds = logit_boundaries
        )
      },
      error = function(e) {
        cat(
          "* Caught a plotting error on itertion ",
          i,
          "\n",
          file = LOG,
          append = T
        )
        flag <- FALSE
      }
    )
    if (!exists("abc_result"))
      next


    # clean the data from na's
    rejection_values <- na.omit(abc_result$unadj.values)
    adjusted_values <- na.omit(abc_result$adj.values)


    # save results into list
    prior_and_posteriors[[i]] <-
      list(rej = rejection_values,
           adj = adjusted_values,
           true_params = true_params,
           podid = podid[i]
           )



    cat(
      "  saved priors and posteriors to list: ",
      i,
      " of ",
      nrow(sumstats_pods),
      "\n",
      file = LOG,
      append = T
    )

    # plot diagnostics
    tryCatch(
      expr = {
        dir.create(PLOT_DIR, showWarnings = FALSE)
        pdf(paste0(PLOT_DIR, "/diagnostic_plots_", i, ".pdf", collapse = ""))
        plot(abc_result, param = params_abc, ask = F)
        dev.off()
      },
      error = function(e) {
        cat(
          "* Caught a plotting error on itertion ",
          i,
          "\n",
          file = LOG,
          append = T
        )
      }
    )
  }
}


# save results list
{
  # add prior to the result list
  prior_and_posteriors[["prior"]] <- params_abc

  saveRDS(prior_and_posteriors, file = OUTFILE_ESTIMS)
}
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
library(tidyverse)
library(modeest)


# logfile
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# read in files
max_tolid <- 0
parameter.estimations <- list()
for (infileid in 1:length(snakemake@input$estims)) {
  infile.name <- snakemake@input$estims[infileid]
  parameter.estimations[[infileid]] <- readRDS(infile.name)
  parameter.estimations[[infileid]]$filename <- infile.name
  split <-
    strsplit(infile.name, split = "_|\\.|/")[[1]]  # read wc from filename
  parameter.estimations[[infileid]]$plsid <-
    as.numeric(split[which(split == "pls") + 1])
  parameter.estimations[[infileid]]$tolid <-
    as.numeric(split[which(split == "tolid") + 1])
  max_tolid <-
    max(c(max_tolid, as.numeric(split[which(split == "tolid") + 1])))
  parameter.estimations[[infileid]]$statcomposition <-
    as.numeric(split[which(split == "statcomp") + 1])
  parameter.estimations[[infileid]]$allregions <- "allregions" %in% split
}


# log
cat("read data\n",
    file = LOG,
    append = T)

# loop through all estimates
df_collector <- list()
for (parestid in 1:length(parameter.estimations)) {
  parest <- parameter.estimations[[parestid]]

  # get values that are not posteriors; afterwards we can loop through the
  # posteriors
  parest.filename <- parest$filename
  parest.plsid <- parest$plsid
  parest.tolid <- parest$tolid
  parest.statcomposition <- parest$statcomposition
  parest.allregions <- parest$allregions

  # remove from list; afterwards we can loop through
  parest$filename <- NULL
  parest$plsid <- NULL
  parest$tolid <- NULL
  parest$statcomposition <- NULL
  parest$allregions <- NULL
  parest$prior <- NULL


  # extract true params
  for (parest.id in 1:length(parest)) {
    podid <- parest[[parest.id]]$obsid

    # loop through parameters and extract the posterior for each parameter
    for (param.name in colnames(parest[[parest.id]]$rej)) {
      post.rej <- parest[[parest.id]]$rej[, param.name]
      post.adj <- parest[[parest.id]]$adj[, param.name]

      if (parest.tolid != length(post.rej)) {
        post.rej <- c(post.rej, rep(NA, parest.tolid - length(post.rej)))
      }
      stopifnot(parest.tolid == length(post.rej))

      if (parest.tolid != length(post.adj)) {
        post.adj <- c(post.adj, rep(NA, parest.tolid - length(post.adj)))
      }
      stopifnot(parest.tolid == length(post.adj))


      df_0 <- data.frame(
        param = param.name,
        statcomposition = parest.statcomposition,
        pls = parest.plsid,
        tol = parest.tolid,
        allregions = parest.allregions,
        regression = c("rej", "adj"),
        podid = podid
      )
      df_1 <-
        rbind.data.frame(c(post.rej, rep(NA, max_tolid - parest.tolid)),
                         c(post.adj, rep(NA, max_tolid - parest.tolid)))
      colnames(df_1) <- paste0("acc_", 1:parest.tolid)

      # collect df
      df_collector[[length(df_collector) + 1]] <-
        cbind.data.frame(df_0, df_1)

      # log
      cat(
        "collected data: ",
        parestid,
        "of",
        length(parameter.estimations),
        ";",
        parest.id,
        "of",
        length(parest),
        ";",
        param.name,
        "of",
        length(colnames(parest[[parest.id]]$rej)),
        ";",
        "\n",
        file = LOG,
        append = T
      )

      rm(df_0, df_1)
    }
  }
}
rm(parest, parameter.estimations)


# put into a single tibble
df <- do.call(rbind.data.frame, df_collector)

# log
cat("extracted and pasted data\n",
    file = LOG,
    append = T)


# add mode, mean, median
posteriors <- df %>%
  select(starts_with("acc_"))

df$mean <- apply(posteriors, 1, function(x) {
  return(mean(as.numeric(x), na.rm = TRUE))
})
df$mode <- apply(posteriors, 1, function(x) {
  return(mlv(x, method = "meanshift", na.rm = TRUE))
})
df$median <- apply(posteriors, 1, function(x) {
  return(median(as.numeric(x), na.rm = TRUE))
})

# log
cat(
  "calculated mean, mode, median of each posterior\n",
  file = LOG,
  append = T
)


# rearrange columns
df <- df %>% select(!starts_with("acc_"))
df <- cbind.data.frame(df,
                       posteriors)


# save to file
saveRDS(object = df, file = snakemake@output$estims)


# log
cat("saved RDS",
    snakemake@output$estims,
    "\n",
    file = LOG,
    append = T)
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import re
import datetime
import numpy as np
import pandas as pd


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start aggregating observed sumstats",
        file=logfile,
    )


# read dataframes and obtain popid from filenames
obsstats_list = []
identifier = []
for obsstats_filename in snakemake.input.observations:
    # obtain population identifier
    split = re.split(r"_|\.|\/", obsstats_filename)
    popid = split[np.where(np.array(split) == "population")[0].max() + 1]

    # read in data
    pddf = pd.read_feather(obsstats_filename)

    # create the map
    identifier.extend([popid] * len(pddf.index))

    obsstats_list.append(pddf)


pd.concat(obsstats_list, ignore_index=True).to_feather(
    snakemake.output.sumstats, compression="lz4"
)


with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved dataframe to {snakemake.output.sumstats}", file=logfile)


pd.DataFrame(identifier, columns=["population"]).to_feather(
    snakemake.output.identifier, compression="lz4"
)


with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved identifier to {snakemake.output.identifier}", file=logfile)
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import sys
import datetime
import warnings
import itertools
import pickle
import math
import numpy as np
import pandas as pd
import tskit
import json
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start calculating observed summarizing statistics on Arabidopsis thaliana data",
        file=logfile,
    )


# read sample names
sample_names = np.loadtxt(snakemake.input.sample_list, dtype=str)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        f"read sample names from {snakemake.input.sample_list}",
        file=logfile,
    )


# read tree sequence
treeseq_athal = tskit.load(snakemake.input.athal_treeseq)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        f"read treeseq of Arabidopsis thaliana from '{snakemake.input.athal_treeseq}'",
        file=logfile,
    )


# find the node ids for the sample of the population
population_sample = []
for individual in treeseq_athal.individuals():
    if str(json.loads(individual.metadata)["id"]) in sample_names:
        population_sample.extend(individual.nodes)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created the node id sample set for the given population", file=logfile)


# sample treeseq to provided samples
treeseq_athal_population = treeseq_athal.simplify(samples=population_sample)
del treeseq_athal


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("simplified the tree sequence to the given population", file=logfile)


# get the chromosomal regions from the config file
chromosome_regions = []
for chromid, start, stop in snakemake.config["ABC"]["athaliana"]["observations"][
    "treeseq_1001"
]["whole_gemome_approach"]:
    start += snakemake.params.chrom_multiplier * chromid
    stop += snakemake.params.chrom_multiplier * chromid
    chromosome_regions.append((start, stop))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("prepared regions", file=logfile)


# chop down to regions
treeseq_list = []
for start, stop in chromosome_regions:
    # calculate chromid from position in original treeseq
    chromid = int(start / snakemake.params.chrom_multiplier)

    # add tuple with chromosome/treeseq, as chromid is needed for masking
    treeseq_list.append(
        (chromid, treeseq_athal_population.keep_intervals([(start, stop)]).trim())
    )

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(
            f"prepared regions for chromsome {chromid + 1} of {len(chromosome_regions)}",
            file=logfile,
        )

del treeseq_athal_population


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created region treeseq list (1 per region of each chromosome)", file=logfile)


# mask if asked; make sure the parameter is a boolean and then mask or not
do_mask = pyfuncs.check_masking_parameter(snakemake.params.masked)

# read and prepare mask files
if do_mask:
    # provide mask in parallel to the regions of the treeseq_list
    mask = []
    for treeseqid, (chromid, _) in enumerate(treeseq_list):
        this_mask = np.loadtxt(snakemake.input.maskfiles[chromid - 1], dtype=int)
        region_start, region_end = chromosome_regions[treeseqid]
        region_start -= chromid * snakemake.params.chrom_multiplier
        region_end -= chromid * snakemake.params.chrom_multiplier
        region_mask = this_mask[
            (region_start <= this_mask[:, 0]) & (region_end > this_mask[:, 1])
        ]
        this_mask = region_mask - region_start
        del region_mask

        # filter mask for disjoint intervals
        mask.append(
            pyfuncs.filter_mask_for_disjoint_intervals(
                this_mask, log=snakemake.log.log1
            )
        )

    # remove chromid from treeseq_list
    treeseq_list = [treeseq for _, treeseq in treeseq_list]

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(
            f"prepared coordinates to mask the exons for {len(mask)} regions",
            file=logfile,
        )

    treeseq_list_masked = []
    for treeid, (treeseq_chrom, mask_chrom) in enumerate(
        zip(treeseq_list, mask), start=1
    ):
        treeseq_list_masked.append(
            treeseq_chrom.delete_intervals(
                mask_chrom, simplify=True, record_provenance=True
            )
        )

        # log
        with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
            print(datetime.datetime.now(), end="\t", file=logfile)
            print(
                f"masked {treeid} of {len(mask)} chromosomes",
                file=logfile,
            )

    treeseq_list = treeseq_list_masked
    del treeseq_list_masked
else:
    # remove chromid from treeseq_list
    treeseq_list = [treeseq for _, treeseq in treeseq_list]

    # log, create log file
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print("will not mask for exons", file=logfile)


# create subsample from treesequence
rng = np.random.default_rng(snakemake.params.seed)
specs = {
    "num_observations": int(
        float(
            snakemake.config["ABC"]["athaliana"]["observations"]["treeseq_1001"][
                "whole_genome_approach_num_observations"
            ]
        )
    ),
    "nsam": int(float(snakemake.config["ABC"]["simulations"]["nsam"])),
}
tsl = pyfuncs.create_subsets_from_treeseqlist(
    treeseq_list, specs, rng, snakemake.log.log1
)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs

    dataframe_sumstats_sfs = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_sfs[treeid] = treeseq.allele_frequency_spectrum(
            sample_sets=None,
            windows=None,
            mode="site",
            span_normalise=False,
            polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
        )

        # log
        if not my_id % 100:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating SFS: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)

    # get average of sfs over the different loci
    dataframe_sumstats_sfs = dataframe_sumstats_sfs.mean(axis=1)

    dataframe_sumstats_sfs = np.concatenate(dataframe_sumstats_sfs, axis=0).reshape(
        (len(dataframe_sumstats_sfs), dataframe_sumstats_sfs[0].shape[0])
    )
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

        # log
        if not my_id % 5:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating LD: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # get average of ld over the different loci
    dataframe_sumstats_ld = dataframe_sumstats_ld.mean(axis=1)

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win

    dataframe_sumstats_tm_win = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_tm_win[treeid] = pyfuncs.calculate_tm_win(
            treeseq, specs, breaks_tm_win
        )

        # log
        if not my_id % 100:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating TM_WIN: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # get average of tm_win over the different loci
    dataframe_sumstats_tm_win = dataframe_sumstats_tm_win.mean(axis=1)
    dataframe_sumstats_tm_win = np.concatenate(
        dataframe_sumstats_tm_win, axis=0
    ).reshape((len(dataframe_sumstats_tm_win), dataframe_sumstats_tm_win[0].shape[0]))
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


pd.DataFrame(dataframe_sumstats, columns=sumstat_names).to_feather(
    snakemake.output.sumstats, compression="lz4"
)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.sumstats}", file=logfile)
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import sys
import datetime
import warnings
import itertools
import pickle
import math
import numpy as np
import pandas as pd
import tskit
import json
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "start calculating observed summarizing statistics on Arabidopsis thaliana data",
        file=logfile,
    )


# read sample names
sample_names = np.loadtxt(snakemake.input.sample_list, dtype=str)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        f"read sample names from {snakemake.input.sample_list}",
        file=logfile,
    )


# read tree sequence
treeseq_athal = tskit.load(snakemake.input.athal_treeseq)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        f"read treeseq of Arabidopsis thaliana from '{snakemake.input.athal_treeseq}'",
        file=logfile,
    )


# find the node ids for the sample of the population
population_sample = []
for individual in treeseq_athal.individuals():
    if str(json.loads(individual.metadata)["id"]) in sample_names:
        population_sample.extend(individual.nodes)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created the node id sample set for the given population", file=logfile)


# sample treeseq to provided samples
treeseq_athal_population = treeseq_athal.simplify(samples=population_sample)
del treeseq_athal


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("simplified the tree sequence to the given population", file=logfile)


# get the chromosomal regions from the config file
chromosome_regions = []
for chromid, (start, stop) in enumerate(
    snakemake.config["ABC"]["athaliana"]["observations"]["treeseq_1001"][
        "chosen_region"
    ],
    start=1,
):
    start += snakemake.params.chrom_multiplier * chromid
    stop += snakemake.params.chrom_multiplier * chromid
    chromosome_regions.append((start, stop))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("prepared regions", file=logfile)


# chop down to regions
treeseq_list = []
for chromid, (start, stop) in enumerate(chromosome_regions):
    treeseq_list.append(treeseq_athal_population.keep_intervals([(start, stop)]).trim())

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(
            f"prepared regions for chromsome {chromid + 1} of {len(chromosome_regions)}",
            file=logfile,
        )

del treeseq_athal_population


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("created region treeseq list (1 per region of each chromosome)", file=logfile)


# mask if asked; make sure the parameter is a boolean and then mask or not
do_mask = pyfuncs.check_masking_parameter(snakemake.params.masked)

# read and prepare mask files
if do_mask:
    mask = []
    for region_id, maskfile in enumerate(snakemake.input.maskfiles):
        this_mask = np.loadtxt(maskfile, dtype=int)
        region_start, region_end = snakemake.config["ABC"]["athaliana"]["observations"][
            "treeseq_1001"
        ]["chosen_region"][region_id]
        region_mask = this_mask[
            (region_start <= this_mask[:, 0]) & (region_end > this_mask[:, 1])
        ]
        this_mask = region_mask - region_start
        del region_mask

        # filter mask for disjoint intervals
        mask.append(
            pyfuncs.filter_mask_for_disjoint_intervals(
                this_mask, log=snakemake.log.log1
            )
        )

    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(
            f"prepared coordinates to mask the exons for {len(mask)} chromosomes",
            file=logfile,
        )

    treeseq_list_masked = []
    for treeid, (treeseq_chrom, mask_chrom) in enumerate(
        zip(treeseq_list, mask), start=1
    ):
        treeseq_list_masked.append(
            treeseq_chrom.delete_intervals(
                mask_chrom, simplify=True, record_provenance=True
            )
        )

        # log
        with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
            print(datetime.datetime.now(), end="\t", file=logfile)
            print(
                f"masked {treeid} of {len(mask)} chromosomes",
                file=logfile,
            )

    treeseq_list = treeseq_list_masked
    del treeseq_list_masked
else:
    # log, create log file
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print("will not mask for exons", file=logfile)


# create subsample from treesequence
rng = np.random.default_rng(snakemake.params.seed)
specs = {
    "num_observations": int(
        float(snakemake.config["ABC"]["athaliana"]["observations"]["num_observations"])
    ),
    "nsam": int(float(snakemake.config["ABC"]["simulations"]["nsam"])),
}
tsl = pyfuncs.create_subsets_from_treeseqlist(
    treeseq_list, specs, rng, snakemake.log.log1
)


# read the breakpoint files
for breaks_filename in snakemake.input.breakpoints:
    if "SFS" in breaks_filename:
        with warnings.catch_warnings():  # will prevent warning if file is empty
            warnings.simplefilter("ignore")
            breaks_sfs = np.loadtxt(breaks_filename, dtype=float)
    elif "LD" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_ld = np.loadtxt(breaks_filename, dtype=float)
    elif "TM_WIN" in breaks_filename:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            breaks_tm_win = np.loadtxt(breaks_filename, dtype=float)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("read the breaks for binning", file=logfile)


# which sumstats to calculate
listed_sumstats = set(
    itertools.chain.from_iterable(
        [sumstat_set.split("/") for sumstat_set in snakemake.params.sumstats]
    )
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculating summarizing statistics",
        file=logfile,
    )

# calculate the SFS
if "SFS" in listed_sumstats:
    breaks = breaks_sfs

    dataframe_sumstats_sfs = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_sfs[treeid] = treeseq.allele_frequency_spectrum(
            sample_sets=None,
            windows=None,
            mode="site",
            span_normalise=False,
            polarised=snakemake.config["ABC"]["sumstats_specs"]["SFS"]["polarised"],
        )

        # log
        if not my_id % 100:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating SFS: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # check if the sfs shall be discretized as well; this usually is not
    # necessary as the sfs is a discrete statistic already
    if len(breaks):
        # for each sfs sum up the values in between the bin_edges (=breaks)
        assert (
            len(breaks) >= 2
        ), "expect closed interval for bin_edges; at least 2 breaks must be provided"
        new_sfs_dataframe = []
        for sfs in dataframe_sumstats_sfs:
            sfs_vals = []
            for breakid in range(1, len(breaks)):
                low = int(breaks[breakid - 1])
                high = breaks[breakid]
                if np.isinf(high):
                    high = len(sfs) + 1
                else:
                    high = math.ceil(high)
                sfs_vals.append(sum(sfs[low:high]))
            new_sfs_dataframe.append(sfs_vals)

        # reassign to the dataframe sumstat
        dataframe_sumstats_sfs = np.array(new_sfs_dataframe)


    # get average of sfs over the different loci
    dataframe_sumstats_sfs = dataframe_sumstats_sfs.mean(axis=1)

    dataframe_sumstats_sfs = np.concatenate(dataframe_sumstats_sfs, axis=0).reshape(
        (len(dataframe_sumstats_sfs), dataframe_sumstats_sfs[0].shape[0])
    )
else:
    dataframe_sumstats_sfs = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated SFS statistics",
        file=logfile,
    )


# calculate the LD
if "LD" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["LD"]
    breaks = breaks_ld

    dataframe_sumstats_ld = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_ld[treeid] = pyfuncs.calculate_ld(
            treeseq, specs, breaks, rng, snakemake.log.log1
        )

        # log
        if not my_id % 100:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating LD: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # get average of ld over the different loci
    dataframe_sumstats_ld = dataframe_sumstats_ld.mean(axis=1)

    dataframe_sumstats_ld = np.concatenate(dataframe_sumstats_ld, axis=0).reshape(
        (len(dataframe_sumstats_ld), dataframe_sumstats_ld[0].shape[0])
    )
else:
    dataframe_sumstats_ld = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated LD statistics",
        file=logfile,
    )


# calculate the TM_WIN
if "TM_WIN" in listed_sumstats:
    specs = snakemake.config["ABC"]["sumstats_specs"]["TM_WIN"]
    breaks = breaks_tm_win

    dataframe_sumstats_tm_win = np.empty(tsl.shape, dtype=tskit.TreeSequence)
    for my_id, (treeid, treeseq) in enumerate(np.ndenumerate(tsl), start=1):
        dataframe_sumstats_tm_win[treeid] = pyfuncs.calculate_tm_win(
            treeseq, specs, breaks_tm_win
        )

        # log
        if not my_id % 100:
            with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
                print(datetime.datetime.now(), end="\t", file=logfile)
                print(
                    f"calculating TM_WIN: {my_id} of {tsl.size}",
                    file=logfile,
                )

    # get average of tm_win over the different loci
    dataframe_sumstats_tm_win = dataframe_sumstats_tm_win.mean(axis=1)
    dataframe_sumstats_tm_win = np.concatenate(
        dataframe_sumstats_tm_win, axis=0
    ).reshape((len(dataframe_sumstats_tm_win), dataframe_sumstats_tm_win[0].shape[0]))
else:
    dataframe_sumstats_tm_win = np.empty(shape=(0, 0))


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(
        "calculated TM_WIN statistics",
        file=logfile,
    )


# fuse the summarizing stats list into a 2d-np.array, each row containing the
# summarizing stats of a single simulated tree sequence
dataframe_sumstats = np.concatenate(
    [
        npy
        for npy in (
            dataframe_sumstats_sfs,
            dataframe_sumstats_ld,
            dataframe_sumstats_tm_win,
        )
        if npy.size != 0
    ],
    axis=1,
)


# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("calculated summarizing statisitcs", file=logfile)


# save the names of the sumstats
sumstat_type = (
    ["sfs"] * dataframe_sumstats_sfs.shape[1]
    + ["ld"] * dataframe_sumstats_ld.shape[1]
    + ["tm_win"] * dataframe_sumstats_tm_win.shape[1]
)
sumstat_type_id = [
    str(element)
    for element in list(range(dataframe_sumstats_sfs.shape[1]))
    + list(range(dataframe_sumstats_ld.shape[1]))
    + list(range(dataframe_sumstats_tm_win.shape[1]))
]

sumstat_names = np.array(list(map("_".join, zip(sumstat_type, sumstat_type_id))))


pd.DataFrame(dataframe_sumstats, columns=sumstat_names).to_feather(
    snakemake.output.sumstats, compression="lz4"
)

# log
with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print(f"saved sumstats to {snakemake.output.sumstats}", file=logfile)
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
library(locfit)
library("MASS")
library(abc)
library(arrow)


# Obtain parameters for estimate
{
  OUTFILE_ESTIMS  <- snakemake@output$estims
  INFILE_ABC  <- snakemake@input$sumstats
  INFILE_OBS  <- snakemake@input$observed
  INFILE_ID <- snakemake@input$identifier
  NTOLERATED <-
    as.numeric(snakemake@params$ntol)
  PLS_COMP <- as.numeric(snakemake@params$pls)
  PLOT_DIR <- snakemake@output$postplots
  REGRESSION <- snakemake@params$regression
  LOG <- snakemake@log$log1
}

cat("creating new log file\n", file = LOG, append = F)

# Read abc sim data and subset to pls
{
  # read in table
  df_abc <- data.frame(read.table(INFILE_ABC, header = T))


  # subset stats (remove parameters)
  all_sumstats_abc <-
    df_abc[, grep("LinearCombination", names(df_abc))]
  if (ncol(all_sumstats_abc) <  PLS_COMP) {
    cat(
      "  changed PLS set from ",
      PLS_COMP,
      " to ",
      ncol(all_sumstats_abc),
      "\n",
      file = LOG,
      append = T
    )
    PLS_COMP <- ncol(all_sumstats_abc)
  }
  sumstats_abc <-
    all_sumstats_abc[, paste("LinearCombination", 0:(PLS_COMP - 1), sep = "_")]


  # Get params and remove the columns that do not have variance
  params_abc_all <- df_abc[, grep("param", names(df_abc))]
  params_abc <- params_abc_all[, apply(params_abc_all, 2, var) != 0]

  cat("  read and prepared simulations",
      "\n",
      file = LOG,
      append = T)
}


# Read obs data
{
  # read in table
  df_obs <- data.frame(read.table(INFILE_OBS, header = T))


  # read obs index
  obsid <- read_feather(INFILE_ID)$population


  # subset stats (remove parameters)
  all_sumstats_obs <-
    df_obs[, grep("LinearCombination", names(df_obs))]
  sumstats_obs <-
    all_sumstats_obs[paste("LinearCombination", 0:(PLS_COMP - 1), sep = "_")]

  # read obs_params
  obs_params <- df_obs[, grep("param_", names(df_obs))]

  cat("  read and prepared obs and obs_params",
      "\n",
      file = LOG,
      append = T)
}


# Produce logit boundaries from actual priors
{
  logit_boundaries <- t(apply(params_abc, 2, range))
  colnames(logit_boundaries) <- c("minimal", "maximal")
}

# results lists
{
  prior_and_posteriors <- list()
}


{
  # parameter preparation
  my_tolerance <- NTOLERATED / nrow(sumstats_abc)
}


# make estimates with library abc
{
  # make an estimate for each set of sumstats in the table (i. e. the samples)
  for (i in 1:nrow(sumstats_obs)) {
    # extract summary statistic to estimate from
    my_target <- as.numeric(sumstats_obs[i, ])

    # extract true parameters
    true_params <- obs_params[i, ]


    # create estimate
    flag <- TRUE
    tryCatch(
      expr = {
        abc_result <- abc(
          target = my_target,
          param = params_abc,
          sumstat = sumstats_abc,
          tol = my_tolerance,
          method = REGRESSION,
          MaxNWts = 5000,
          transf = "logit",
          logit.bounds = logit_boundaries
        )
      },
      error = function(e) {
        cat(
          "* Caught an estimation error on itertion ",
          i,
          "\n",
          file = LOG,
          append = T
        )
        flag <- FALSE
      }
    )
    if (!exists("abc_result"))
      next


    # clean the data from na's
    rejection_values <- na.omit(abc_result$unadj.values)
    adjusted_values <- na.omit(abc_result$adj.values)


    # save results into list
    prior_and_posteriors[[i]] <-
      list(rej = rejection_values,
           adj = adjusted_values,
           true_params = true_params,
           obsid = obsid[i]
           )


    cat(
      "  saved priors and posteriors to list: ",
      i,
      " of ",
      nrow(sumstats_obs),
      "\n",
      file = LOG,
      append = T
    )


    # plot diagnostics
    tryCatch(
      expr = {
        dir.create(PLOT_DIR, showWarnings = FALSE)
        pdf(paste0(PLOT_DIR, "/diagnostic_plots_", i, ".pdf", collapse = ""))
        plot(abc_result, param = params_abc, ask = F)
        dev.off()
      },
      error = function(e) {
        cat(
          "* Caught a plotting error on itertion ",
          i,
          "\n",
          file = LOG,
          append = T
        )
      }
    )
  }
}


# save results list
{
  # add prior to the result list
  prior_and_posteriors[["prior"]] <- params_abc

  saveRDS(prior_and_posteriors, file = OUTFILE_ESTIMS)
}
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
import datetime
import pickle
import numpy as np
import tskit
import msprime
import pyfuncs  # from file


# log, create log file
with open(snakemake.log.log1, "w", encoding="utf-8") as logfile:
    print(datetime.datetime.now(), end="\t", file=logfile)
    print("start simulating tree sequences", file=logfile)


def main(simid):
    """Heart of this script

    The main function will create a single tree sequence. The main function has
    to be repeatedly executed until the list of treesequences has been created.
    """
    # provide the correct seed to create the random number generator
    seed = np.load(snakemake.input.npy)[simid][int(float(snakemake.wildcards.locid))]
    seed_params = np.load(snakemake.input.npy)[simid][
        0
    ]  # the parameters for the independent loci must be the same
    rng = np.random.default_rng(seed)
    rng_params = np.random.default_rng(seed_params)

    # draw the parameters of the model
    params = [
        pyfuncs.draw_parameter_from_prior(prior_definition, rng_params)
        for prior_definition in snakemake.params.model["priors"]
    ]
    del (
        seed_params,
        rng_params,
    )  # only the parameter drawing relies on the same seed, the simulations must be independent


    # simulate
    ts = pyfuncs.simulate_treesequence_under_six_parameter_model(
        params, snakemake.params, rng, snakemake.log.log1
    )


    return ts, params


# define the loop for the clustering of consequential simulations
start = int(float(snakemake.wildcards.simid))
end = min(
    start + int(float(snakemake.config["ABC"]["jobluster_simulations"])),
    int(float(snakemake.config["ABC"]["simulations"]["nsim"])),
)

assert start < end, "simid (Simulation IDs) maldefined, you will not simulate anything"

# run the simulations
treesequence_list = []
parameters_list = []
for simid in range(start, end):
    # log
    with open(snakemake.log.log1, "a", encoding="utf-8") as logfile:
        print(datetime.datetime.now(), end="\t", file=logfile)
        print(f"creating treesequence no {simid}", file=logfile)

    treeseq, params = main(simid)
    treesequence_list.append(treeseq)
    parameters_list.append(params)


# save treeseq list
with open(snakemake.output.tsl, "wb") as tsl_file:
    pickle.dump(treesequence_list, tsl_file)


# save parameters list
np.save(snakemake.output.params, np.array(parameters_list))
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
library(ggplot2)
library(cowplot)
library(tidyverse)
library(wesanderson)


# save.image(file = "rdev.RData")
# stop("saved rdev.RData")
# setwd("/Users/struett/MPIPZ/netscratch-2/dep_tsiantis/grp_laurent/struett/git_tsABC/tsABC")
# load(file="rdev.RData")


# set theme
theme_set(theme_cowplot())


# helper functions and values
discretize_bayes_factor <- function(bf, do.factor = T) {
  if (bf <= 10 ** 0)
    dbf = "Negative"
  else if (bf <= 10 ** 0.5)
    dbf = "Barely worth mentioning"
  else if (bf <= 10 ** 1)
    dbf = "Substantial"
  else if (bf <= 10 ** 1.5)
    dbf = "Strong"
  else if (bf <= 10 ** 2)
    dbf = "Very strong"
  else if (bf > 10 ** 2)
    dbf =  "Decisive"
  else
    stop(paste0(c("unknown bayes: ", as.character(bf)), collapse = ""))

  dbf <- factor(
    dbf,
    levels = c(
      "Negative",
      "Barely worth mentioning",
      "Substantial",
      "Strong",
      "Very strong",
      "Decisive"
    )
  )

  if (do.factor) {
    return(dbf)
  } else {
    return(as.numeric(dbf) - 1)
  }
}
logbreak <- sapply(10 ** (-10:10), function(x) {
  x * (1:10)
}) %>% unique() %>% sort()
loglabel <- sapply(logbreak, function(x) {
  if (x %in% 10 ** (-10:10))
    return(as.character(log10(x)))
  else
    return("")
})


# log
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# create color palette
mcol <- wes_palette("Cavalcanti1")[c(5, 2)]
mcol <- c(mcol[1], "gray90", mcol[2])
colfunc <- colorRampPalette(mcol)
mcol <- colfunc(9)[c(4:9)]
mcol <- colfunc(9)[c(1, 5:9)]


# log
cat("creating color palette\n", file = LOG, append = T)


# read in data
df <- readRDS(snakemake@input$model_choice) %>%
  tibble() %>%
  pivot_longer(
    cols = c("rejection", "mnlogistic"),
    names_to = "regression_method",
    values_to = "bf"
  )

# log
cat("read data\n", file = LOG, append = T)


# find tsigma per podid
tsigma <- snakemake@params$tsigma_per_podid %>% as.numeric()
df$tsigma <- sapply(df$podid, function(x)
  return(tsigma[x]))

# discretize bf
df$bf_discrete <- sapply(df$bf, discretize_bayes_factor)

# log
cat("discretized bayes factor\n",
    file = LOG,
    append = T)

# create empty plot list
plot_list <- list()
plot_index <- 0


# make as bar plot per podid
plot_index <- plot_index + 1
plot_list[[plot_index]] <- df %>%
  ggplot(aes(podid, fill = bf_discrete)) +
  geom_bar(position = "fill") +
  facet_grid(statcomposition ~ regression_method + pls) +
  scale_fill_manual(values = mcol) +
  theme(axis.text.x = element_text(angle = 60, hjust = 1))


# calculate percentage of discrete bayes factor
df <- df %>%
  group_by(podid, statcomposition, pls, tolid, regression_method, tsigma) %>%
  mutate(N = n()) %>%
  group_by(podid,
           statcomposition,
           pls,
           tolid,
           regression_method,
           tsigma,
           bf_discrete) %>%
  mutate(n = n()) %>%
  mutate(bf_proportion = n / N) %>%
  ungroup() %>%
  select(-c(bf, N, n)) %>%
  distinct()


# check if they sum up to 1
a <- df %>%
  group_by(podid, statcomposition, pls, tolid, regression_method, tsigma) %>%
  summarise(total_proportion = sum(bf_proportion))
stopifnot(all(1-a$total_proportion < 1e-6))  # sometimes numerical issues occur


# remove column
df <- df %>% select(-podid)  # no podid needed as we use tsigma


# add zeros for proper plotting
zcounter = 0
for (a in unique(df$statcomposition)) {
  for (b in unique(df$regression_method)) {
    for (c in unique(df$tsigma)) {
      for (d in unique(df$bf_discrete)) {
        for (e in unique(df$tolid)) {
          for (f in unique(df$pls)) {
            sdf <- df %>%
              subset(statcomposition == a) %>%
              subset(regression_method == b) %>%
              subset(tsigma == c) %>%
              subset(bf_discrete == d)
            n <-  sdf %>%
              nrow()
            if (n == 0) {
              zcounter = zcounter + 1
              df = rbind.data.frame(
                df,
                data.frame(
                  statcomposition = a,
                  pls = f,
                  tolid = e,
                  regression_method = b,
                  tsigma = c,
                  bf_discrete = d,
                  bf_proportion = 0
                )
              )
            }

          }
        }
      }
    }
  }
}


# log
cat(
  "calculated proportion per group and added zeros\n",
  file = LOG,
  append = T
)
cat("there were",
    zcounter,
    " zeros to be added\n",
    file = LOG,
    append = T)


# create plot as proportional area plot over time
for (a in unique(df$statcomposition)) {
  for (b in unique(df$pls)) {
    for (c in unique(df$tolid)) {
      for (d in unique(df$regression_method)) {
        sdf <- df %>%
          subset(statcomposition == a) %>%
          subset(pls == b) %>%
          subset(tolid == c) %>%
          subset(regression_method == d)

        if (nrow(sdf) > 0) {
          plot_index <- plot_index + 1
          plot_list[[plot_index]] <- sdf %>%
            ggplot(aes(
              x = tsigma,
              y = bf_proportion * 100,
              fill = bf_discrete
            )) +
            geom_area() +
            geom_hline(yintercept = 5,
                       col = "white",
                       size = 0.1) +
            geom_hline(yintercept = 20,
                       col = "white",
                       size = 0.1) +
            geom_hline(yintercept = 50,
                       col = "white",
                       size = 0.1) +
            geom_hline(yintercept = 80,
                       col = "white",
                       size = 0.1) +
            geom_hline(yintercept = 95,
                       col = "white",
                       size = 0.1) +
            geom_vline(xintercept = 200000,
                       col = "white",
                       size = 0.1) +
            # facet_grid(regression_method+pls~statcomposition)+
            scale_x_continuous(trans = "log",
                               breaks = logbreak,
                               labels = loglabel) +
            scale_y_continuous(breaks = c(5, 95),
                               labels = c("0%", "100%")) +
            scale_fill_manual(values = mcol) +
            theme(
              # axis.text.x = element_text(angle = 60, hjust = 1),
              aspect.ratio = 1,
              strip.background = element_blank(),
              axis.title.y = element_blank(),
              # axis.text = element_blank(),
              axis.line = element_blank(),
              axis.ticks.y = element_blank(),
              axis.ticks.length.x = unit(-0.1, "lines"),
              # panel.spacing = unit(-0.9, "lines"),
              # panel.border = element_rect(colour = "black")
            ) +
            labs(
              x = "time after transition [log10]",
              fill = "model support",
              title = paste0(
                "statcomposition: ",
                a,
                "\npls: ",
                b,
                "\ntolid: ",
                c,
                "\nregression_method: ",
                d,
                collapse = ""
              )
            )
        }
      }
    }
  }
}




# log
cat("created plots\n",
    file = LOG,
    append = T)


# log
cat("start printing to file\n", file = LOG, append = T)


# print to file
pdf(snakemake@output$pdf)
for (my_plot in 1:length(plot_list)) {
  show(plot_list[[my_plot]])

  # log
  cat(
    "successfully printed plot to file:",
    my_plot,
    "of",
    length(plot_list),
    "\n",
    file = LOG,
    append = T
  )
}
dev.off()

# log
cat("successfully printed plots to file\n",
    file = LOG,
    append = T)
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
library(ggplot2)
library(cowplot)
library(tidyverse)
library(wesanderson)


# set theme
theme_set(theme_cowplot())


# log
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# read parameter estimation file
data <- readRDS(snakemake@input$estimations)


# get the variable parameter
# log
cat(
  "for this plotting we make the assumption that only param_3 (t_sigma) is varying\n",
  file = LOG,
  append = T
)
cat(
  "script will run also otherwise, but plots won't be useful\n",
  file = LOG,
  append = T
)


# define interquantile ranges
interquantile.ranges <-
  sort(unique(as.numeric(as.character(
    snakemake@params$iqr
  ))))


# create color palette
mcol <- wes_palette("Darjeeling1")[1:2]
mcol <- c(mcol[1], mcol[2])
colfunc <- colorRampPalette(mcol)
mcol <- colfunc(length(interquantile.ranges))
if (0 %in% interquantile.ranges) {
  mcol <- c(mcol, rev(mcol)[2:length(mcol)])
} else {
  mcol <- c(mcol, rev(mcol))
}



# helper functions
calc_quants <-
  function(posterior.data.frame,
           interquantile.ranges = c(0.99, 0.95, 0.9, 0.8, 0.5, 0.25, 0.1, 0)) {
    # Computes the quantiles row wise
    #
    # Args:
    #   posterior.data.frame: data.frame; should contain a single explicit
    #     distribution per row
    #   interquantile.ranges: numeric vector; which interquantiles to calculate,
    #     e. g. 0.99 calculates the 0.005-th and 0.995-th quantiles; e. g. 0
    #     calculates the mode
    #
    # Returns:
    #   data.frame with quantiles
    iqr <- sort(interquantile.ranges)
    stopifnot(all(iqr >= 0))

    qr <- sort(unlist(sapply(iqr, function(x) {
      return(unique(c(0.5 - x / 2, 0.5 + x / 2)))
    })))

    apply.f <- function(x) {
      quantile(x, probs = qr, na.rm = T)
    }

    return(t(apply(posterior.data.frame, 1, apply.f)))
  }
logbreak <- sapply(10 ** (-10:10), function(x) {
  x * (1:10)
}) %>% unique() %>% sort()
loglabel <- sapply(logbreak, function(x) {
  if (x %in% 10 ** (-10:10))
    return(as.character(log10(x)))
  else
    return("")
})


# calculate quantiles of posteriors; remove accepted posteriors
data.quantiles <-
  calc_quants(data %>% select(starts_with("acc_")), interquantile.ranges = interquantile.ranges)
data.index <- data %>% select(!starts_with("acc_"))
data <- cbind.data.frame(data.index,
                         data.quantiles) %>% tibble()


# find tsigma per podid
tsigma <- snakemake@params$tsigma_per_podid %>% as.numeric()
data$tsigma <- sapply(data$podid, function(x)
  return(tsigma[x]))


# create empty plot list
plot_list <- list()
plot_index <- 0


# create plots
msize <- 2


for (a in unique(data$param)) {
  for (b in unique(data$statcomposition)) {
    for (c in unique(data$pls)) {
      for (d in unique(data$tol)) {
        for (e in unique(data$regression)) {
          sdf <- data %>%
            subset(param == a) %>%
            subset(statcomposition == b) %>%
            subset(pls == c) %>%
            subset(tol == d) %>%
            subset(regression == e)

          if (nrow(sdf) > 0) {
            sdf <- sdf %>%
              pivot_longer(
                -c(
                  param,
                  true_value,
                  statcomposition,
                  pls,
                  tol,
                  regression,
                  podid,
                  mean,
                  mode,
                  median,
                  tsigma
                ),
                names_to = "quantile",
                values_to = "value"
              ) %>%
              group_by(
                param,
                true_value,
                statcomposition,
                pls,
                tol,
                regression,
                podid,
                quantile,
                tsigma
              ) %>%
              summarise(mean.iqr = mean(value))

            sdf$quantile <- as.numeric(sub("%", "", sdf$quantile))
            mlevels <- sort(unique(sdf$quantile))
            sdf$quantile <- factor(sdf$quantile, levels = mlevels)

            mxlim <- range(sdf$tsigma)
            mylim <- c(NA, NA)

            plot_index <- plot_index + 1
            plot_list[[plot_index]] <- sdf %>%
              ggplot(aes(
                x = tsigma,
                y = mean.iqr,
                col = quantile
              )) +
              geom_line(aes(tsigma, true_value),
                        col = "black",
                        size = msize) +
              geom_line(show.legend = F, size = msize) +
              # geom_point(position = position_jitter(height = 0, width = 0.01), alpha=0.02, show.legend = FALSE)+
              # geom_text() %>%
              # facet_grid(plsComp ~ sscomp) +
              scale_x_continuous(
                trans = "log10",
                limits = mxlim,
                breaks = logbreak,
                labels = loglabel
              ) +
              scale_color_manual(values = mcol) +
              theme(
                legend.position = "none",
                aspect.ratio = 1,
                panel.border = element_rect(colour = "black", size = msize),
                # text = element_text(size = 12),
                #   strip.background = element_blank(),
                # axis.title.y = element_blank(),
                # axis.text = element_blank(),
                # axis.text.y = element_blank(),
                axis.line = element_blank(),
                axis.title = element_blank(),
                axis.ticks = element_line(size = msize * 0.8),
                # axis.ticks.y = element_blank(),
                # axis.ticks.x = element_line(size = 1),
                # axis.ticks.length.x = unit(-1, "lines"),
                # panel.spacing = unit(-0.9, "lines")
              ) +
              labs(
                title = paste0(
                  "param: ",
                  a,
                  "\nstatcomposition: ",
                  b,
                  "\npls: ",
                  c,
                  "\ntol: ",
                  d,
                  "\nregression_method: ",
                  e,
                  collapse = ""
                )
              )

            # make log scale for some parameters
            if (a %in% c("param_0", "param_3"))
              plot_list[[plot_index]] <- plot_list[[plot_index]] +
              scale_y_continuous(
                trans = "log10",
                limits = mylim,
                breaks = logbreak,
                labels = loglabel
              )
          }
        }
      }
    }
  }
}

# log
cat("created plots\n",
    file = LOG,
    append = T)


# log
cat("start printing to file\n", file = LOG, append = T)


# print to file
pdf(snakemake@output$pdf)
for (my_plot in 1:length(plot_list)) {
  show(plot_list[[my_plot]])

  # log
  cat(
    "successfully printed plot to file:",
    my_plot,
    "of",
    length(plot_list),
    "\n",
    file = LOG,
    append = T
  )
}
dev.off()

# log
cat("successfully printed plots to file\n",
    file = LOG,
    append = T)
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
library(ggplot2)
library(cowplot)
library(wesanderson)
library(tidyverse)


# theme set
theme_set(theme_cowplot())
logbreak <- sapply(10 ** (-10:10), function(x) {
  x * (1:10)
}) %>% unique() %>% sort()
loglabel <- sapply(logbreak, function(x) {
  if (x %in% 10 ** (-10:10))
    return(as.character(log10(x)))
  else
    return("")
})


# small functions
find_max_pls  <- function(my_cols) {
  colslist <- strsplit(my_cols, "LinearCombination_")
  plsno <- integer(length = length(colslist))
  for (colid in 1:length(colslist)) {
    plsno[colid] <- as.numeric(colslist[[colid]][2])
  }
  return(max(plsno))
}


# log
LOG <- snakemake@log$log1
cat("creating new log file\n", file = LOG, append = F)


# read data; pseudo-observed data sets
pod.stats <- list()
for (infileid in 1:length(snakemake@input$transformed)) {
  infile.name <- snakemake@input$transformed[infileid]
  df_raw <- read.csv(infile.name, sep = "\t")


  # select all cols that arent stats and add identifiers
  nostat <- df_raw %>%
    select(!starts_with("LinearCombination_"))
  split <-
    strsplit(infile.name, split = "_|\\.|/")[[1]]  # read wc from filename
  nostat$statcomposition <-
    as.numeric(split[which(split == "statcomp") + 1])

  # add pod index
  nostat$podid <-
    as.numeric(interaction(nostat[, grep("^param_", colnames(nostat))]))


  # collect all cols that are stats and reduce or extend to 20 columns
  stats <- df_raw %>%
    select(starts_with("LinearCombination_"))
  if (ncol(stats) > 20) {
    wanted_cols <- paste("LinearCombination", 1:20 - 1, sep = "_")
    colindexes <- integer(length = 20)
    for (mycolid in 1:20) {
      mycolname <- wanted_cols[mycolid]
      colindexes[mycolid] <- which(mycolname == colnames(stats))
    }
    stats <- stats[, colindexes]

    # log
    cat("read data and reduced stats to 20 columns\n",
        file = LOG,
        append = T)

  } else if (ncol(stats) < 20) {
    max_pls <- find_max_pls(colnames(stats))
    extender <-
      setNames(data.frame(# LinearCombinations are zero-based
        matrix(
          ncol = 20 - max_pls - 1, nrow = nrow(stats)
        )),
        paste("LinearCombination", (max_pls + 1):(20 - 1), sep = "_"))
    stats <- cbind.data.frame(stats, extender)

    # log
    cat("read data and extended stats to 20 columns\n",
        file = LOG,
        append = T)
  }
  stopifnot(ncol(stats) == 20)


  # read into list
  pod.stats[[infileid]] <- cbind.data.frame(nostat, stats)
}


# log
cat("read data\n", file = LOG, append = T)
cat("note, pod indexes may differ from config file\n",
    file = LOG,
    append = T)


# put into a single tibble
df <- do.call(rbind.data.frame, pod.stats) %>% tibble()



# log
cat("start plotting\n", file = LOG, append = T)


# collect plots in list
plot_list <- list()
plot.id <- 0

# parameters
plot.id <- plot.id + 1
plot_list[[plot.id]] <- df %>%
  select(starts_with("param_"),
         starts_with("podid")) %>%
  pivot_longer(cols = -c("podid"),
               names_to = "param",
               values_to = "value") %>%
  ggplot(aes(podid, value, fill = param)) +
  geom_point(shape = 23) +
  facet_grid(param ~ ., scales = "free") +
  scale_y_continuous(trans = "log",
                     breaks = logbreak,
                     labels = round(log10(logbreak), 1)) +
  scale_fill_manual(values = wes_palette("Darjeeling1")) +
  theme(
    aspect.ratio = 0.707 ,
    legend.position = "none",
    panel.background = element_rect(fill = "gray99"),
    strip.background = element_blank()
  )

# log
cat("plotted parameters\n", file = LOG, append = T)


# param sumstat correlation
plot_list[[plot.id]] <-
  for (pls_name in colnames(df)[grep("^LinearCombination_", colnames(df))]) {
    plot.id <- plot.id + 1
    plot_list[[plot.id]] <- df %>%
      select(!starts_with("LinearCombination_"),
             starts_with(all_of(pls_name))) %>%
      rename(LinearCombination = pls_name) %>%
      pivot_longer(cols = colnames(df)[grep("^param_", colnames(df))],
                   names_to = "param",
                   values_to = "value") %>%
      ggplot(aes(podid, LinearCombination, fill = as.factor(statcomposition))) +
      geom_point(position = "jitter", shape = 23) +
      facet_grid(statcomposition ~ ., scales = "free") +
      scale_fill_manual(values = wes_palette("Darjeeling2")) +
      labs(title = paste(pls_name)) +
      theme(aspect.ratio = 0.707 / 2,
            legend.position = "none")


    # log
    cat("plotted",
        pls_name,
        "\n",
        file = LOG,
        append = T)
  }


# log
cat("start printing to file\n", file = LOG, append = T)


# print to file
pdf(snakemake@output$pdf)
for (my_plot in 1:length(plot_list)) {
  show(plot_list[[my_plot]])

  # log
  cat(
    "successfully printed plot to file:",
    my_plot,
    "of",
    length(plot_list),
    "\n",
    file = LOG,
    append = T
  )
}
dev.off()

# log
cat("successfully printed plots to file\n",
    file = LOG,
    append = T)
ShowHide 79 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/sstruett/tsABC
Name: tsabc
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Accessed: 5
Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...