Reproduce some tables and graphs from Angrist & Krueger (1991)

public public 1yr ago 0 bookmarks

Reproduces some tables and graphs from Angrist & Krueger (1991).

How to compile

The project is set up so that snakemake handles the installation of the required dependencies into a local virtual environment. The following external dependencies have to be manually installed:

  • A TeX distribution with pdflatex and latexmk available on the path

  • The snakemake workflow management system

A TeX distribution can be installed using your preferred method (TeX Live for Linux, MikTeX for Windows and MacTeX for MacOS are good default choices, installable via apt/rpm/etc., scoop/chocolately and homebrew, respectively). It is recommended to install snakemake in its own separate conda virtual environment (e.g. conda create -c conda-forge -c bioconda -n snakemake snakemake ).

The steps to build the project are described in its snakemake file. If snakemake is installed it can be compiled from scratch by running the snakemake command in its root directory:

 cd /path/to/project-for-pp4rs
 conda activate snakemake
 snakemake --cores all --use-conda

assuming that snakemake is available in the conda environment names snakemake. --cores all sets the number of parallel jobs equal to the number of your logical cpu cores. If you wish to run N jobs in parallel, replace it with --cores N .

Code Snippets

25
26
27
28
shell:
    "quarto render {input.qmd} -P figure_path:{input.figure} 2> {log} && \
     rm -rf {params.output_dir}/* && \
     mv -f src/presentation/presentation.html src/presentation/presentation_files {params.output_dir}"
SnakeMake From line 25 of main/Snakefile
40
41
script:
    "src/figures/birth_year_education_interactive.py"
SnakeMake From line 40 of main/Snakefile
53
54
shell:
    "latexmk -pdf -interaction=nonstopmode -jobname={params.output_name} {input.tex} 2> {log}"
SnakeMake From line 53 of main/Snakefile
77
78
script:
    "src/tables/regression_table.R"
SnakeMake From line 77 of main/Snakefile
 99
100
script:
    "src/estimation/estimate_model.R"
SnakeMake From line 99 of main/Snakefile
116
117
script:
    "src/figures/birth_year_education.py"
SnakeMake From line 116 of main/Snakefile
138
139
script:
    "src/figures/birth_year_education.py"
SnakeMake From line 138 of main/Snakefile
149
150
script:
    "src/data/prepare_data.py"
SnakeMake From line 149 of main/Snakefile
162
163
shell:
    "bash {input.script} {params.url} {output.file} 2> {log}"
SnakeMake From line 162 of main/Snakefile
179
180
shell:
    "snakemake --filegraph | dot -Tpdf > build_graphs/filegraph.pdf"
189
190
shell:
    "snakemake --rulegraph | dot -Tpdf > build_graphs/rulegraph.pdf"
199
200
shell:
    "snakemake --dag | dot -Tpdf > build_graphs/dag.pdf"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import sys

import pandas as pd
import numpy as np


def read_data(path):
    """Reads data from path and returns a pandas DataFrame.
    Also renames columns to be more descriptive.

    Args:
        path (str): path to data

    Returns:
        pandas.DataFrame
    """

    colnames = [f"v{i+1}" for i in range(27)]
    data = pd.read_csv(
        "data/raw/QOB.txt",
        delim_whitespace=True,
        header=None,
        names=colnames
    )
    print(f"Read {len(data)} rows")

    rename_dict = {
        "v1": "age",
        "v2": "ageq",
        "v4": "education",
        "v5": "enocent",
        "v6": "esocent",
        "v9": "log_weekly_wage",
        "v10": "married",
        "v11": "midatl",
        "v12": "mt",
        "v13": "neweng",
        "v16": "census",
        "v18": "quarter_of_birth",
        "v19": "race",
        "v20": "smsa",
        "v21": "soatl",
        "v24": "wnocent",
        "v25": "wsocent",
        "v27": "year_of_birth"
    }
    data = data.rename(columns=rename_dict)

    return data


def prepare_data(data):
    """Prepares data for analysis.

    Args:
        data (pandas.DataFrame): data to prepare

    Returns:
        pandas.DataFrame
    """

    data = data.copy()

    data["cohort"] = np.where(
        (40 <= data["year_of_birth"]) & (data["year_of_birth"] <= 49),
        "40-49",
        np.where(
            (30 <= data["year_of_birth"]) & (data["year_of_birth"] <= 39),
            "30-39",
            "20-29"
        )
    )
    data.loc[data["census"] == 80, "ageq"] = data.loc[data["census"] == 80, "ageq"] - 1900
    data["ageq_squared"] = data["ageq"] ** 2
    data["year_of_birth_within_decade"] = [int(str(year)[-1]) for year in data["year_of_birth"]]

    return data


