StratoMod: Quantifying the Difficulty of Variant Calling in Genomic Context

public public 1yr ago Version: v8.0.4 0 bookmarks

A model-based tool to quantify the difficulty of calling a variant given genomic context.

Background

Intuitively we understand that accurately calling variants in a genome can be more or less difficult depending on the context of that variant. For example, many sequencing technologies have higher error rates in homopolymers, and this error rate generally increases as homopolymers get longer. However, precisely quantifying the relationship between these errors, the length of the homopolymer, and the impact on the resulting variant call remain challenging. Analogous arguments can be drawn for other "repetitive" regions in the genome, such as tandem repeats, segmental duplications, transposable elements, and difficult-to-map regions.

The solution we present here is to use an interpretable modeling framework called explainable boosting machines to predict variant calling errors as a function of genomic features (eg, whether or not the variant in a tandem repeat, homopolymer, etc). The interpretability of the model is important for allowing end users to understand the relationship each feature has to the prediction, which facilitates understanding (for example) at what lengths of homopolymers the likelihood of incorrectly calling a variant drastically increases. This precision is an improvement over existing methods we have developed for stratifying the genome by difficulty into discrete bins. Furthermore, this modeling framework allows understanding of interactions between different genomic contexts, which is important as many repetitive characteristics do not exist in isolation.

We anticipate StratoMod would be useful for both method developers and clinicians who wish to better understand variant calling error modalities. In the case of method development, StratoMod can be used to accurately compare the error modalities of different sequencing technologies. For clinicians, this can be used for determining in which regions/genes (which may be clinically interesting for a given study) variant errors are likely to occur, which may in turn inform which technologies should be employed and/or other mitigation strategies should be used.

Further information can be found in our preprint .

User Guide

Pipeline steps

  1. Compare user-supplied query vcf with GIAB benchmark vcf to produce labels (true positive, false positive, false negative). The labels comprise the dependent variable used in model training downstream.

  2. Intersect comparison output labels with genomic features to produce the features (independent variables) used in model training.

  3. Train the EBM model with random holdout for testing

  4. If desired, test the model on other query vcfs (which may or may not also be labeled with a benchmark comparison).

  5. Inspect the output features (plots showing the profile of each feature and its effect on the label).

NOTE: currently only two labels can be compared at once given that we used a binary classifier. This means either one of the three labels must be omitted or two need to be combined into one label.

Data Inputs

The only mandatory user-supplied data required to run is a query vcf. Optionally one can supply other vcfs for testing the model.

Unless one is using esoteric references or benchmarks, the pipeline is preconfigured to retrieve commonly-used data defined by flags in the configuration. This includes:

  • a GIAB benchmark, including the vcf, bed, and reference fasta

  • reference-specific bed files which will provide "contextual" features for each variant call, including:

    • difficult-to-map regions (GIAB stratification bed file)

    • segmental duplications (UCSC superdups database)

    • tandem repeats (UCSC simple repeats database)

    • transposable elements (UCSC Repeat Masker)

Installation

This assumes the user has a working conda or mamba installation.

Run the following to set up the runtime environment.

mamba env create -f env.yml

Configuration

A sample configuration file may be found in config/dynamic-testing.yml which may be copied as a starting point and modified to one's liking. This file is heavily annotated to explain all the options/flags and their purpose.

For a list of features which may be used, see FEATURES.md .

Running

Execute the pipeline using snakemake:

snakemake --use-conda -c <num_cores> --rerun-incomplete --configfile=config/<confname.yml> all

Output

Report

Each model has a report at results/model/<model_key>-<filter_key>-<run_key>/summary.html which contains model performance curves and feature plots (the latter which allows model interpretation).

Here <model_key> is the key under the models section in the config, <filter_key is either SNV or INDEL depending on what was requested, and <run_key> is the key under the models -> <model_key> -> runs section in the config.

Train/test data

All raw data for the models will be saved alongside the model report (see above). This includes the input tsv of data used to train the EBM, a config yml file with all settings used to train the EBM for reference, and python pickles for the X/Y train/test datasets as well as a pickle for the final model itself.

Within the run directory will also be a test directory which will contain all test runs (eg the results of the model test and the input data used for the test).

Raw input data

In addition to the model data itself, the raw input data (that is the master dataframe with all features for each query vcf prior to filtering/transformation) can be found in results/annotated/{unlabeled,labeled}/<query_key> where query_key is the key under either labeled_queries or unlabeled_queries in the config.

Each of these directories contains the raw dataframe itself (both both SNVs and INDELs) as well as an HTML report summarizing the dataframe (statistics for each feature, distributions, correlations, etc)

Developer Guide

Environments

By convention, the conda environment specified by env.yml only has runtime dependencies for the pipeline itself.

To install development environments, run the following:

./setup_dev.sh

In addition to creating new environments, this script will update existing ones if they are changed during development.

Note that scripts in the pipeline are segregated by environment in order to prevent dependency hell while maintaining reproducible builds. When editing, one will need to switch between environments in the IDE in order to benefit from the features they provide. Further details on which environments correspond to which files can be found in workflow/scripts .

Note that this will only install environments necessary for running scripts (eg rules with a script directive).

Linting

All python code should be error free when finalizing any new features. Linting will be performed automatically as part of the CI/CD pipeline, but to run it manually, invoke the following:

./lint.sh

This assumes all development environments are installed (see above).

New Feature Workflow

There are two main development branches: master and develop .

Make a new branch off of develop for the new feature, then merge into develop when done (note --no-ff ).

git checkout develop
git branch -n <new_feature>
git checkout <new_feature>
# do a bunch of stuff...
git checkout develop
git merge --no-ff <new_feature>

After feature(s) have been added and all tests have succeeded, update changelog, add tag, and merge into master. Use semantic versioning for tags.

# update changelog
vim CHANGELOG.md
git commit
git tag vX.Y.Z
git checkout master
git merge --no-ff vX.Y.Z

NOTE: do not add an experiment-specific configuration to master or develop . The yml files in config for these branches are used for testing. See below for how to add an experiment.

Code Snippets

15
16
shell:
    "curl -sS -L -o {output} {params.url}"
24
25
26
27
28
shell:
    """
    mkdir {output} && \
    tar xzf {input} --directory {output} --strip-components=1
    """
40
41
shell:
    "make -C {input} > {log} && mv {input}/repseq {output}"
55
56
57
58
59
60
61
shell:
    """
    gunzip -c {input.ref} | \
    {input.bin} 1 4 - 2> {log} | \
    sed '/^#/d' | \
    gzip -c > {output}
    """
80
81
script:
    "../../scripts/python/bio/get_homopoly_features.py"
18
19
script:
    "../../scripts/python/bio/download_bed_or_vcf.py"
51
52
script:
    "../../scripts/python/bio/get_mappability_features.py"
39
40
script:
    "../../scripts/python/bio/get_repeat_masker_features.py"
29
30
script:
    "../../scripts/python/bio/get_segdup_features.py"
32
33
script:
    "../../scripts/python/bio/get_tandem_repeat_features.py"
37
38
script:
    "../scripts/python/bio/download_ref.py"
69
70
script:
    "../scripts/python/bio/filter_sort_ref.py"
83
84
85
86
87
88
shell:
    """
    samtools faidx {input} -o - 2> {log} | \
    cut -f1,2 > \
    {output}
    """
102
103
shell:
    "rtg format -o {output} {input} 2>&1 > {log}"
123
124
script:
    "../scripts/python/bio/write_bed.py"
SnakeMake From line 123 of rules/inputs.smk
141
142
script:
    "../scripts/python/bio/download_bed_or_vcf.py"
SnakeMake From line 141 of rules/inputs.smk
164
165
script:
    "../scripts/python/bio/standardize_vcf.py"
SnakeMake From line 164 of rules/inputs.smk
188
189
shell:
    "tabix -p vcf {input}"
247
248
script:
    "../scripts/python/bio/standardize_bed.py"
SnakeMake From line 247 of rules/inputs.smk
260
261
262
263
264
shell:
    """
    bedtools subtract -a {input.bed} -b {input.mhc} | \
    bgzip > {output}
    """
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
shell:
    """
    rm -rf {params.tmp_dir} && \

    rtg RTG_MEM=$(({resources.mem_mb}*80/100))M \
    vcfeval {params.extra} \
    --threads={threads} \
    -b {input.bench_vcf} \
    -e {input.bench_bed} \
    -c {input.query_vcf} \
    -o {params.tmp_dir} \
    -t {input.sdf} > {log} 2>&1 && \

    mv {params.tmp_dir}/* {params.output_dir} && \

    rm -r {params.tmp_dir}
    """
353
354
script:
    "../scripts/python/bio/vcf_to_bed.py"
SnakeMake From line 353 of rules/inputs.smk
376
377
script:
    "../scripts/python/bio/concat_tsv.py"
SnakeMake From line 376 of rules/inputs.smk
75
76
script:
    "../scripts/python/bio/annotate_variants.py"
115
116
script:
    "../scripts/rmarkdown/summary/input_summary.Rmd"
161
162
script:
    "../scripts/python/bio/prepare_train.py"
187
188
script:
    "../scripts/python/ebm/train_ebm.py"
203
204
script:
    "../scripts/python/ebm/decompose_model.py"
227
228
script:
    "../scripts/rmarkdown/summary/train_summary.Rmd"
274
275
script:
    "../scripts/python/bio/prepare_test.py"
325
326
script:
    "../scripts/python/ebm/test_ebm.py"
360
361
script:
    "../scripts/rmarkdown/summary/test_summary.Rmd"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from functools import reduce
import pandas as pd
from typing import Any, cast
import numpy as np
from common.tsv import write_tsv
from pybedtools import BedTool as bt  # type: ignore
from common.io import setup_logging
import common.config as cfg

logger = setup_logging(snakemake.log[0])  # type: ignore


def left_outer_intersect(left: pd.DataFrame, path: str) -> pd.DataFrame:
    logger.info("Adding annotations from %s", path)

    # Use bedtools to perform left-outer join of two bed/tsv files. Since
    # bedtools will join all columns from the two input files, keep track of the
    # width of the left input file so that the first three columns of the right
    # input (chr, chrStart, chrEnd, which are redundant) can be dropped.
    left_cols = left.columns.tolist()
    left_width = len(left_cols)
    right = pd.read_table(path)
    # ASSUME the first three columns are the bed index columns
    right_cols = ["_" + c if i < 3 else c for i, c in enumerate(right.columns.tolist())]
    right_bed = bt.from_dataframe(right)
    # prevent weird type errors when converted back to dataframe from bed
    dtypes = {right_cols[0]: str}
    # convert "." to NaN since "." is a string/object which will make pandas run
    # slower than an actual panda
    na_vals = {c: "." for c in left_cols + right_cols[3:]}
    new_df = cast(
        pd.DataFrame,
        bt.from_dataframe(left)
        .intersect(right_bed, loj=True)
        .to_dataframe(names=left_cols + right_cols, na_values=na_vals, dtype=dtypes),
    )
    # Bedtools intersect will use -1 for NULL in the case of numeric columns. I
    # suppose this makes sense since any "real" bed columns (according to the
    # "spec") will always be positive integers or strings. Since -1 might be a
    # real value and not a missing one in my case, use the chr field to figure
    # out if a row is "missing" and fill NaNs accordingly
    new_cols = new_df.columns[left_width:]
    new_pky = new_cols[:3]
    new_chr = new_pky[0]
    new_data_cols = new_cols[3:]
    new_df.loc[:, new_data_cols] = new_df[new_data_cols].where(
        new_df[new_chr] != ".", np.nan
    )

    logger.info("Annotations added: %s\n", ", ".join(new_data_cols.tolist()))

    return new_df.drop(columns=new_pky)


