Snakemake template for building reusable and scalable machine learning pipelines with mikropml

public public 1yr ago Version: v1.3.0 0 bookmarks

Snakemake is a workflow manager that enables massively parallel and reproducible analyses. Snakemake is a suitable tool to use when you can break a workflow down into discrete steps, with each step having input and output files.

mikropml is an R package for supervised machine learning pipelines. We provide this example workflow as a template to get started running mikropml with snakemake. We hope you then customize the code to meet the needs of your particular ML task.

For more details on these tools, see the Snakemake tutorial and read the mikropml docs .

The Workflow

The Snakefile contains rules which define the output files we want and how to make them. Snakemake automatically builds a directed acyclic graph (DAG) of jobs to figure out the dependencies of each of the rules and what order to run them in. This workflow preprocesses the example dataset, calls mikropml::run_ml() for each seed and ML method set in the config file, combines the results files, plots performance results (cross-validation and test AUROCs, hyperparameter AUROCs from cross-validation, and benchmark performance), and renders a simple R Markdown report as a GitHub-flavored markdown file ( see example here ).

rulegraph

The DAG shows how calls to run_ml can run in parallel if snakemake is allowed to run more than one job at a time. If we use 100 seeds and 4 ML methods, snakemake would call run_ml 400 times. Here's a small example DAG if we were to use only 2 seeds and 1 ML method:

dag

Usage

Full usage instructions recommended by snakemake are available in the snakemake workflow catalog . Snakemake recommends using snakedeploy to use this workflow as a module in your own project.

Alternatively, you can download this repo and modify the code directly to suit your needs. See instructions here .

Help & Contributing

If you come across a bug, open an issue and include a minimal reproducible example.

If you have questions, create a new post in Discussions .

If you’d like to contribute, see our guidelines here .

Code of Conduct

Please note that the mikropml-snakemake-workflow is released with a Contributor Code of Conduct . By contributing to this project, you agree to abide by its terms.

More resources

Code Snippets

17
18
script:
    "../scripts/combine_results.R"
34
35
script:
    "../scripts/combine_hp_perf.R"
47
48
script:
    "../scripts/mutate_benchmark.R"
28
29
30
31
32
33
shell:
    """
    for f in {input.figs}; do
        cp $f {params.outdir}
    done
    """
58
59
script:
    "../scripts/report.Rmd"
17
18
script:
    "../scripts/preproc.R"
43
44
script:
    "../scripts/train_ml.R"
64
65
script:
    "../scripts/find_feature_importance.R"
80
81
script:
    "../scripts/calc_model_sensspec.R"
14
15
script:
    "../scripts/plot_performance.R"
SnakeMake From line 14 of rules/plot.smk
34
35
script:
    "../scripts/plot_feature_importance.R"
SnakeMake From line 34 of rules/plot.smk
46
47
script:
    "../scripts/make_blank_plot.R"
SnakeMake From line 46 of rules/plot.smk
63
64
script:
    "../scripts/plot_hp_perf.R"
SnakeMake From line 63 of rules/plot.smk
81
82
script:
    "../scripts/plot_benchmarks.R"
SnakeMake From line 81 of rules/plot.smk
94
95
script:
    "../scripts/plot_roc_curves.R"
SnakeMake From line 94 of rules/plot.smk
107
108
109
110
shell:
    """
    snakemake --{wildcards.cmd} --configfile {params.config_path} 2> {log} > {output.dot}
    """
122
123
124
125
shell:
    """
    cat {input.dot} | dot -T png 2> {log} > {output.png}
    """
SnakeMake From line 122 of rules/plot.smk
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
schtools::log_snakemake()
library(tidyverse)

model <- read_rds(snakemake@input[["model"]])
test_dat <- read_csv(snakemake@input[["test"]])
outcome_colname <- snakemake@params[["outcome_colname"]]
mikropml::calc_model_sensspec(
  model,
  test_dat,
  outcome_colname
) %>%
  bind_cols(schtools::get_wildcards_tbl()) %>%
  write_csv(snakemake@output[["csv"]])