def main():
    """Main function."""

    with open(snakemake.log[0], "w") as logfile:
        sys.stderr = sys.stdout = logfile

        data = read_data(snakemake.input["file"])
        data = prepare_data(data)
        data.to_csv(snakemake.output["file"], index=False)
        print(f"Exported data to {snakemake.output['file']}")


if __name__ == "__main__":
    main()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
library(fixest)
library(tibble)
library(readr)
library(dplyr)
library(purrr)
library(jsonlite)
library(stringr)


load_and_filter_data <- function(data_path, cohort_limits) {
  df <- read_csv(data_path) %>%
    filter(year_of_birth >= cohort_limits[1] & year_of_birth <= cohort_limits[2])
  return(df)
}


estimate_model <- function(df, form) {
    model <- feols(form, df)
    model$data <- NULL
    model$model <- NULL

    return(model)
}


load_formula <- function(specs, iv=FALSE) {
    dep_var <- specs$dep_var
    indep_vars <- unlist(specs$indep_vars)

    if (!is.null(specs$fixed_effects)) {
        fixed_effects_part <- unlist(specs$fixed_effects)
    } else {
        fixed_effects_part <- "0"
    }


    if (iv) {
        instrumented_vars <- unlist(specs$instrumental$instrumented_vars)
        instruments <- unlist(specs$instrumental$instruments)
        indep_vars <- setdiff(indep_vars, instrumented_vars)
        if (length(indep_vars) == 0) {
            indep_vars <- "1"
        }

        iv_part <- paste(
            paste(instrumented_vars, collapse = " | "),
            "~",
            paste(instruments, collapse = " + ")
        )
    }

    formula_str <- paste(
        dep_var, "~",
        paste(indep_vars, collapse = " + "), "|",
        fixed_effects_part
    )
    if (iv) {
        formula_str <- paste(formula_str, "|", iv_part)
    }

    return(as.formula(formula_str))
}


main <- function() {

    file.create(snakemake@log[[1]])
    logfile <- file(snakemake@log[[1]], "wt")
    for (stream in c("output", "message")) {
        sink(file = logfile, type = stream)
    }

    if (snakemake@wildcards["model_type"] == "iv") {
        iv <- TRUE
    } else if (snakemake@wildcards["model_type"] == "ols") {
        iv <- FALSE
    } else {
        stop("model_type must be either 'iv' or 'ols'")
    }

    specs <- read_json(snakemake@input[["model_spec"]])

    limits <- as.integer(c(
        snakemake@wildcards[["from_"]],
        snakemake@wildcards[["to"]]
    ))
    df <- load_and_filter_data(snakemake@input[["file"]], limits)
    form <- load_formula(specs, iv=iv)

    model <- estimate_model(df, form)
    model$name <- specs$name
    model$type <- snakemake@wildcards[["model_type"]]
    saveRDS(model, snakemake@output[["file"]])

    for (stream in c("output", "message")) {
        sink(type = stream)
    }
    close(logfile)

}


main()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import sys

import pandas as pd
import altair as alt
from birth_year_education import aggregate_data


def create_line_plot(data):
    """Creates a line plot of education versus quarter of birth.

    Args:
        data (pandas.DataFrame): data to plot

    Returns:
        alt.Chart: line plot
    """

    brush = alt.selection_interval()

    chart_line = alt.Chart(data).mark_line(
        color="black"
    ).encode(
        x=alt.X("year_quarter_of_birth:Q", title="Year of Birth"),
        y=alt.Y("education:Q", title="Education", scale=alt.Scale(zero=False))
    ).properties(
        width=400,
        height=400
    )

    chart_points = alt.Chart(data).mark_point(
        filled=True,
        size=80
    ).encode(
        x=alt.X("year_quarter_of_birth:Q", title="Year of Birth"),
        y=alt.Y("education:Q", title="Education"),
        color=alt.condition(
            brush, 
            alt.Color("quarter_of_birth:O", title="Quarter of birth"),
            alt.value('lightgray')
        )
    ).properties(
        width=400,
        height=400
    ).add_selection(
        brush
    )

    return (chart_line + chart_points), brush