def intersect_tsvs(
    config: cfg.StratoMod,
    ifile: str,
    ofile: str,
    tsv_paths: list[str],
) -> None:
    target_df = pd.read_table(ifile)
    new_df = reduce(left_outer_intersect, tsv_paths, target_df)
    new_df.insert(loc=0, column=cfg.VAR_IDX, value=new_df.index)
    write_tsv(ofile, new_df)


def main(smk: Any, config: cfg.StratoMod) -> None:
    fs = smk.input.features
    vcf = smk.input.variants[0]
    logger.info("Adding annotations to %s\n", vcf)
    intersect_tsvs(config, vcf, smk.output[0], fs)


main(snakemake, snakemake.config)  # type: ignore
1
2
3
4
5
6
7
import pandas as pd
from common.tsv import write_tsv
from common.bed import sort_bed_numerically

# use pandas here since it will more reliably account for headers
df = pd.concat([pd.read_table(i, header=0) for i in snakemake.input])  # type: ignore
write_tsv(snakemake.output[0], sort_bed_numerically(df, 3))  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from pathlib import Path
import subprocess as sp
from typing import Callable
from typing_extensions import assert_never
from tempfile import NamedTemporaryFile as Tmp
import common.config as cfg
from common.io import get_md5, is_gzip, setup_logging

# hacky curl/gzip wrapper; this exists because I got tired of writing
# specialized rules to convert gzip/nozip files to bgzip and back :/
# Solution: force bgzip for references and gzip for bed

log = setup_logging(snakemake.log[0])  # type: ignore

GZIP = ["gzip", "-c"]
CURL = ["curl", "-Ss", "-L", "-q"]


def main(opath: Path, src: cfg.FileSrc | None) -> None:
    if isinstance(src, cfg.LocalSrc):
        # ASSUME these are already tested via the pydantic class for the
        # proper file format
        Path(opath).symlink_to(Path(src.filepath).resolve())

    elif isinstance(src, cfg.HTTPSrc):
        curlcmd = [*CURL, src.url]

        # to test the format of downloaded files, sample the first 65000 bytes
        # (which should be enough to get one block of a bgzip file, which will
        # allow us to test for it)
        curltestcmd = [*CURL, "-r", "0-65000", src.url]

        with open(opath, "wb") as f, Tmp() as tf:

            def curl() -> None:
                sp.Popen(curlcmd, stdout=f).wait()

            def curl_test(testfun: Callable[[Path], bool]) -> bool:
                sp.Popen(curltestcmd, stdout=tf).wait()
                return testfun(Path(tf.name))

            def curl_gzip(cmd: list[str]) -> None:
                p1 = sp.Popen(curlcmd, stdout=sp.PIPE)
                p2 = sp.Popen(cmd, stdin=p1.stdout, stdout=f)
                p2.wait()

            if curl_test(is_gzip):
                curl()
            else:
                curl_gzip(GZIP)

    elif src is None:
        assert False, "file src is null; this should not happen"
    else:
        assert_never(src)

    if src.md5 is not None and src.md5 != (actual := get_md5(opath)):
        log.error("md5s don't match; wanted %s, actual %s", src.md5, actual)
        exit(1)


main(Path(snakemake.output[0]), snakemake.params.src)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pathlib import Path
import subprocess as sp
from typing import Callable, Any, cast
from typing_extensions import assert_never
from tempfile import NamedTemporaryFile as Tmp
from common.io import is_gzip, setup_logging, get_md5, get_md5_dir
from common.bed import is_bgzip
import common.config as cfg


GUNZIP = ["gunzip", "-c"]
BGZIP = ["bgzip", "-c"]
CURL = ["curl", "-Ss", "-L", "-q"]

log = setup_logging(snakemake.log[0])  # type: ignore


def main(smk: Any, params: Any) -> None:
    src = cast(cfg.FileSrc, params.src)
    opath = Path(smk.output[0])
    is_fasta = smk.params.is_fasta

    if isinstance(src, cfg.LocalSrc):
        # ASSUME this is in the format we indicate (TODO be more paranoid)
        opath.symlink_to(Path(src.filepath).resolve())

    elif isinstance(src, cfg.HTTPSrc):
        curlcmd = [*CURL, src.url]

        if is_fasta:
            # to test the format of downloaded files, sample the first 65000 bytes
            # (which should be enough to get one block of a bgzip file, which will
            # allow us to test for it)
            curltestcmd = [*CURL, "-r", "0-65000", src.url]

            with open(opath, "wb") as f, Tmp() as tf:

                def curl() -> None:
                    sp.Popen(curlcmd, stdout=f).wait()

                def curl_test(testfun: Callable[[Path], bool]) -> bool:
                    sp.Popen(curltestcmd, stdout=tf).wait()
                    return testfun(Path(tf.name))

                def curl_gzip(cmd: list[str]) -> None:
                    p1 = sp.Popen(curlcmd, stdout=sp.PIPE)
                    p2 = sp.Popen(cmd, stdin=p1.stdout, stdout=f)
                    p2.wait()

                if curl_test(is_bgzip):
                    curl()
                elif curl_test(is_gzip):
                    p1 = sp.Popen(curlcmd, stdout=sp.PIPE)
                    p2 = sp.Popen(GUNZIP, stdin=p1.stdout, stdout=sp.PIPE)
                    p3 = sp.Popen(BGZIP, stdin=p2.stdout, stdout=f)
                    p3.wait()
                else:
                    curl_gzip(BGZIP)

        else:
            tarcmd = [
                *["bsdtar", "-xf", "-"],
                *["--directory", str(opath)],
                "--strip-component=1",
            ]

            opath.mkdir(parents=True)

            p1 = sp.Popen(curlcmd, stdout=sp.PIPE)
            p2 = sp.Popen(tarcmd, stdin=p1.stdout)
            p2.wait()

    else:
        assert_never(src)

    if src.md5 is not None:
        actual = get_md5(opath) if is_fasta else get_md5_dir(opath)
        if actual != src.md5:
            log.error("md5s don't match; wanted %s, actual %s", src.md5, actual)
            exit(1)


main(snakemake, snakemake.params)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import re
from typing import Any
import subprocess as sp
import common.config as cfg
from Bio import bgzf  # type: ignore
from common.io import setup_logging

logger = setup_logging(snakemake.log[0])  # type: ignore


def stream_fasta(ipath: str, chr_names: list[str]) -> sp.Popen[bytes]:
    return sp.Popen(
        ["samtools", "faidx", ipath, *chr_names],
        stdout=sp.PIPE,
        stderr=sp.PIPE,
    )


def stream_sdf(ipath: str, chr_names: list[str]) -> sp.Popen[bytes]:
    return sp.Popen(
        [
            *["rtg", "sdf2fasta", "--no-gzip", "--line-length=70"],
            *["--input", ipath],
            *["--output", "-"],
            *["--names", *chr_names],
        ],
        stdout=sp.PIPE,
        stderr=sp.PIPE,
    )


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    rsk = cfg.RefsetKey(smk.wildcards["refset_key"])
    cs = sconf.refsetkey_to_chr_indices(rsk)
    prefix = sconf.refsetkey_to_ref(rsk).sdf.chr_prefix

    chr_mapper = {c.chr_name_full(prefix): c.value for c in cs}
    chr_names = [*chr_mapper]

    # Read from a fasta or sdf depending on what we were given; in either
    # case, read only the chromosomes we want in sorted order and return a
    # fasta text stream
    def choose_input(i: Any) -> sp.Popen[bytes]:
        try:
            return stream_fasta(i.fasta[0], chr_names)
        except AttributeError:
            try:
                return stream_sdf(i.sdf[0], chr_names)
            except AttributeError:
                assert False, "unknown input key, this should not happen"

    p = choose_input(smk.input)

    if p.stdout is not None:
        # Stream the fasta and replace the chromosome names in the header with
        # its integer index
        with bgzf.open(smk.output[0], "w") as f:
            for i in p.stdout:
                if i.startswith(b">"):
                    m = re.match(">([^ \n]+)", i.decode())
                    if m is None:
                        logger.error("could get chrom name from FASTA header")
                        exit(1)
                    try:
                        f.write(f">{chr_mapper[m[1]]}\n")
                    except KeyError:
                        assert False, (
                            "could not convert '%s' to index, this should not happen"
                            % m[1]
                        )
                else:
                    f.write(i)
    else:
        assert False, "stdout not a pipe, this should not happen"

    p.wait()

    if p.returncode != 0:
        logger.error(p.stderr)
        exit(1)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from pathlib import Path
import pandas as pd
import common.config as cfg
from typing import Any, cast
from pybedtools import BedTool as bt  # type: ignore
from pybedtools import cleanup
from common.tsv import write_tsv
from common.io import setup_logging
from common.bed import read_bed


logger = setup_logging(snakemake.log[0])  # type: ignore

# temporary columns used for dataframe processing
BASE_COL = "_base"
PFCT_LEN_COL = "_perfect_length"

SLOP = 1


def read_input(path: Path) -> pd.DataFrame:
    logger.info("Reading dataframe from %s", path)
    return read_bed(path, more={3: BASE_COL})


def merge_base(
    config: cfg.StratoMod,
    df: pd.DataFrame,
    base: cfg.Base,
    genome: str,
) -> pd.DataFrame:
    logger.info("Filtering bed file for %ss", base)
    _df = df[df[BASE_COL] == f"unit={base.value}"].drop(columns=[BASE_COL])
    logger.info("Merging %s rows for %ss", len(_df), base)
    # Calculate the length of each "pure" homopolymer (eg just "AAAAAAAA").
    # Note that this is summed in the merge below, and the final length based
    # on start/end won't necessarily be this sum because of the -d 1 parameter
    _df[PFCT_LEN_COL] = _df[cfg.BED_END] - _df[cfg.BED_START]
    merged = cast(
        pd.DataFrame,
        bt.from_dataframe(_df)
        .merge(d=1, c=[4], o=["sum"])
        .slop(b=SLOP, g=genome)
        .to_dataframe(names=[*cfg.BED_COLS, PFCT_LEN_COL]),
    )
    # these files are huge; now that we have a dataframe, remove all the bed
    # files from tmpfs to prevent a run on downloadmoreram.com
    cleanup()

    hgroup = config.feature_definitions.homopolymers

    length_col = hgroup.fmt_name(base, lambda x: x.len)
    frac_col = hgroup.fmt_name(base, lambda x: x.imp_frac)

    merged[length_col] = merged[cfg.BED_END] - merged[cfg.BED_START] - SLOP * 2
    merged[frac_col] = 1 - (merged[PFCT_LEN_COL] / merged[length_col])
    return merged.drop(columns=[PFCT_LEN_COL])