1
2
3
4
5
6
schtools::log_snakemake()

models <- lapply(snakemake@input[["rds"]], function(x) readRDS(x))
hp_perf <- mikropml::combine_hp_performance(models)
hp_perf$method <- snakemake@wildcards[["method"]]
saveRDS(hp_perf, file = snakemake@output[["rds"]])
1
2
3
4
5
6
schtools::log_snakemake()
library(dplyr)

snakemake@input[["csv"]] %>%
  purrr::map_dfr(readr::read_csv) %>%
  readr::write_csv(snakemake@output[["csv"]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
schtools::log_snakemake()
library(mikropml)
library(dplyr)
library(readr)
doFuture::registerDoFuture()
future::plan(future::multicore, workers = snakemake@threads)
message(paste("# workers: ", foreach::getDoParWorkers()))

model <- readRDS(snakemake@input[["model"]])
outcome_colname <- snakemake@params[["outcome_colname"]]
train_dat <- model$trainingData
names(train_dat)[names(train_dat) == ".outcome"] <- outcome_colname
test_dat <- read_csv(snakemake@input[["test"]])
method <- snakemake@params[["method"]]
seed <- as.numeric(snakemake@params[["seed"]])

outcome_type <- get_outcome_type(c(
  train_dat %>% pull(outcome_colname),
  test_dat %>% pull(outcome_colname)
))
class_probs <- outcome_type != "continuous"
perf_metric_function <- get_perf_metric_fn(outcome_type)
perf_metric_name <- get_perf_metric_name(outcome_type)

if (!is.na(seed)) {
  set.seed(seed)
}
feat_imp <- mikropml::get_feature_importance(
  trained_model = model,
  test_data = test_dat,
  outcome_colname = outcome_colname,
  perf_metric_function = perf_metric_function,
  perf_metric_name = perf_metric_name,
  class_probs = class_probs,
  method = method,
  seed = seed,
)

wildcards <- schtools::get_wildcards_tbl()

readr::write_csv(
  feat_imp %>%
    inner_join(wildcards, by = c("method", "seed")),
  snakemake@output[["feat"]]
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
schtools::log_snakemake()
library(ggplot2)
message("making a blank plot")
ggsave(
  filename = snakemake@output[["plot"]],
  plot = ggplot() +
    theme_void(),
  height = 0.1, width = 0.1,
  device = "png"
)
1
2
3
4
5
6
7
8
schtools::log_snakemake()
library(tidyverse)

wildcards <- schtools::get_wildcards_tbl()

read_tsv(snakemake@input[["tsv"]]) %>%
  bind_cols(wildcards) %>%
  write_csv(snakemake@output[["csv"]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
schtools::log_snakemake()
library(tidyverse)

dat <- read_csv(snakemake@input[["csv"]],
  col_types = cols(
    s = col_double(),
    `h:m:s` = col_time(format = "%H:%M:%S"),
    max_rss = col_double(),
    max_vms = col_double(),
    max_uss = col_double(),
    max_pss = col_double(),
    io_in = col_double(),
    io_out = col_double(),
    mean_load = col_double(),
    cpu_time = col_double(),
    method = col_character(),
    seed = col_double()
  )
) %>%
  mutate(
    runtime_mins = s / 60,
    memory_gb = max_rss / 1024
  ) %>%
  select(method, runtime_mins, memory_gb) %>%
  pivot_longer(-method, names_to = "metric") %>%
  mutate(value = round(value, 2)) %>%
  group_by(method)

bench_plot <- dat %>%
  ggplot(aes(method, value)) +
  geom_boxplot() +
  facet_wrap(metric ~ ., scales = "free") +
  theme_classic() +
  labs(y = "", x = "") +
  coord_flip()
ggsave(snakemake@output[["plot"]], plot = bench_plot)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
schtools::log_snakemake()
library(dplyr)
library(ggplot2)
library(schtools)

feat_df <- readr::read_csv(snakemake@input[["csv"]])
top_n <- as.numeric(snakemake@params[["top_n"]])

top_feats <- feat_df %>%
  group_by(method, names) %>%
  summarize(median_diff = median(perf_metric_diff)) %>%
  slice_max(order_by = median_diff, n = top_n)

feat_plot <- feat_df %>%
  right_join(top_feats, by = c("method", "names")) %>%
  mutate(features = factor(names, levels = unique(top_feats$names))) %>%
  ggplot(aes(x = perf_metric_diff, y = features, color = method)) +
  geom_boxplot() +
  facet_wrap(~method) +
  theme_sovacool()

ggsave(
  filename = snakemake@output[["plot"]],
  plot = feat_plot,
  device = "png"
)
1
2
3
4
5
6
7
8
schtools::log_snakemake()

hp_perf <- readRDS(snakemake@input[["rds"]])
hp_plot_list <- lapply(hp_perf$params, function(param) {
  mikropml::plot_hp_performance(hp_perf$dat, !!rlang::sym(param), !!rlang::sym(hp_perf$metric)) + ggplot2::theme_classic() + ggplot2::scale_color_brewer(palette = "Dark2") + ggplot2::labs(title = unique(hp_perf$method))
})
hp_plot <- cowplot::plot_grid(plotlist = hp_plot_list)
ggplot2::ggsave(snakemake@output[["plot"]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
schtools::log_snakemake()
library(tidyverse)

perf_plot <- snakemake@input[["csv"]] %>%
  read_csv() %>%
  mikropml::plot_model_performance() +
  theme_classic() +
  scale_color_brewer(palette = "Dark2") +
  coord_flip()
ggsave(snakemake@output[["plot"]], plot = perf_plot)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
schtools::log_snakemake()
library(patchwork)
library(tidyverse)


dat <- read_csv(snakemake@input[["csv"]])

calc_mean_perf <- function(sensspec_dat,
                           group_var = specificity,
                           sum_var = sensitivity,
                           custom_group_vars = NULL) {
  specificity <- sensitivity <- sd <- NULL
  dat_round <- sensspec_dat %>%
    dplyr::mutate({{ group_var }} := round({{ group_var }}, 2))
  if (!is.null(custom_group_vars)) {
    dat_grouped <- dat_round %>%
      dplyr::group_by({{ group_var }}, !!rlang::sym(custom_group_vars))
  } else {
    dat_grouped <- dat_round %>%
      dplyr::group_by({{ group_var }})
  }
  return(
    dat_grouped %>%
      dplyr::summarise(
        mean = mean({{ sum_var }}),
        sd = stats::sd({{ sum_var }})
      ) %>%
      dplyr::mutate(
        upper = mean + sd,
        lower = mean - sd,
        upper = dplyr::case_when(
          upper > 1 ~ 1,
          TRUE ~ upper
        ),
        lower = dplyr::case_when(
          lower < 0 ~ 0,
          TRUE ~ lower
        )
      ) %>%
      dplyr::rename(
        "mean_{{ sum_var }}" := mean,
        "sd_{{ sum_var }}" := sd
      )
  )
}

calc_mean_roc <- function(sensspec_dat, custom_group_vars = NULL) {
  specificity <- sensitivity <- NULL
  return(calc_mean_perf(sensspec_dat,
    group_var = specificity,
    sum_var = sensitivity,
    custom_group_vars = custom_group_vars
  ))
}

calc_mean_prc <- function(sensspec_dat, custom_group_vars = NULL) {
  sensitivity <- recall <- precision <- NULL
  return(calc_mean_perf(
    sensspec_dat %>%
      dplyr::rename(recall = sensitivity),
    group_var = recall,
    sum_var = precision,
    custom_group_vars = custom_group_vars
  ))
}

shared_ggprotos <- function(colorvar) {
  return(list(
    ggplot2::geom_ribbon(aes(fill = {{ colorvar }}), alpha = 0.5),
    ggplot2::geom_line(aes(color = {{ colorvar }})),
    ggplot2::coord_equal(),
    ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)),
    ggplot2::theme_bw(),
    ggplot2::theme(legend.title = ggplot2::element_blank())
  ))
}

plot_mean_roc <- function(dat) {
  specificity <- mean_sensitivity <- lower <- upper <- NULL
  dat %>%
    ggplot2::ggplot(ggplot2::aes(
      x = specificity, y = mean_sensitivity,
      ymin = lower, ymax = upper
    )) +
    shared_ggprotos(colorvar = method) +
    ggplot2::geom_abline(
      intercept = 1, slope = 1,
      linetype = "dashed", color = "grey50"
    ) +
    ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) +
    ggplot2::labs(x = "Specificity", y = "Mean Sensitivity")
}

plot_mean_prc <- function(dat, baseline_precision = NULL) {
  recall <- mean_precision <- lower <- upper <- NULL
  prc_plot <- dat %>%
    ggplot2::ggplot(ggplot2::aes(
      x = recall, y = mean_precision,
      ymin = lower, ymax = upper
    )) +
    shared_ggprotos(colorvar = method) +
    ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) +
    ggplot2::labs(x = "Recall", y = "Mean Precision")
  if (!is.null(baseline_precision)) {
    prc_plot <- prc_plot +
      ggplot2::geom_hline(
        yintercept = baseline_precision,
        linetype = "dashed", color = "grey50"
      )
  }
  return(prc_plot)
}
p <- (dat %>%
  calc_mean_roc(custom_group_vars = "method") %>%
  plot_mean_roc()) +
  (dat %>%
    calc_mean_prc(custom_group_vars = "method") %>%
    plot_mean_prc() +
    theme(legend.position = "none"))

ggsave(
  filename = snakemake@output[["plot"]],
  plot = p,
  device = "png",
  height = 4,
  width = 6
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
schtools::log_snakemake()
library(mikropml)

doFuture::registerDoFuture()
future::plan(future::multicore, workers = snakemake@threads)

data_raw <- readr::read_csv(snakemake@input[["csv"]])
data_processed <- preprocess_data(data_raw, outcome_colname = snakemake@params[["outcome_colname"]])

saveRDS(data_processed, file = snakemake@output[["rds"]])
12
schtools::set_knitr_opts()
16
library(knitr)
29
include_graphics(snakemake@input[['rulegraph']])
35
include_graphics(snakemake@input[['perf_plot']])
39
include_graphics(snakemake@input[['roc_plot']])
45
include_graphics(snakemake@input[['hp_plot']])
49
50
51
if (isTRUE(snakemake@params[['find_feature_importance']])) { 
    cat("## Feature Importance") 
}
55
include_graphics(snakemake@input[['feat_plot']])
64
include_graphics(snakemake@input[['bench_plot']])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
schtools::log_snakemake()
library(dplyr)
doFuture::registerDoFuture()
future::plan(future::multicore, workers = snakemake@threads)

method <- snakemake@params[["method"]]
seed <- as.numeric(snakemake@params[["seed"]])
hyperparams <- snakemake@params[["hyperparams"]][[method]]
data_processed <- readRDS(snakemake@input[["rds"]])$dat_transformed

ml_results <- mikropml::run_ml(
  dataset = data_processed,
  method = method,
  outcome_colname = snakemake@params[["outcome_colname"]],
  find_feature_importance = FALSE,
  kfold = as.numeric(snakemake@params[["kfold"]]),
  seed = seed,
  hyperparameters = hyperparams
)

wildcards <- schtools::get_wildcards_tbl()

readr::write_csv(
  ml_results$performance %>%
    inner_join(wildcards, by = c("method", "seed")),
  snakemake@output[["perf"]]
)
readr::write_csv(ml_results$test_data, snakemake@output[["test"]])
saveRDS(ml_results$trained_model, file = snakemake@output[["model"]])
84
85
script:
    "scripts/report.Rmd"
104
105
106
107
shell:
    """
    zip -r {output} {input} 2> {log}
    """
ShowHide 33 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/SchlossLab/mikropml-snakemake-workflow
Name: mikropml-snakemake-workflow
Version: v1.3.0
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...