def create_avg_bars(data):
    """Creates a bar plot of education versus quarter of birth.

    Args:
        data (pandas.DataFrame): data to plot

    Returns:
        alt.Chart: bar plot
    """
    chart = alt.Chart(data).mark_bar().encode(
        x=alt.X("quarter_of_birth:O", title="Quarter of birth"),
        y=alt.Y("mean(education):Q", title="Education", scale=alt.Scale(zero=False)),
        color=alt.Color("quarter_of_birth:O", title="Quarter of birth"),
    ).properties(
        width=200,
        height=400
    )

    return chart


def create_combined_plot(data):
    """Creates a combined plot of education versus year/quarter of birth.

    Args:
        data (pandas.DataFrame): data to plot

    Returns:
        alt.Chart: combined plot
    """
    chart_line, brush = create_line_plot(data)
    chart_bars = create_avg_bars(data).transform_filter(
        brush
    )

    return chart_line | chart_bars


def main(input_data: str, output_path: str, cohort: tuple[int, int] = (20, 49)):
    data = pd.read_csv(input_data)
    aggregated_data = aggregate_data(data, cohort)
    chart = create_combined_plot(aggregated_data)
    chart.save(output_path)


if __name__ == "__main__":

    with open(snakemake.log[0], "w") as logfile:
        sys.stderr = sys.stdout = logfile

        main(
            input_data=snakemake.input["file"],
            output_path=snakemake.output["json"],
            cohort=snakemake.params["cohort"]
        )
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import sys

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def aggregate_data(data, cohort_limits=None):
    """Filters data to cohorts and aggregates data by birth quarter.

    Args:
        data (pandas.DataFrame): data to aggregate
        cohort (list[str]): cohorts to include

    Returns:
        pandas.DataFrame
    """
    if cohort_limits:
        data = data[data["year_of_birth"].between(*cohort_limits)]
    data = data \
        .groupby(["year_of_birth", "quarter_of_birth"]) \
        .aggregate({
            "education": "mean",
            "census": "count"
        }) \
        .rename(columns={"census": "num_obs"}) \
        .reset_index()

    data["year_quarter_of_birth"] = data["year_of_birth"] + \
        data["quarter_of_birth"] / 4 - 0.25

    return data


def create_line_plot(data):
    """Creates a line plot of education versus quarter of birth.

    Args:
        data (pandas.DataFrame): data to plot
        path (str): path to save plot
    """
    fig, ax = plt.subplots()
    sns.lineplot(
        data=data,
        x="year_quarter_of_birth",
        y="education",
        color="black",
        ax=ax
    )
    sns.scatterplot(
        data=data,
        x="year_quarter_of_birth",
        y="education",
        hue="quarter_of_birth",
        alpha=1,
        s=80,
        ax=ax
    )

    ax.set_xlabel("Year of Birth")
    ax.set_ylabel("Years of completed education")
    ax.set_title("Years of education and season of birth")

    return fig, ax


def create_bar_plot(data, cohorts):
    """Creates a line plot of education versus quarter of birth.

    Args:
        data (pandas.DataFrame): data to plot
        path (str): path to save plot
    """
    data = data.copy()
    data["education_ma5"] = data["education"].rolling(window=5, center=True).mean()
    data["education_diff_ma5"] = data["education"] - data["education_ma5"]

    fig, ax = plt.subplots(len(cohorts), 1)
    for i, cohort in enumerate(cohorts):
        data_subset = data[data["year_of_birth"].between(*cohort)]
        sns.barplot(
            data=data_subset,
            x="year_of_birth",
            y="education_diff_ma5",
            hue="quarter_of_birth",
            ax=ax[i]
        )
        ax[i].set_ylabel("Schooling differential")
        ax[i].set_xlabel("")
        if i > 0:
            ax[i].get_legend().remove()
        else:
            ax[i].get_legend().set_title('Quarter of birth')

    ax[1].set_xlabel("Year of Birth")
    fig.suptitle("Season of birth and years of schooling")

    return fig, ax


def save_plot(fig, path, width, height, dpi):
    """Saves a line plot of education versus quarter of birth.

    Args:
        fig (matplotlib.figure.Figure): figure to save
        path (str): path to save plot
    """
    fig.set_size_inches(width, height)
    fig.savefig(path, dpi=dpi)