def main(smk: Any, config: cfg.StratoMod) -> None:
    # ASSUME this file is already sorted
    simreps = read_input(smk.input["bed"][0])
    merged = merge_base(
        config,
        simreps,
        cfg.Base(smk.wildcards["base"]),
        smk.input["genome"][0],
    )
    write_tsv(smk.output[0], merged, header=True)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pathlib import Path
import pandas as pd
import common.config as cfg
from typing import Any
from pybedtools import BedTool as bt  # type: ignore
from common.tsv import write_tsv
from common.bed import read_bed
from common.io import setup_logging

logger = setup_logging(snakemake.log[0])  # type: ignore


def main(smk: Any, config: cfg.StratoMod) -> None:
    rsk = cfg.RefsetKey(smk.wildcards["refset_key"])
    cs = config.refsetkey_to_chr_indices(rsk)
    mapconf = config.refsetkey_to_ref(rsk).feature_data.mappability
    mapmeta = config.feature_definitions.mappability

    def read_map_bed(p: Path, ps: cfg.BedFileParams, col: str) -> pd.DataFrame:
        logger.info("Reading mappability feature: %s", col)
        df = read_bed(p, ps, {}, cs)
        df[col] = 1
        return df

    high = read_map_bed(smk.input["high"][0], mapconf.high.params, mapmeta.high)
    low = read_map_bed(smk.input["low"][0], mapconf.low.params, mapmeta.low)
    # subtract high from low (since the former is a subset of the latter)
    new_low = (
        bt.from_dataframe(low)
        .subtract(bt.from_dataframe(high))
        .to_dataframe(names=low.columns.tolist())
    )
    write_tsv(smk.output["high"], high)
    write_tsv(smk.output["low"], new_low)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import pandas as pd
from pathlib import Path
from typing import Optional, Any
import common.config as cfg
from pybedtools import BedTool as bt  # type: ignore
from common.tsv import write_tsv
from common.io import setup_logging
from common.bed import read_bed

# The repeat masker database is documented here:
# https://genome.ucsc.edu/cgi-bin/hgTables?db=hg38&hgta_group=rep&hgta_track=rmsk&hgta_table=rmsk&hgta_doSchema=describe+table+schema

logger = setup_logging(snakemake.log[0])  # type: ignore

# both of these columns are temporary and used to make processing easier
CLASSCOL = "_repClass"
FAMCOL = "_repFamily"


def main(smk: Any, config: cfg.StratoMod) -> None:
    rsk = cfg.RefsetKey(smk.wildcards["refset_key"])
    rk = config.refsetkey_to_refkey(rsk)
    src = config.references[rk].feature_data.repeat_masker
    cs = config.refsetkey_to_chr_indices(rsk)

    def read_rmsk_df(path: Path) -> pd.DataFrame:
        cols = {11: CLASSCOL, 12: FAMCOL}
        return read_bed(path, src.params, cols, cs)

    def merge_and_write_group(
        df: pd.DataFrame,
        groupcol: str,
        clsname: str,
        famname: Optional[str] = None,
    ) -> None:
        groupname = clsname if famname is None else famname
        dropped = df[df[groupcol] == groupname].drop(columns=[groupcol])
        merged = bt.from_dataframe(dropped).merge().to_dataframe(names=cfg.BED_COLS)
        col = config.feature_definitions.repeat_masker.fmt_name(src, clsname, famname)
        merged[col] = merged[cfg.BED_END] - merged[cfg.BED_START]
        write_tsv(smk.output[0], merged, header=True)

    cls = smk.wildcards.rmsk_class
    df = read_rmsk_df(smk.input[0])

    try:
        fam = smk.wildcards.rmsk_family
        logger.info("Filtering/merging rmsk family %s/class %s", fam, cls)
        merge_and_write_group(df, FAMCOL, cls, fam)
    except AttributeError:
        logger.info("Filtering/merging rmsk class %s", cls)
        merge_and_write_group(df, CLASSCOL, cls)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import pandas as pd
from pathlib import Path
from typing import Any, cast
import common.config as cfg
from common.tsv import write_tsv
from common.bed import read_bed, merge_and_apply_stats
from common.io import setup_logging

# This database is documented here:
# http://genome.ucsc.edu/cgi-bin/hgTables?hgta_doSchemaDb=hg38&hgta_doSchemaTable=genomicSuperDups

# ASSUME segdups dataframe is fed into this script with the chromosome column
# standardized. The column numbers below are dictionary values, and the
# corresponding feature names are the dictionary keys. Note that many feature
# names don't match the original column names in the database.

logger = setup_logging(snakemake.log[0])  # type: ignore


def read_segdups(
    smk: Any,
    config: cfg.StratoMod,
    path: Path,
    fconf: cfg.SegDupsGroup,
) -> pd.DataFrame:
    rsk = cfg.RefsetKey(smk.wildcards["refset_key"])
    rk = config.refsetkey_to_refkey(rsk)
    s = config.references[rk].feature_data.segdups
    ocs = s.other_cols
    feature_cols = {
        ocs.align_L: str(fconf.fmt_col(lambda x: x.alignL)[0]),
        ocs.frac_match_indel: str(fconf.fmt_col(lambda x: x.fracMatchIndel)[0]),
    }
    cs = config.refsetkey_to_chr_indices(rsk)
    return read_bed(path, s.params, feature_cols, cs)


def merge_segdups(
    df: pd.DataFrame,
    fconf: cfg.SegDupsGroup,
) -> pd.DataFrame:
    bed, names = merge_and_apply_stats(fconf, df)
    return cast(pd.DataFrame, bed.to_dataframe(names=names))


def main(smk: Any, config: cfg.StratoMod) -> None:
    fconf = config.feature_definitions.segdups
    repeat_df = read_segdups(smk, config, smk.input[0], fconf)
    merged_df = merge_segdups(repeat_df, fconf)
    write_tsv(smk.output[0], merged_df, header=True)


# TODO make a stub so I don't need to keep repeating this
main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path
import pandas as pd
from typing import Any, cast
import common.config as cfg
from common.tsv import write_tsv
from common.bed import read_bed, merge_and_apply_stats
from common.io import setup_logging

# Input dataframe documented here:
# https://genome.ucsc.edu/cgi-bin/hgTables?db=hg38&hgta_group=rep&hgta_track=simpleRepeat&hgta_table=simpleRepeat&hgta_doSchema=describe+table+schema
#
# ASSUME this dataframe is fed into this script as-is. The column numbers below
# are dictionary values, and the corresponding feature names are the dictionary
# keys. Note that many feature names don't match the original column names in
# the database.

logger = setup_logging(snakemake.log[0])  # type: ignore

SLOP = 5


def read_tandem_repeats(
    smk: Any,
    path: Path,
    fconf: cfg.TandemRepeatGroup,
    sconf: cfg.StratoMod,
) -> pd.DataFrame:
    rsk = cfg.RefsetKey(smk.wildcards["refset_key"])
    rk = sconf.refsetkey_to_refkey(rsk)
    ss = sconf.references[rk].feature_data.tandem_repeats
    ocs = ss.other_cols
    fmt_col = fconf.fmt_col
    perc_a_col = str(fconf.A[0])
    perc_t_col = str(fconf.T[0])
    perc_c_col = str(fconf.C[0])
    perc_g_col = str(fconf.G[0])
    unit_size_col = fmt_col(lambda x: x.period)[0]
    feature_cols = {
        ocs.period: unit_size_col,
        ocs.copy_num: fmt_col(lambda x: x.copyNum)[0],
        ocs.per_match: fmt_col(lambda x: x.perMatch)[0],
        ocs.per_indel: fmt_col(lambda x: x.perIndel)[0],
        ocs.score: fmt_col(lambda x: x.score)[0],
        ocs.per_A: perc_a_col,
        ocs.per_C: perc_c_col,
        ocs.per_G: perc_g_col,
        ocs.per_T: perc_t_col,
    }
    cs = sconf.refsetkey_to_chr_indices(rsk)
    df = read_bed(path, ss.params, feature_cols, cs)
    base_groups = [
        (fconf.AT[0], perc_a_col, perc_t_col),
        (fconf.AG[0], perc_a_col, perc_g_col),
        (fconf.CT[0], perc_c_col, perc_t_col),
        (fconf.GC[0], perc_c_col, perc_g_col),
    ]
    for double, single1, single2 in base_groups:
        df[double] = df[single1] + df[single2]
    # Filter out all TRs that have period == 1, since those by definition are
    # homopolymers. NOTE, there is a difference between period and consensusSize
    # in this database; however, it turns out that at least for GRCh38 that the
    # sets of TRs where either == 1 are identical, so just use period here
    # since I can easily refer to it.
    logger.info("Removing TRs with unitsize == 1")
    return df[df[unit_size_col] > 1]


def merge_tandem_repeats(
    gfile: str,
    df: pd.DataFrame,
    fconf: cfg.TandemRepeatGroup,
) -> pd.DataFrame:
    bed, names = merge_and_apply_stats(fconf, df)
    merged_df = cast(pd.DataFrame, bed.slop(b=SLOP, g=gfile).to_dataframe(names=names))
    len_col = fconf.length[0]
    merged_df[len_col] = merged_df[cfg.BED_END] - merged_df[cfg.BED_START] - SLOP * 2
    return merged_df


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    i = smk.input
    fconf = sconf.feature_definitions.tandem_repeats
    repeat_df = read_tandem_repeats(smk, Path(i.src[0]), fconf, sconf)
    merged_df = merge_tandem_repeats(i.genome[0], repeat_df, fconf)
    write_tsv(smk.output[0], merged_df, header=True)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pandas as pd
from typing import Any
from common.tsv import write_tsv
from common.io import setup_logging
import common.config as cfg
from common.prepare import process_labeled_data, process_unlabeled_data

logger = setup_logging(snakemake.log[0])  # type: ignore


def write_labeled(
    xpath: str,
    ypath: str,
    sconf: cfg.StratoMod,
    rconf: cfg.Model,
    df: pd.DataFrame,
) -> None:
    filter_col = sconf.feature_definitions.vcf.filter
    label_col = sconf.feature_definitions.label_name
    processed = process_labeled_data(
        rconf.features,
        rconf.error_labels,
        rconf.filtered_are_candidates,
        [cfg.FeatureKey(c) for c in cfg.IDX_COLS],
        filter_col,
        cfg.FeatureKey(label_col),
        df,
    )
    write_tsv(xpath, processed.drop([label_col], axis=1))
    write_tsv(ypath, processed[label_col].to_frame())


def write_unlabeled(
    xpath: str,
    sconf: cfg.StratoMod,
    rconf: cfg.Model,
    df: pd.DataFrame,
) -> None:
    processed = process_unlabeled_data(
        rconf.features,
        [cfg.FeatureKey(c) for c in cfg.IDX_COLS],
        df,
    )
    write_tsv(xpath, processed)


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    sin = smk.input
    sout = smk.output
    wcs = smk.wildcards
    variables = sconf.testkey_to_variables(
        cfg.ModelKey(wcs["model_key"]),
        cfg.TestKey(wcs["test_key"]),
    )
    df = pd.read_table(sin["annotated"][0]).assign(
        **{str(k): v for k, v in variables.items()}
    )
    rconf = sconf.models[cfg.ModelKey(wcs.model_key)]
    if "test_y" in dict(sout):
        write_labeled(
            sout["test_x"],
            sout["test_y"],
            sconf,
            rconf,
            df,
        )
    else:
        write_unlabeled(sout["test_x"], sconf, rconf, df)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pandas as pd
import common.config as cfg
from typing import Any
from common.tsv import write_tsv
from common.io import setup_logging
from common.prepare import process_labeled_data

logger = setup_logging(snakemake.log[0])  # type: ignore


def read_query(
    config: cfg.StratoMod, path: str, key: cfg.LabeledQueryKey
) -> pd.DataFrame:
    variables = config.querykey_to_variables(key)
    return pd.read_table(path).assign(**{str(k): v for k, v in variables.items()})


def read_queries(
    config: cfg.StratoMod,
    paths: dict[cfg.LabeledQueryKey, str],
) -> pd.DataFrame:
    # TODO this is weird, why do I need the [0] here?
    return pd.concat([read_query(config, path[0], key) for key, path in paths.items()])


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    rconf = sconf.models[cfg.ModelKey(cfg.ModelKey(smk.wildcards.model_key))]
    fconf = sconf.feature_definitions
    raw_df = read_queries(sconf, smk.input)
    processed = process_labeled_data(
        rconf.features,
        rconf.error_labels,
        rconf.filtered_are_candidates,
        [cfg.FeatureKey(c) for c in cfg.IDX_COLS],
        cfg.FeatureKey(fconf.vcf.filter),
        cfg.FeatureKey(fconf.label_name),
        raw_df,
    )
    write_tsv(smk.output["df"], processed)


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Any, TextIO
from common.config import StratoMod, RefsetKey
from common.io import with_gzip_maybe


def filter_file(smk: Any, config: StratoMod, fi: TextIO, fo: TextIO) -> None:
    rsk = RefsetKey(smk.wildcards["refset_key"])
    chr_prefix = smk.params.chr_prefix
    cs = config.refsetkey_to_chr_indices(rsk)

    chr_mapper = {c.chr_name_full(chr_prefix): c.value for c in cs}

    for ln in fi:
        if ln.startswith("#"):
            fo.write(ln)
        else:
            ls = ln.rstrip().split("\t")
            try:
                ls[0] = str(chr_mapper[ls[0]])
                fo.write("\t".join(ls) + "\n")
            except KeyError:
                pass


def main(smk: Any, config: StratoMod) -> None:
    with_gzip_maybe(
        lambda i, o: filter_file(smk, config, i, o),
        str(smk.input[0]),
        str(smk.output[0]),
    )


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Any, cast, TextIO
from common.config import StratoMod, RefsetKey, VCFFile
from common.bed import with_bgzip_maybe


def fix_DV_refcall(filter_col: str, sample_col: str) -> str:
    return (
        sample_col.replace("./.", "0/1").replace("0/0", "0/1")
        if filter_col == "RefCall"
        else sample_col
    )


def strip_format_fields(
    fields: set[str],
    format_col: str,
    sample_col: str,
) -> tuple[str, str]:
    f, s = zip(
        *[
            (f, s)
            for f, s in zip(format_col.split(":"), sample_col.split(":"))
            if f not in fields
        ]
    )
    return (":".join(f), ":".join(s))


def filter_file(smk: Any, config: StratoMod, fi: TextIO, fo: TextIO) -> None:
    rsk = RefsetKey(smk.wildcards["refset_key"])
    vcf = cast(VCFFile, smk.params.vcf)
    chr_prefix = vcf.chr_prefix
    cs = config.refsetkey_to_chr_indices(rsk)

    chr_mapper = {c.chr_name_full(chr_prefix): c.value for c in cs}

    for ln in fi:
        if ln.startswith("#"):
            fo.write(ln)
        else:
            ls = ln.rstrip().split("\t")[:10]
            # CHROM = 0
            # POS = 1
            # ID = 2
            # REF = 3
            # ALT = 4
            # QUAL = 5
            # FILTER = 6
            # INFO = 7
            # FORMAT = 8
            # SAMPLE = 9
            try:
                ls[0] = str(chr_mapper[ls[0]])
                if vcf.corrections.fix_refcall_gt:
                    ls[9] = fix_DV_refcall(ls[6], ls[9])
                if len(vcf.corrections.strip_format_fields) > 0:
                    ls[8], ls[9] = strip_format_fields(
                        vcf.corrections.strip_format_fields,
                        ls[8],
                        ls[9],
                    )
                fo.write("\t".join(ls) + "\n")
            except KeyError:
                pass


def main(smk: Any, config: StratoMod) -> None:
    with_bgzip_maybe(
        lambda i, o: filter_file(smk, config, i, o),
        str(smk.input[0]),
        str(smk.output[0]),
    )


main(snakemake, snakemake.config)  # type: ignore
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from typing import Any, TextIO
import common.config as cfg
from common.io import with_gzip_maybe, setup_logging

logger = setup_logging(snakemake.log[0])  # type: ignore


def is_real(s: str) -> bool:
    return s.removeprefix("-").replace(".", "", 1).isdigit()


def dot_to_blank(s: str) -> str:
    return "" if s == "." else s


def none_to_blank(s: str | None) -> str:
    return "" if s is None else s


def write_row(
    fo: TextIO,
    chrom: str,
    start: str,
    end: str,
    qual: str,
    filt: str,
    info: str,
    indel_length: str,
    parse_fields: list[str],
    const_fields: list[str],
    label: str | None,
) -> None:
    const_cols = [chrom, start, end, qual, info, filt, indel_length]
    label_col = [] if label is None else [label]
    cols = [*const_cols, *parse_fields, *const_fields, *label_col]
    fo.write("\t".join(cols) + "\n")


def lookup_field(f: cfg.FormatField, d: dict[str, str]) -> str:
    try:
        v = d[f.name]
        if len(f.mapper) == 0:
            return v if is_real(v) else ""
        try:
            return str(f.mapper[v])
        except KeyError:
            return ""
    except KeyError:
        return none_to_blank(f.missing)


def line_to_bed_row(
    fo: TextIO,
    ls: list[str],
    vcf: cfg.UnlabeledVCFQuery,
    vtk: cfg.VartypeKey,
    parse_fields: list[cfg.FormatField],
    const_field_values: list[str],
    label: str | None,
) -> bool:
    # CHROM = 0
    # POS = 1
    # ID = 2
    # REF = 3
    # ALT = 4
    # QUAL = 5
    # FILTER = 6
    # INFO = 7
    # FORMAT = 8
    # SAMPLE = 9

    chrom = int(ls[0])
    start = int(ls[1]) - 1  # bed's are 0-indexed and vcf's are 1-indexed

    # remove cases where ref and alt are equal (which is what "." means)
    if ls[4] == "." or ls[3] == ls[4]:
        logger.info("Skipping equal variant at %s, %s", chrom, start)
        return False

    # remove multiallelics
    if "," in ls[4]:
        logger.info("Skipping multiallelic variant at %s, %s", chrom, start)
        return False

    # remove anything that doesn't pass out length filters
    ref_len = len(ls[3])
    alt_len = len(ls[4])

    if len(ls[3]) > vcf.max_ref or len(ls[4]) > vcf.max_alt:
        logger.info("Skipping oversized variant at %s, %s", chrom, start)
        return False

    # keep only the variant type we care about
    is_snv = ref_len == alt_len == 1

    if is_snv and vtk is cfg.VartypeKey.SNV:
        indel_length = 0
    elif not is_snv and ref_len != alt_len and vtk is cfg.VartypeKey.INDEL:
        indel_length = alt_len - ref_len
    else:
        return False

    # parse the format/sample columns if desired
    if len(parse_fields) > 0:
        fmt_col = ls[8].split(":")
        smpl_col = ls[9].split(":")
        # ASSUME any FORMAT/SAMPLE columns with different lengths are screwed
        # up in some way
        if len(fmt_col) != len(smpl_col):
            logger.error("FORMAT/SAMPLE have different lengths at %s, %s", chrom, start)
            return True
        d = dict(zip(fmt_col, smpl_col))
        parsed_field_values = [lookup_field(f, d) for f in parse_fields]
    else:
        parsed_field_values = []

    write_row(
        fo,
        str(chrom),
        str(start),
        str(start + ref_len),
        dot_to_blank(ls[5]),
        dot_to_blank(ls[6]),
        dot_to_blank(ls[7]),
        str(indel_length),
        parsed_field_values,
        list(const_field_values),
        label,
    )

    return False


def parse(smk: Any, sconf: cfg.StratoMod, fi: TextIO, fo: TextIO) -> None:
    defs = sconf.feature_definitions
    vcf = sconf.querykey_to_vcf(cfg.LabeledQueryKey(smk.params.query_key))
    vtk = cfg.VartypeKey(smk.wildcards.vartype_key)
    found_error = False

    try:
        label = str(smk.wildcards.label)
    except AttributeError:
        label = None

    fields = [(str(defs.vcf.fmt_feature(k)), v) for k, v in vcf.format_fields.items()]
    parse_fields = [(k, v) for k, v in fields if isinstance(v, cfg.FormatField)]
    const_fields = [
        (k, none_to_blank(v)) for k, v in fields if not isinstance(v, cfg.FormatField)
    ]

    # write header
    write_row(
        fo,
        cfg.BED_CHROM,
        cfg.BED_START,
        cfg.BED_END,
        defs.vcf.qual[0],
        defs.vcf.filter,
        defs.vcf.info,
        defs.vcf.indel_length[0],
        [f[0] for f in parse_fields],
        [f[0] for f in const_fields],
        None if label is None else defs.label_name,
    )

    for ln in fi:
        if ln.startswith("#"):
            continue

        err = line_to_bed_row(
            fo,
            ln.rstrip().split("\t"),
            vcf,
            vtk,
            [f[1] for f in parse_fields],
            [f[1] for f in const_fields],
            label,
        )

        found_error = err or found_error

    if found_error is True:
        exit(1)