def lineplot(input_data: str, output_path: str,
             cohort: tuple[int, int] = (20, 49),
             width: int = 6, height: int = 4, dpi: int = 300):
    """Create a line plot of schooling attainment by birth quarter."""
    data = pd.read_csv(input_data)
    aggregated_data = aggregate_data(data, cohort)
    fig, ax = create_line_plot(aggregated_data)
    save_plot(fig, output_path, width, height, dpi)


def barplot(input_data: str, output_path: str,
            cohorts: list[str] = ["30-39", "40-49"],
            width: int = 6, height: int = 8, dpi: int = 300):
    """Create a barplot of schooling differential by birth quarter."""
    data = pd.read_csv(input_data)
    cohort_tuples = []
    for cohort in cohorts:
        cohort_tuples.append(tuple(map(int, cohort.split("-"))))
    aggregated_data = aggregate_data(data, cohort_limits=None)
    fig, ax = create_bar_plot(aggregated_data, cohort_tuples)
    save_plot(fig, output_path, width, height, dpi)


if __name__ == "__main__":

    with open(snakemake.log[0], "w") as logfile:
        sys.stderr = sys.stdout = logfile


        if snakemake.params["plot_type"] == "lineplot":
            cohort = (
                int(snakemake.wildcards["from_"]),
                int(snakemake.wildcards["to"])
            )
            lineplot(
                input_data=snakemake.input["file"],
                output_path=snakemake.output["file"],
                cohort=cohort,
                width=snakemake.params["width"],
                height=snakemake.params["height"],
                dpi=snakemake.params["dpi"]
            )

        elif snakemake.params["plot_type"] == "barplot":
            barplot(
                input_data=snakemake.input["file"],
                output_path=snakemake.output["file"],
                cohorts=snakemake.params["cohorts"],
                width=snakemake.params["width"],
                height=snakemake.params["height"],
                dpi=snakemake.params["dpi"]
            )

        else:
            raise ValueError(f"Unknown plot type: {snakemake.params['plot_type']}")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
library(fixest)
library(modelsummary)
library(stringr)
library(purrr)
library(tibble)
library(dplyr)


create_extra_row_tibble <- function(models) {

    birth_dummies = rep("Yes", length(models))
    region_dummies = map_chr(
        models,
        function(model) ifelse("soatl" %in% rownames(model$coef), "Yes", "No")
    )
    extra_rows_list <- list(
        birth_dummies = birth_dummies,
        region_dummies = region_dummies
    )
    extra_rows <- extra_rows_list %>%
        as_tibble() %>%
        t() %>%
        as_tibble() %>% 
        mutate(term = c("9 Year-of-birth dummies", "8 Region-of-residence dummies")) %>%
        select(term, everything())

    return(extra_rows)

}


main <- function() {

    file.create(snakemake@log[[1]])
    logfile <- file(snakemake@log[[1]], "wt")
    for (stream in c("output", "message")) {
        sink(file = logfile, type = stream)
    }

    paths <- snakemake@input[["models"]]
    models <- map(paths, readRDS)
    names(models) <- map_chr(models, function(model) paste(str_to_upper(model$type), model$name))

    num_models <- length(models)
    alphabetical_order <- order(names(models))
    flip_order <- rep(c(num_models / 2, 0), times = num_models / 2) + rep(seq(num_models / 2), each = 2)

    models <- models[alphabetical_order][flip_order]

    extra_rows <- create_extra_row_tibble(models)

    coef_names <- c(
        "education" = "Years of education",
        "fit_education" = "Years of education",
        "race" = "Race (1 = black)",
        "smsa" = "SMSA (1 = center city)",
        "married" = "Married (1 = married)",
        "ageq" = "Age",
        "ageq_squared" = "Age-squared"
    )

    options(modelsummary_format_numeric_latex = "mathmode")
    modelsummary(
        models,
        output = snakemake@output[["table"]],
        gof_omit = ".*",
        coef_map = coef_names,
        add_rows = extra_rows
    )

    for (stream in c("output", "message")) {
        sink(type = stream)
    }
    close(logfile)

}


main()
ShowHide 12 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/pp4rs/snakemake-demo
Name: snakemake-demo
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Other Versions:
Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...