def main(smk: Any, config: cfg.StratoMod) -> None:
    with_gzip_maybe(
        lambda i, o: parse(smk, config, i, o),
        str(smk.input[0]),
        str(smk.output[0]),
    )


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import common.config as cfg
from Bio import bgzf  # type: ignore


def main(opath: str, regions: list[cfg.BedRegion]) -> None:
    with bgzf.open(opath, "w") as f:
        for r in sorted(regions):
            f.write(r.fmt() + "\n")


main(snakemake.output[0], snakemake.params.regions)  # type: ignore
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import json
import pandas as pd
import numpy as np
from numpy.typing import NDArray
from typing import Any, Hashable, cast, TypedDict
from common.io import setup_logging
from common.ebm import read_model
from common.tsv import write_tsv
import common.config as cfg
from interpret.glassbox import ExplainableBoostingClassifier  # type: ignore
from enum import Enum

setup_logging(snakemake.log[0])  # type: ignore


IndexedVectors = dict[int, NDArray[np.float64]]
NamedVectors = dict[str, NDArray[np.float64]]

EBMUniData = TypedDict(
    "EBMUniData",
    {"type": str, "names": list[float | str], "scores": NDArray[np.float64]},
)


class VarType(Enum):
    INT = "interaction"
    CNT = "continuous"
    CAT = "categorical"


AllFeatures = dict[str, tuple[VarType, int]]


Variable = TypedDict("Variable", {"name": str, "type": str})


BivariateData = TypedDict(
    "BivariateData",
    {
        "left": Variable,
        "right": Variable,
        "df": dict[Hashable, float],
    },
)


GlobalScoreData = TypedDict(
    "GlobalScoreData",
    {
        "variable": list[str],
        "score": list[float],
    },
)


UnivariateDF = TypedDict(
    "UnivariateDF",
    {
        "value": list[str | float],
        "score": list[float],
        "stdev": list[float],
    },
)


UnivariateData = TypedDict(
    "UnivariateData",
    {
        "name": str,
        "vartype": str,
        "df": UnivariateDF,
    },
)


ModelData = TypedDict(
    "ModelData",
    {
        "global_scores": GlobalScoreData,
        "intercept": float,
        "univariate": list[UnivariateData],
        "bivariate": list[BivariateData],
    },
)


# TODO there is no reason this can't be done immediately after training
# just to avoid the pickle thing


def array_to_list(arr: NDArray[np.float64], repeat_last: bool) -> list[float]:
    # cast needed since this can return a nested list depending on number of dims
    al = cast(list[float], arr.tolist())
    return al + [al[-1]] if repeat_last else al


def get_univariate_df(
    continuous: bool,
    feature_data: EBMUniData,
    stdev: NDArray[np.float64],
) -> UnivariateDF:
    def proc_scores(scores: NDArray[np.float64]) -> list[float]:
        return array_to_list(scores, continuous)

    return UnivariateDF(
        value=feature_data["names"],
        score=proc_scores(feature_data["scores"]),
        # For some reason, the standard deviations array has an extra 0 in
        # in the front and thus is one longer than the scores array.
        stdev=proc_scores(stdev[1:]),
    )


def build_scores_array(
    arr: NDArray[np.float64],
    left_type: VarType,
    right_type: VarType,
) -> NDArray[np.float64]:
    # any continuous dimension is going to be one less than the names length,
    # so copy the last row/column to the end in these cases
    if left_type == VarType.CNT:
        arr = np.vstack((arr, arr[-1, :]))
    if right_type == VarType.CNT:
        arr = np.column_stack((arr, arr[:, -1]))
    return arr


def get_bivariate_df(
    all_features: AllFeatures,
    ebm_global: ExplainableBoostingClassifier,
    name: str,
    data_index: int,
    stdevs: IndexedVectors,
) -> BivariateData:
    def lookup_feature_type(name: str) -> VarType:
        return all_features[name][0]

    feature_data = ebm_global.data(data_index)
    # left is first dimension, right is second
    left_name, right_name = tuple(name.split(" x "))

    left_type = lookup_feature_type(left_name)
    right_type = lookup_feature_type(right_name)

    left_index = pd.Index(feature_data["left_names"], name="left_value")
    right_index = pd.Index(feature_data["right_names"], name="right_value")

    def stack_array(arr: NDArray[np.float64], name: str) -> "pd.Series[float]":
        return cast(
            "pd.Series[float]",
            pd.DataFrame(
                build_scores_array(arr, left_type, right_type),
                index=left_index,
                columns=right_index,
            ).stack(),
        ).rename(name)

    # the standard deviations are in an array that has 1 larger shape than the
    # scores array in both directions where the first row/column is all zeros.
    # Not sure why it is all zeros, but in order to make it line up with the
    # scores array we need to shave off the first row/column.
    return BivariateData(
        left=Variable(name=left_name, type=left_type.value),
        right=Variable(name=right_name, type=right_type.value),
        df=pd.concat(
            [
                stack_array(feature_data["scores"], "score"),
                stack_array(stdevs[data_index][1:, 1:], "stdev"),
            ],
            axis=1,
        )
        .reset_index()
        .to_dict(orient="list"),
    )


def get_global_scores(ebm_global: ExplainableBoostingClassifier) -> GlobalScoreData:
    glob = ebm_global.data()
    return GlobalScoreData(variable=glob["names"], score=glob["scores"])


def get_univariate_list(
    ebm_global: ExplainableBoostingClassifier,
    all_features: AllFeatures,
    stdevs: IndexedVectors,
) -> list[UnivariateData]:
    return [
        UnivariateData(
            name=name,
            vartype=vartype.value,
            df=get_univariate_df(
                vartype == VarType.CNT,
                ebm_global.data(i),
                stdevs[i],
            ),
        )
        for name, (vartype, i) in all_features.items()
        if vartype in [VarType.CNT, VarType.CAT]
    ]


def get_bivariate_list(
    ebm_global: ExplainableBoostingClassifier,
    all_features: AllFeatures,
    stdevs: IndexedVectors,
) -> list[BivariateData]:
    return [
        get_bivariate_df(all_features, ebm_global, name, i, stdevs)
        for name, (vartype, i) in all_features.items()
        if vartype == VarType.INT
    ]


def get_model(ebm: ExplainableBoostingClassifier) -> ModelData:
    ebm_global = ebm.explain_global()
    stdevs = cast(IndexedVectors, ebm.term_standard_deviations_)
    all_features = {
        cast(str, n): (VarType(t), i)
        for i, (n, t) in enumerate(
            map(tuple, ebm_global.selector[["Name", "Type"]].to_numpy())
        )
    }
    return ModelData(
        global_scores=get_global_scores(ebm_global),
        intercept=ebm.intercept_[0],
        univariate=get_univariate_list(ebm_global, all_features, stdevs),
        bivariate=get_bivariate_list(ebm_global, all_features, stdevs),
    )


def write_model_json(path: str, ebm: ExplainableBoostingClassifier) -> None:
    with open(path, "w") as f:
        json.dump(get_model(ebm), f)


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    sin = smk.input
    sout = smk.output

    ebm = read_model(sin["model"])
    write_model_json(sout["model"], ebm)

    label = sconf.feature_definitions.label

    def write_predictions(xpath: str, ypath: str, out_path: str) -> None:
        X = pd.read_table(xpath).drop(columns=cfg.IDX_COLS)
        y = pd.read_table(ypath)
        y_pred = pd.DataFrame(
            {
                "prob": ebm.predict_proba(X)[::, 1],
                "label": y[label],
            }
        )
        write_tsv(out_path, y_pred)

    write_predictions(sin["train_x"], sin["train_y"], sout["train_predictions"])
    write_predictions(sin["test_x"], sin["test_y"], sout["predictions"])


main(snakemake, snakemake.config)  # type: ignore
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pandas as pd
from typing import Any
from common.io import setup_logging
from common.tsv import write_tsv
from common.ebm import read_model
import common.config as cfg
from interpret.glassbox import ExplainableBoostingClassifier  # type: ignore

setup_logging(snakemake.log[0])  # type: ignore


def _write_tsv(path: str, df: pd.DataFrame) -> None:
    write_tsv(path, df, header=True)


def predict_from_x(
    ebm: ExplainableBoostingClassifier,
    df: pd.DataFrame,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    probs, explanations = ebm.predict_and_contrib(df)
    return pd.DataFrame(probs), pd.DataFrame(explanations, columns=ebm.feature_names)


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    sin = smk.input
    sout = smk.output
    ebm = read_model(sin["model"])
    predict_x = pd.read_table(sin["test_x"]).drop(columns=cfg.IDX_COLS)
    ps, xs = predict_from_x(ebm, predict_x)
    _write_tsv(sout["predictions"], ps)
    _write_tsv(sout["explanations"], xs)


main(snakemake, snakemake.config)  # type: ignore
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
import pandas as pd
import yaml
from typing import Any
from more_itertools import flatten
from sklearn.model_selection import train_test_split  # type: ignore
from interpret.glassbox import ExplainableBoostingClassifier  # type: ignore
from common.tsv import write_tsv
from common.io import setup_logging
from common.ebm import write_model
import common.config as cfg

logger = setup_logging(snakemake.log[0])  # type: ignore


def _write_tsv(smk: Any, key: str, df: pd.DataFrame) -> None:
    write_tsv(smk.output[key], df, header=True)


def dump_config(smk: Any, config: cfg.Model) -> None:
    with open(smk.output["config"], "w") as f:
        yaml.dump(config, f)


def get_interactions(
    df_columns: list[cfg.FeatureKey],
    iconfig: int | cfg.InteractionSpec,
) -> int | list[list[int]]:
    def expand_interactions(i: cfg.InteractionSpec_) -> list[list[int]]:
        if isinstance(i, str):
            return [
                [df_columns.index(i), c] for c, f in enumerate(df_columns) if f != i
            ]
        else:
            return [[df_columns.index(i.f1), df_columns.index(i.f2)]]

    if isinstance(iconfig, int):
        return iconfig
    else:
        return [*flatten(expand_interactions(i) for i in iconfig)]


def train_ebm(
    smk: Any,
    sconf: cfg.StratoMod,
    rconf: cfg.Model,
    df: pd.DataFrame,
) -> None:
    label = sconf.feature_definitions.label

    def strip_coords(df: pd.DataFrame) -> pd.DataFrame:
        return df.drop(columns=cfg.IDX_COLS)

    features = rconf.features
    feature_names = [
        k if v.alt_name is None else v.alt_name for k, v in features.items()
    ]
    misc_params = rconf.ebm_settings.misc_parameters

    if misc_params.downsample is not None:
        df = df.sample(frac=misc_params.downsample)

    train_cols = [c for c in df.columns if c != label]
    X = df[train_cols]
    y = df[label]

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        **rconf.ebm_settings.split_parameters.dict(),
    )

    cores = smk.threads

    logger.info(
        "Training EBM with %d features and %d cores",
        len(features),
        cores,
    )

    ebm = ExplainableBoostingClassifier(
        # NOTE the EBM docs show them explicitly adding interactions here like
        # 'F1 x F2' but it appears to work when I specify them separately via
        # the 'interactions' parameter
        feature_names=feature_names,
        feature_types=[f.feature_type.value for f in features.values()],
        interactions=get_interactions(feature_names, rconf.interactions),
        n_jobs=cores,
        **rconf.ebm_settings.classifier_parameters.mapping,
    )
    ebm.fit(strip_coords(X_train), y_train)

    write_model(smk.output["model"], ebm)
    _write_tsv(smk, "train_x", X_train)
    _write_tsv(smk, "train_y", y_train)
    _write_tsv(smk, "test_x", X_test)
    _write_tsv(smk, "test_y", y_test)


def main(smk: Any, sconf: cfg.StratoMod) -> None:
    rconf = sconf.models[smk.wildcards.model_key]
    df = pd.read_table(smk.input[0])
    train_ebm(smk, sconf, rconf, df)
    dump_config(smk, rconf)


main(snakemake, snakemake.config)  # type: ignore
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
library(tidyverse)
library(infotheo)

# from blablabla import nukes
`:=` = rlang::`:=`
`!!` = rlang::`!!`

root = snakemake@params$lib_path
source(file.path(root, "colocalization.r"))
source(file.path(root, "plots.r"))

format_perc <- function(x) {
    sprintf("%.4f", x)
}

format_exp <- function(x) {
    sprintf("%.1e", x)
}

make_stats_table <- function(df) {
    N <- nrow(df)
    gather(df, factor_key = TRUE) %>%
        group_by(key) %>%
        summarize(n_present = sum(!is.na(value)),
                  perc_present = 100 * n_present / N,
                  # prevent NULL error for zero length vectors in min/max
                  min = ifelse(n_present == 0, NA, min(value, na.rm = TRUE)),
                  max = ifelse(n_present == 0, NA, max(value, na.rm = TRUE)),
                  med = median(value, na.rm = TRUE),
                  mean = mean(value, na.rm = TRUE),
                  stdev = sd(value, na.rm = TRUE),
                  range = max - min) %>%
        rename(feature = key) %>%
        mutate(perc_present = format_perc(perc_present)) %>%
        mutate(across(c(min, max, med, mean, stdev, range), format_exp)) %>%
        arrange(desc(as.numeric(perc_present))) %>%
        knitr::kable(align = "r")
}

## TODO wet.....
make_feature_distribution <- function(x, labels) {
    infer_transform(x) %>%
        mutate(label = labels) %>%
        gather(-label, key = key, value = value) %>%
        ggplot() +
        aes(value, color = label) +
        geom_density() +
        xlab(NULL) +
        ylab(NULL) +
        facet_wrap(~key, scales = "free")
}

make_unlabeled_feature_distribution <- function(x) {
    infer_transform(x) %>%
        gather(key = key, value = value) %>%
        ggplot() +
        aes(value) +
        geom_density() +
        xlab(NULL) +
        ylab(NULL) +
        facet_wrap(~key, scales = "free")
}

label_summary_table <- function(y) {
    tibble(label = y) %>%
        group_by(label) %>%
        summarize(n = n(),
                  proportion = format_perc(n / N)) %>%
        knitr::kable()
}

columns <- snakemake@params[["columns"]]
query_key <- snakemake@params[["query_key"]]
label_col <- snakemake@params[["label_col"]]
path <- snakemake@input[[1]]
has_label <- !is.null(label_col)

all_columns <- if (is.null(label_col)) { columns } else { c(columns, label_col) }

x_col_types <- rep("-", length(all_columns)) %>%
  as.list() %>%
  setNames(all_columns) %>%
  c(list(".default"="d")) %>%
  do.call(cols, .)

df_x <- readr::read_tsv(path, col_types = x_col_types)
features <- names(df_x)
N <- nrow(df_x)
 99
100
101
102
103
104
df_y <- readr::read_tsv(path, col_types = cols(
                                  !!label_col := "c",
                                  .default = "-")) %>%
    pull(!!label_col)

y_labels <- unique(df_y)
117
label_summary_table(df_y)
127
128
df_x %>%
    make_stats_table()
134
135
136
137
138
139
140
141
142
143
144
145
print_label_tbl <- function(label) {
    cat(sprintf("## %s Label\n\n", label))

    df_x %>%
        filter(df_y == label) %>%
        make_stats_table() %>%
        print()

    cat("\n\n")
}

walk(as.list(y_labels), print_label_tbl)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
information <- function(df, var1, var2) {
    .df <- df %>%
        select(all_of(c(var1, var2))) %>%
        drop_na()
    n <- nrow(.df)
    if (n > 0) {
        nbreaks <- sqrt(n) %>% ifelse(. > 2, ., 2)
        mi <- .df %>%
            ## the discretize function doesn't seem to work the way I want, so
            ## just use 'cut' since I know what that does
            ## discretize() %>%
            mutate(across(everything(),
                          ~ cut(.x, breaks = nbreaks, labels = FALSE) %>%
                              as.vector())) %>%
            mutinformation()
        H1 <- mi[1, 1]
        H2 <- mi[2, 2]
        I <- mi[1, 2]
    } else {
        H1 <- NA
        H2 <- NA
        I <- NA
    }
    list(H1 = H1,
         H2 = H2,
         ## mutual information
         I = I,
         ## mutual information normalized to joint entropy
         Inorm = I / (H1 + H2 - I),
         ## mutual information normalized to the first feature
         I_H1 = I / H1,
         ## variation of information (if a metric is needed)
         VI = H1 + H2 - 2 * I)
}

info_df <- function(features, df_info) {
    features %>%
        as.list() %>%
        map(~ information(df_info, "label", .x)) %>%
        tibble(i = ., param = features) %>%
        unnest_wider(i) %>%
        drop_na()
}

info_plot <- function(df) {
    ggplot(df, aes(reorder(param, desc(I_H1)), I_H1)) +
        geom_col() +
        xlab(NULL) +
        ylab("Mutual Inf. (Normalized to Label)") +
        theme(axis.text.x = element_text(hjust = 1, vjust = 0.5, angle = 90))
}

print_info_plot <- function(df_info, features) {
    info_df(features, df_info) %>%
        info_plot() %>%
        print()
}

df_info_na <- df_x %>%
    mutate(label = as.integer(df_y == "tp"))

df_info_filled <- df_info_na %>%
    mutate(across(everything(), ~ if_else(is.na(.x), 0.0, as.double(.x))))
232
print_info_plot(df_info_filled, features)
240
print_info_plot(df_info_na, features)
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
print_coloc <- function(df_bool, comb_df) {
    mutate(comb_df, asymm_jaccard = ajaccard(df_bool, var.x, var.y)) %>%
        make_xy_tile_plot("var.x",
                          "var.y",
                          "asymm_jaccard",
                          "starting set",
                          "overlapping set") %>%
        print()

    cat("\n\n")
}

combinations <- df_x %>%
    names() %>%
    cross_tibble()

df_x_bool <- to_binary(df_x)
281
282
283
cat("## All labels\n\n")

print_coloc(df_x_bool, combinations)
289
290
291
292
293
294
295
296
297
298
299
print_label_coloc <- function(label) {
    cat(sprintf("## %s only\n\n", label))

    df_x_bool <- filter(df_x, df_y == label) %>%
        to_binary()
    print_coloc(df_x_bool, combinations)

    cat("\n\n")
}

walk(as.list(y_labels), print_label_coloc)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
## ASSUME these will be the same for TP/FP/both
perfect_overlaps <- df_x_bool %>%
    perfect_overlapping_sets(combinations, "var.x", "var.y")

print_subset_cor_plot <- function(df, subset) {
    df %>%
        select(all_of(subset)) %>%
        drop_na() %>%
        make_cor_plot() %>%
        print()
}

print_cor_plots <- function(df, subsets) {
    cat(sprintf("number of rows: %s\n\n", nrow(df)))

    walk(as.list(subsets), ~ print_subset_cor_plot(df, .x))

    cat("\n\n")
}
336
337
338
cat("## All labels\n\n")

print_cor_plots(df_x, perfect_overlaps)
344
345
346
347
348
349
350
print_label_cor <- function(label) {
    cat(sprintf("## %s only\n\n", label))
    filter(df_x, df_y == label) %>% print_cor_plots(perfect_overlaps)
    cat("\n\n")
}

walk(as.list(y_labels), print_label_cor)
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
print_labeled_plot <- function(x, name) {
    cat(sprintf("## %s\n\n", name))

    .df <- tibble(x = x, y = df_y) %>%
        filter(!is.na(x))

    label_summary_table(.df$y) %>% print()

    cat("\n\n")

    .x <- .df$x
    .y <- .df$y

    if (length(.x) == 0) {
        cat("Feature has no values")
    } else if (max(.x) - min(.x) == 0) {
        cat(sprintf("Feature has one value: %.1f", max(.x)))
    } else {
        print(make_feature_distribution(.x, .y))
    }

    cat("\n\n")
}

print_unlabeled_plot <- function(x, name) {
    cat(sprintf("## %s\n\n", name))

    .df <- tibble(x = x) %>%
        filter(!is.na(x))

    .x <- .df$x

    if (length(.x) == 0) {
        cat("Feature has no values")
    } else if (max(.x) - min(.x) == 0) {
        cat(sprintf("Feature has one value: %.1f", max(.x)))
    } else {
        print(make_unlabeled_feature_distribution(.x))
    }

    cat("\n\n")
}

if (has_label) {
    iwalk(df_x, ~ print_labeled_plot(.x, .y))
} else {
    iwalk(df_x, ~ print_unlabeled_plot(.x, .y))
}
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
library(tidyverse)
library(ggpubr)

root = snakemake@params$lib_path
source(file.path(root, "plots.r"))

read_df <- function(path) {
    readr::read_tsv(path, col_types = cols(.default = "d"))
}

has_label <- "truth_y" %in% names(snakemake@input)

pred_y <- read_df(snakemake@input[["predictions"]])
explain_x <- read_df(snakemake@input[["explanations"]])

query_key <- snakemake@params[["query_key"]]
27
28
29
30
31
32
33
truth_y <- readr::read_tsv(snakemake@input[["truth_y"]],
                           col_types = cols(chrom = "-",
                                            chromStart = "-",
                                            chromEnd = "-",
                                            variant_index = "-",
                                            .default = "d"))
y <- tibble(label = truth_y$label, prob = pred_y$`1`)
48
49
cat(sprintf("* N: %i\n", nrow(pred_y)))
cat(sprintf("* Perc. Pos: %f\n\n", sum(y$label)/nrow(y)))
57
58
59
60
61
62
63
64
65
66
if (has_label) {
    ggplot(y, aes(prob, color = factor(label))) +
        geom_density() +
        xlab("probability") +
        scale_color_discrete(name = "label")
} else {
    ggplot(pred_y, aes(x = `1`)) +
        geom_density() +
        xlab("probability")
}
70
71
72
73
74
75
76
77
78
79
80
81
82
AllP <- sum(y$label)
AllN <- nrow(y) - AllP

roc <- y %>%
    arrange(prob) %>%
    mutate(thresh_FN = cumsum(label),
           thresh_TN = row_number() - thresh_FN,
           thresh_TP = AllP - thresh_FN,
           thresh_FP = AllN - thresh_TN,
           TPR = thresh_TP / AllP,
           TNR = thresh_TN / AllN,
           FPR = 1 - TNR,
           precision = thresh_TP / (thresh_TP + thresh_FP))
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
cat("## Calibration\n\n")

nbins <- 10

y %>%
    mutate(bin = cut(prob, nbins, labels = FALSE) / nbins) %>%
    group_by(bin) %>%
    summarize(mean_pred = mean(prob), frac_pos = mean(label)) %>%
    ggplot(aes(mean_pred, frac_pos)) +
    geom_point() +
    geom_line() +
    geom_abline(linetype = "dotted", color = "red") +
    xlim(0, 1) +
    ylim(0, 1)

cat("## ROC Curves\n\n")

roc %>%
    arrange(FPR, TPR) %>%
    ggplot(aes(FPR, TPR)) +
    geom_line()

roc %>%
    filter(!is.na(precision)) %>%
    arrange(TPR, desc(precision)) %>%
    ggplot(aes(TPR, precision)) +
    geom_line()
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
library(tidyverse)
library(ggpubr)

root = snakemake@params$lib_path
source(file.path(root, "plots.r"))

to_tibble <- function(lst) {
    do.call(tibble, lst)
}

read_model <- function(path) {
    jsonlite::read_json(path, simplifyVector = TRUE, simplifyDataFrame = FALSE)
}

read_predictions <- function(path) {
    readr::read_tsv(path, col_types = cols(.default = "d"))
}

lookup_input_path <- function(mapping, k) {
    pluck(mapping, as.character(as.integer(k)))
}

to_univariate <- function(model) {
    model$univariate %>%
        map(~ list(meta = .x[c("name", "vartype")], df = to_tibble(.x[["df"]])))
}

to_bivariate <- function(model) {
    model$bivariate %>%
        map(~ list(left = .x$left, right = .x$right, df = to_tibble(.x$df)))
}

run_features <- snakemake@params[["features"]]
error_labels <- snakemake@params[["error_labels"]]

mod <- read_model(snakemake@input[["model"]])
test_pred <- read_predictions(snakemake@input[["predictions"]])
train_pred <- read_predictions(snakemake@input[["train_predictions"]])

train_x <- readr::read_tsv(snakemake@input[["train_x"]],
                           col_types = cols(chrom = "-",
                                            chromStart = "-",
                                            chromEnd = "-",
                                            variant_index = "-",
                                            .default = "d"))
train_y <- readr::read_tsv(snakemake@input[["train_y"]],
                           col_types = cols(.default = "d"))

threshold <- train_pred %>% pull(label) %>% mean()

alltrain <- bind_cols(train_x, train_pred) %>%
    mutate(pred = prob > threshold)

VCF_input_name <- "VCF_input"

global_df <- to_tibble(mod$global_scores)

univariate <- to_univariate(mod)
bivariate <- to_bivariate(mod)
80
cat("\n")
93
94
95
96
ggplot(test_pred, aes(prob, color = factor(label))) +
    geom_density() +
    xlab("probability") +
    scale_color_discrete(name = "label")
102
103
104
105
106
107
108
109
110
111
112
113
nbins <- 10

test_pred %>%
    mutate(bin = cut(prob, nbins, labels = FALSE) / nbins) %>%
    group_by(bin) %>%
    summarize(mean_pred = mean(prob), frac_pos = mean(label)) %>%
    ggplot(aes(mean_pred, frac_pos)) +
    geom_point() +
    geom_line() +
    geom_abline(linetype = "dotted", color = "red") +
    xlim(0, 1) +
    ylim(0, 1)
121
122
123
124
125
126
127
128
129
130
131
132
133
AllP <- sum(test_pred$label)
AllN <- nrow(test_pred) - AllP

roc <- test_pred %>%
    arrange(prob) %>%
    mutate(thresh_FN = cumsum(label),
           thresh_TN = row_number() - thresh_FN,
           thresh_TP = AllP - thresh_FN,
           thresh_FP = AllN - thresh_TN,
           TPR = thresh_TP / AllP,
           TNR = thresh_TN / AllN,
           FPR = 1 - TNR,
           precision = thresh_TP / (thresh_TP + thresh_FP))
139
140
141
142
143
144
145
146
147
148
roc %>%
    arrange(FPR, TPR) %>%
    ggplot(aes(FPR, TPR)) +
    geom_line()

roc %>%
    filter(!is.na(precision)) %>%
    arrange(TPR, desc(precision)) %>%
    ggplot(aes(TPR, precision)) +
    geom_line()
159
160
161
162
ggplot(global_df, aes(score, reorder(variable, score))) +
    geom_col() +
    xlab("Score") +
    ylab(NULL)
173
174
175
176
177
tibble(x = "intercept", y = mod$intercept) %>%
    ggplot(aes(x, y)) +
    geom_col() +
    xlab(NULL) +
    ylab("score")
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
get_truncation <- function(s) {
    run_features[[s]][["visualization"]][["truncate"]]
}

get_split_missing <- function(s) {
    run_features[[s]][["visualization"]][["split_missing"]]
}

get_fill_na <- function(s) {
    run_features[[s]][["fill_na"]]
}

get_plot_type <- function(s) {
    run_features[[s]][["visualization"]][["plot_type"]]
}

truncate_maybe <- function(df, name) {
    t <- get_truncation(name)
    lower <- t[["lower"]]
    upper <- t[["upper"]]
    caption <- if (!is.null(lower) && !is.null(upper)) {
                   sprintf("Truncated from %d to %d", lower, upper)
               } else if (!is.null(lower)) {
                   sprintf("Truncated from %d to -Inf", lower)
               } else if (!is.null(upper)) {
                   sprintf("Truncated from -Inf to %d", upper)
               }
    .df <- if (is.null(lower) && is.null(upper)) {
               df
           } else {
               .lower <- if (is.null(lower)) min(df$value) else lower
               .upper <- if (is.null(upper)) max(df$value) else upper
               filter(df, .lower <= value & value <= .upper)
           }
    list(df = .df, lower = lower, upper = upper, caption = caption)
}

null2alt <- function(default, x) {
    if (is.null(x)) default else x
}

null2na <- function(x) {
    null2alt(NA, x)
}

make_integer_plot <- function(df, name, lower = NULL, upper = NULL,
                              ylab = "Score") {
    fill_cols <- c("score", "stdev")
    .lower <- null2alt(min(df$value), lower)
    .upper <- null2alt(max(df$value), upper)
    .join <- tibble(value = .lower:.upper)
    mutate(df, value = ceiling(value)) %>%
        right_join(.join, by = "value") %>%
        arrange(value) %>%
        fill(all_of(fill_cols), .direction = "downup") %>%
        ggplot(aes(value, score)) +
        geom_col() +
        xlab(name) +
        ylab(ylab) +
        geom_errorbar(aes(ymin = score - stdev, ymax = score + stdev))
}

## TODO use inverse logit here?
make_fraction_plot <- function(df) {
    df %>%
        group_by(value) %>%
        summarize(frac = mean(label),
                  stderr = sqrt(frac * (1 - frac) / n())) %>%
        ggplot(aes(value, frac)) +
        geom_point() +
        geom_errorbar(aes(ymin = frac - stderr, ymax = frac + stderr),
                      width = 0.1)
}

make_integer_fraction_plot <- function(df, name, lower = NULL, upper = NULL) {
    .name <- sym(name)
    df %>%
        mutate(value = ceiling({{ .name }})) %>%
        make_fraction_plot() +
        xlab(name) +
        ylab("Frac(TP)") +
        coord_trans(xlim = c(null2na(lower), null2na(upper)))
}

make_continuous_plot <- function(df, name, ylab = "Score") {
    ggplot(df, aes(value, score)) +
        geom_step(aes(y = score + stdev), color = "red") +
        geom_step(aes(y = score - stdev), color = "red") +
        geom_step() +
        xlab(name) +
        ylab(ylab)
}

make_continuous_fraction_plot <- function(df, name) {
    .name <- sym(name)
    df %>%
        ## TODO average the bin ends to make the axis cleaner
        mutate(value = cut({{ .name }}, 20)) %>%
        make_fraction_plot() +
        xlab(name) +
        ylab("Frac(TP)") +
        theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5))
}

make_categorical_plot <- function(df, name, ylab = "Score") {
    df %>% 
        ggplot(aes(factor(value), score)) +
        geom_col() +
        geom_errorbar(aes(ymin = score - stdev, ymax = score + stdev), width = 0.1) +
        xlab(name) +
        ylab(ylab)
}

make_categorical_fraction_plot <- function(df, name) {
    .name <- sym(name)
    df %>%
        mutate(value = factor({{ .name }})) %>%
        make_fraction_plot() +
        xlab(name) +
        ylab("Frac(TP)")
}

standardize_y_axes <- function(ps) {
    lims <- map(ps, ~ layer_scales(.x)[["y"]]$get_limits()) %>%
        do.call(cbind, .)
    new <- c(min(lims[1, ]), max(lims[2, ]))
    map(ps, ~ .x + ylim(new))
}

make_split_plot <- function(df, name, bound, fun) {
    missing_val <- get_fill_na(name)
    missing <- filter(df, value == missing_val) %>%
        mutate(value = "Missing")
    nonmissing <- filter(df, value != missing_val) %>%
        mutate(value = if_else(value < bound, bound, value))
    bar <- ggplot(missing, aes(factor(value), score)) +
        geom_col() +
        geom_errorbar(aes(ymax = score + stdev,
                          ymin = score - stdev),
                      width = 0.2) +
        xlab(NULL)
    step <- fun(nonmissing, NULL) +
        ylab(NULL) +
        theme(axis.text.y = element_blank(),
              axis.ticks.y = element_blank())
    list(bar, step) %>%
        standardize_y_axes() %>%
        ggarrange(plotlist = ., ncol = 2, widths = c(1, 5)) %>%
        annotate_figure(bottom = text_grob(name))
}

make_split_fraction_plot <- function(df, name, bound, fun) {
    .name <- sym(name)
    missing_val <- get_fill_na(name)
    missing <- filter(df, {{ .name }} == missing_val) %>%
        mutate({{ .name }} := "Missing")
    nonmissing <- filter(df, {{ .name }} != missing_val)
    bar <- make_categorical_fraction_plot(missing, name) +
        xlab(NULL)
    step <- fun(nonmissing, name) +
        xlab(NULL) +
        ylab(NULL) +
        theme(axis.text.y = element_blank(),
              axis.ticks.y = element_blank())
    list(bar, step) %>%
        standardize_y_axes() %>%
        ggarrange(plotlist = ., ncol = 2, widths = c(1, 5), align = "h") %>%
        annotate_figure(bottom = text_grob(name))
}

wrap_split_maybe <- function(name, split_f, f) {
    s <- get_split_missing(name)
    if (is.null(s)) f else partial(split_f, fun = f, bound = s)
}

print_uv_plot <- function(vartype, df, name) {
    r <- if (vartype == "continuous") {
             tr <- truncate_maybe(df, name)
             t <- get_plot_type(name)
             fs <- if (t == "step") {
                       list(
                           make_continuous_plot,
                           make_continuous_fraction_plot
                       )
                   } else if (t == "bar") {
                       list(
                           partial(
                               make_integer_plot,
                               lower = tr[["lower"]],
                               upper = tr[["upper"]]
                           ),
                           partial(
                               make_integer_fraction_plot,
                               lower = tr[["lower"]],
                               upper = tr[["upper"]]
                           )
                       )
                   } else {
                       stop(sprintf("wrong type, dummy; got %s", t))
                   }
             ## TODO only continuous plots can be split (for now)
             list(
                 feat_f = wrap_split_maybe(name, make_split_plot, fs[[1]]),
                 frac_f = wrap_split_maybe(name, make_split_fraction_plot, fs[[2]]),
                 df = tr[["df"]],
                 caption = tr[["caption"]]
             )
         } else if (vartype == "categorical") {
             list(
                 feat_f = make_categorical_plot,
                 frac_f = make_categorical_fraction_plot,
                 df = df,
                 caption = NULL
             )
         } else {
             stop(sprintf("wrong plot type, dummy; got %s", vartype))
         }
    p0 <- r$feat_f(r$df, name)
    p1 <- r$frac_f(alltrain, name)
    cat(sprintf("## %s\n", name))
    print(p0)
    cat("\n\n")
    print(p1)
    cat("\n\n")
    if (!is.null(r$caption)) {
        cat(sprintf("%s\n\n", r$caption))
    }
}

walk(univariate, ~ print_uv_plot(.x$meta$vartype, .x$df, .x$meta$name))
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
cont_cont_plot <- function(df, yvar, left_name, right_name) {
  # poor-mans 2d step heatmap plot thing
  .yvar <- sym(yvar)
  df %>%
    group_by(right_value) %>%
    mutate(left_upper = lead(left_value)) %>%
    ungroup() %>%
    group_by(left_value) %>%
    mutate(right_upper = lead(right_value)) %>%
    ungroup() %>%
    filter(!is.na(left_upper)) %>%
    filter(!is.na(right_upper)) %>%
    ggplot() +
    geom_rect(aes(xmin = left_value,
                  xmax = left_upper,
                  ymin = right_value,
                  ymax = right_upper,
                  fill = {{ .yvar }})) +
    xlab(left_name) +
    ylab(right_name)
}

print_cont_cont_plot <- function(df, left_name, right_name) {
    x_tr <- df %>%
        rename(value = left_value) %>%
        truncate_maybe(left_name)
    y_tr <- df %>%
        rename(value = right_value) %>%
        truncate_maybe(right_name)
    lims <- coord_trans(xlim = c(null2na(x_tr$lower), null2na(x_tr$upper)),
                        ylim = c(null2na(y_tr$lower), null2na(y_tr$upper)))
    if (!is.null(x_tr$caption)) {
        cat(sprintf("%s: %s\n\n", left_name, x_tr$caption))
    }
    if (!is.null(y_tr$caption)) {
        cat(sprintf("%s: %s\n\n", right_name, y_tr$caption))
    }
    cat("### Scores\n\n")
    p0 <- cont_cont_plot(df, "score", left_name, right_name) +
        scale_fill_gradient2(midpoint = 0) +
        lims
    print(p0)
    cat("\n\n")
    cat("### Stdevs\n\n")
    p1 <- cont_cont_plot(df, "stdev", left_name, right_name) +
        scale_fill_gradient() +
        lims
    print(p1)
}

print_cont_cat_plot_inner <- function(df, cat_name, cont_name) {
    cat_val <- as.character(df$c[[1]])
    tr <- truncate_maybe(df, cont_name)
    t <- get_plot_type(cont_name)
    f <- if (t == "step") {
             partial(make_continuous_plot)
         } else if (t == "bar") {
             partial(make_integer_plot,
                     lower = tr[["lower"]],
                     upper = tr[["upper"]]
                     )
         } else {
             stop(sprintf("wrong plot type: got %s", t))
         }
    p <- wrap_split_maybe(cont_name, make_split_plot, f)(tr[["df"]], cont_name)
    cat(sprintf("### %s = %s\n\n", cat_name, cat_val))
    cat(sprintf("%s\n\n", tr[["caption"]]))
    print(p)
    cat("\n\n")
}

print_cont_cat_plot <- function(df, cat, cont, cat_name, cont_name) {
    ## this 'all_of' thing is needed to silence a weird warning about
    ## using vectors to select things (I disagree with it, but whatever)
    .df <- rename(df, value = all_of(cont), c = all_of(cat)) %>%
        mutate(lower = score - stdev,
               upper = score + stdev) %>%
        group_split(c) %>%
        walk(~ print_cont_cat_plot_inner(.x, cat_name, cont_name))
}

print_cat_cat_plot <- function(df, left_name, right_name) {
    p <- ggplot(df, aes(factor(left_value),
                         score,
                         fill = factor(right_value)
                         )) +
        geom_col(position = "dodge") +
        geom_errorbar(aes(ymin = score - stdev,
                          ymax = score + stdev),
                      width = 0.1,
                      position = position_dodge(0.9)) +
        xlab(left_name) +
        scale_fill_discrete(name = right_name)
    print(p)
}

print_bv_plot_inner <- function(L, R, df) {
    if (L$type == "continuous" && R$type == "continuous") {
        print_cont_cont_plot(df, L$name, R$name)
    } else if (L$type == "categorical" && R$type == "continuous") {
        print_cont_cat_plot(df, "left_value", "right_value", L$name, R$name)
    } else if (L$type == "continuous" && R$type == "categorical") {
        print_cont_cat_plot(df, "right_value", "left_value", R$name, L$name)
    } else if (L$type == "categorical" && R$type == "categorical") {
        print_cat_cat_plot(df, L$name, R$name)
    } else {
        sprintf("Types are wrong, dummy: %s and/or %s", L$type, R$type)
    }
}

print_bv_plot <- function(L, R, df) {
    cat(sprintf("## %s x %s\n\n", L$name, R$name))

    print_bv_plot_inner(L, R, df)

    cat("\n\n")
}

if (length(bivariate) == 0) {
    cat("None\n\n")
} else {
    walk(bivariate, ~ print_bv_plot(.x$left, .x$right, .x$df))
}
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
logErr <- function(p) {
    -log10(1 - p)
}

make_perf_df <- function(df) {
    df %>%
        group_by(value) %>%
        summarize(precision = sum(label & pred) / sum(pred),
                  recall = sum(label & pred) / sum(label),
                  f1 = 2 * (precision * recall) / (precision + recall)) %>%
        pivot_longer(cols = c(precision, recall, f1),
                     names_to = "metric",
                     values_to = "mvalue") %>%
        mutate(mvalue = logErr(mvalue))
        ## ggplot(aes(value, mvalue, color = metric)) +
        ## labs(x = "Feature Value",
        ##      y = "-log10(metric)")
}

make_perf_plot <- function(df) {
    df %>%
        group_by(value) %>%
        summarize(precision = sum(label & pred) / sum(pred),
                  recall = sum(label & pred) / sum(label),
                  f1 = 2 * (precision * recall) / (precision + recall)) %>%
        pivot_longer(cols = c(precision, recall, f1),
                     names_to = "metric",
                     values_to = "mvalue") %>%
        mutate(mvalue = logErr(mvalue)) %>%
        ggplot(aes(value, mvalue, color = metric)) +
        xlab(NULL) +
        ylab("-log10(metric)")
}

make_integer_perf_plot <- function(df, name, lower = NULL, upper = NULL) {
    .n <- sym(name)
    df %>%
        mutate(value = ceiling({{ .n }})) %>%
        make_perf_plot() +
        geom_point() +
        geom_line() +
        coord_trans(xlim = c(null2na(lower), null2na(upper)))
}

make_continuous_perf_plot <- function(df, name) {
    .n <- sym(name)
    df %>%
        mutate(value = cut({{ .n }}, 20)) %>%
        make_perf_plot() +
        geom_point() +
        theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5))
}

make_categorical_perf_plot <- function(df, name) {
    .n <- sym(name)
    df %>%
        mutate(value = factor({{ .n }})) %>%
        make_perf_plot() +
        geom_col(position = "dodge", aes(fill = metric))
}

make_perf_split_plot <- function(df, name, bound, fun) {
    .n <- sym(name)
    missing_val <- get_fill_na(name)
    missing <- filter(df, {{ .n }} == missing_val) %>%
        mutate({{ .n }} := "Missing")
    nonmissing <- filter(df, {{ .n }} != missing_val) %>%
        mutate({{ .n }} := if_else({{ .n }} < bound, bound, {{ .n }}))
    bar <- make_categorical_perf_plot(missing, name) +
        xlab(NULL)
    step <- fun(nonmissing, name) +
        ylab(NULL) +
        xlab(NULL) +
        theme(axis.text.y = element_blank(),
              axis.ticks.y = element_blank())
    list(bar, step) %>%
        standardize_y_axes() %>%
        ggarrange(plotlist = ., ncol = 2, widths = c(1, 5),
                  common.legend = TRUE, legend = "right", align = "h") %>%
        annotate_figure(bottom = text_grob(name))
}

print_perf_profile_plot <- function(vartype, name) {
    r <- if (vartype == "continuous") {
             tr <- get_truncation(name)
             t <- get_plot_type(name)
             f <- if (t == "step") {
                      make_continuous_perf_plot
                  } else if (t == "bar") {
                      partial(
                          make_integer_perf_plot,
                          lower = tr[["lower"]],
                          upper = tr[["upper"]]
                      )
                  } else {
                      stop(sprintf("wrong type, dummy; got %s", t))
                  }
             list(
                 feat_f = wrap_split_maybe(name, make_perf_split_plot, f),
                 caption = tr[["caption"]]
             )
         } else if (vartype == "categorical") {
             list(
                 feat_f = make_categorical_perf_plot,
                 caption = NULL
             )
         } else {
             stop(sprintf("wrong plot type, dummy; got %s", vartype))
         }
    cat(sprintf("## %s\n", name))
    print(r$feat_f(alltrain, name))
    cat("\n\n")
    if (!is.null(r$caption)) {
        cat(sprintf("%s\n\n", r$caption))
    }
}

walk(univariate, ~ print_perf_profile_plot(.x$meta$vartype, .x$meta$name))
ShowHide 72 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/ndwarshuis/stratomod
Name: stratomod
Version: v8.0.4
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...