Hard to Measure Well: Can Feasible Policies Reduce Methane Emissions?

public public 1yr ago 0 bookmarks

Authors: Karl Dunkle Werner ORCID logo and Wenfeng Qiu

Steps to replicate (TL;DR)

  1. Read this README

  2. Make sure you have an appropriate OS (Linux or WSL2) and the necessary computing resources (see below)

  3. Unzip the replication files.

  4. If the data is saved somewhere outside the project folder, mount a copy inside the project folder. (Useful for development only)

  5. Install Conda and Snakemake (see below)

  6. Run Snakemake

  7. Check results

Putting it all the together:

# 3. Unzip
mkdir methane_replication # or whatever you want
cd methane_replication
unzip path/to/replication_public.zip -d .
unzip path/to/replication_drillinginfo.zip -d .
# 4. OPTIONAL
# If the data is saved somewhere outside the project folder, mount
# a copy inside the project folder.
# This is only necessary if the data are stored somewhere *outside*
# the project folder. You may need to change these paths to fit
# your situation
data_drive="$HOME/Dropbox/data/methane_abatement"
scratch_drive="$HOME/scratch/methane_abatement"
project_dir="$(pwd)"
mkdir -p "$scratch_drive" "$project_dir/data" "$project_dir/scratch"
sudo mount --bind "$data_drive" "$project_dir/data"
sudo mount --bind "$scratch_drive" "$project_dir/scratch"
# 6. Install Conda and Snakemake
# If conda is not already installed, follow instructions here:
# https://docs.conda.io/en/latest/miniconda.html
conda env create --name snakemake --file code/envs/install_snakemake.yml
conda activate snakemake
snakemake --version
singularity --version
# Should show versions, not an error
# 7. Run Snakemake to create all outputs
# (this takes about a day with 4 CPU)
/usr/bin/time -v snakemake
# snakemake --dry-run to see what will be run
# 8. Check results (optional and slow)
# Check everything into git, rerun snakemake, and verify results are the same.
git init
git add .
git commit -m "Replication run 1"
snakemake --delete-all-output
rm -r scratch/*
rm -r .snakemake/conda
snakemake --use-conda --use-singularity --singularity-args='--cleanenv'
# Results should be binary-identical if everything worked correctly
# (except software_cites_r.bib, which has some manual edits)
git diff

Setup

Operating system

This code uses Singularity. You don't have to install it yourself, but you do have to be on an operating system where it can be installed. Good options are any recent version of Linux or Windows WSL2 (but not WSL1).

On macOS, or on Windows outside WSL2, things are more difficult. One approach is to install Vagrant, use Vagrant to create a virtual machine, and run everything inside that virtual machine. Good luck.

For more detail, see Singularity's installation docs (only the pre-install requirements; conda will install Singularity for you)

Software

This project uses Snakemake (v6.8.0) and Conda (v4.10.3) to manage dependencies.

  • To get started, first install Conda (mini or full-sized).

  • Then use Conda to install Snakemake and Singularity from the file install_snakemake.yml (in the replication zipfile).

In a terminal:

conda env create --name snakemake --file code/envs/install_snakemake.yml

Run all other commands in that activated environment. If you close the terminal window, you need to re-run conda activate snakemake before running the rest of the commands. These downloads can be large.

What does Snakemake do?

Snakemake uses rules to generate outputs and manages the code environment to make it all work.

In particular, we're following a pattern Snakemake calls an Ad-hoc combination of Conda package management with containers .

Snakemake uses Singularity (an alternative to Docker) to run code in a virtual environment, and uses conda to install packages. All of this is handled transparently as the rules are run.

It can be useful to run snakemake --dry-run to see the planned jobs.

Snakemake keeps track of what needs to run and what doesn't. If something goes wrong midway through, snakemake will see that some outputs are up-to-date and others aren't, and won't re-run the things that don't need it.

The Snakefile is set up to retry failing jobs once, to avoid issues where temporary issues cause the build to fail (e.g. "Error creating thread: Resource temporarily unavailable"). If you would rather not restart failed jobs, remove the line workflow.restart_times = 1 from Snakefile . Note that Snakemake will still stop after failing twice (it will not run other jobs).

Files and data

Accessing data

We need make sure the code can access the right files. There are two ways this can be done, the straightforward way and the way Karl does it.

Recommended file access

Straightforward approach: Unzip the replication files, either interactively or with the commands below.

mkdir methane_replication # or whatever you want
cd methane_replication
unzip path/to/replication_public.zip -d .
unzip path/to/replication_drillinginfo.zip -d .
Alternative file access

Less straightforward, arguably better for development

  • Store the data and scratch folders somewhere else (e.g. data in Dropbox).

  • Create your own bind mounts to point to the data and scratch folders. (See an example in code/bind_mount_folders.sh )

For people familiar with Singularity: Note that $SINGULARITY_BIND doesn't work, because it's not used until the Singularity container is running, so Snakemake thinks files are missing.

For people familiar with symlinks: Using symlinks (in place of bind mounts) do not work here, because Singularity will not follow them.

File structure

All files in output/tex_fragments , data/generated , and scratch/ are auto-generated and can safely be deleted. All other files in data/ should not be deleted. Some files in graphics/ are auto-generated, but the ones that are in the replication zipfile are not. data/ and scratch/ are ignored by .gitignore .

Other

  • The PDF outputs are built with Latexmk and LuaLaTeX.

    • For size reasons, LuaLaTeX is not included in the set of software managed by conda. The paper job, which runs latexmk might fail if it's not installed on your computer. All the outputs up to that point will be present.

    • The tex files use some fonts that are widely distributed, but may not be installed by default.

  • Note that the code depends on moodymudskipper/safejoin which is a different package than safejoin on CRAN. moodymudskipper/safejoin will be renamed .

    • In case the original author deletes the repository, a copy is here .

Computing resources for a full run

In addition to the programs above, parts of this program require significant amounts of memory and disk space. Most parts also benefit from having multiple processors available. (The slow parts parallelize well, so speed should increase almost linearly with processor count.)

The tasks that require significant memory are noted in the Snakemake file (see the mem_mb fields). The highest requirement for any task is 10 GB, though most are far lower. (These could be overstatements; we haven't tried that hard to find the minimum memory requirements for each operation.) The programs also use about 80 GB of storage in scratch/ in addition to the ~10 GB of input data and ~8 GB output data.

Running the whole thing takes 23 hours on a computer with 4 CPUs. According to /usr/bin/time -v , it uses 22:45 of wall time and 81:45 of user time. Maximum resident set size is (allegedly) 3.62 GiB (this seems low).

Data sources and availability

The data in this study come from a variety of sources, with the sources in bold providing the central contribution.

  • Scientific studies (except as noted, all from the published papers and supplementary information)

    • Alverez et al. (2018)

    • Duren et al. (2019)

    • Frankenberg et al. (2016)

      • Includes data received by email from the authors
    • Lyon et al. (2016)

    • Omara et al. (2018)

    • Zavala-Araiza et al. (2018)

  • US Agencies

    • BEA

    • EIA

    • EPA

    • St. Louis Federal Reserve

  • Data providers

    • SNL: prices at trading hubs

    • Enverus (formerly Drillinginfo): Well production and characteristics

All datasets are included in the replication_public.zip file, except the Enverus data. I believe my Enverus data subset can be shared with people who have access to the Enverus entity production headers and entity production monthly datasets.

Development docs

These notes are modestly outdated, and aren't useful for replication.

Other installation instructions

Installing Stan

  1. Download and extract CmdStan

  2. Add these lines to a file named local in the CmdStan make/ directory. (Create local if it doesn't already exist)

O_STANC=3
STANCFLAGS+= --O --warn-pedantic
STAN_THREADS=true
STAN_CPP_OPTIMS=true
STANC3_VERSION=2.27.0 # change this version to match the downloaded cmdstan version
  1. Edit user environment variable CMDSTAN to the folder (e.g. ~/.R/cmdstan-2.27.0 )

  2. Windows only: mingw32-make install-tbb (even if make is installed)

  3. Follow prompts, including adding the TBB dir to your path (Windows only)

  4. Run make build and make examples/bernoulli/bernoulli (see install instructions)

  5. Installation tries to download stanc (because compilation is a hassle), but sometimes I've had to download manually from https://github.com/stan-dev/stanc3/releases

Windows-specific instructions

  • After installing conda, use code/snakemake_environment_windows.yml to create snakemake environment (will error if you already have one named snakemake )

    • conda env create -f code/snakemake_environment_windows.yml
  • Install cyipopt manually .

    • Download and extract ipopt

    • Download and extract cyipopt

    • Copy the ipopt folders into cyipopt

    • Activate snakemake environment

    • python ../cyipopt/setup.py install (assuming you saved the cyipopt directory next to the methane_abatement directory)

Running on Windows
  • Activate the newly created snakemake environment, and do not set --use_conda when running Snakemake.

  • There will be some warnings.

Overleaf

Connect to Overleaf by Git. See details here .

git remote add overleaf https://git.overleaf.com/5d6e86df6820580001f6bdfa
git checkout master
git pull overleaf master --allow-unrelated-histories

Code Snippets

 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
source(here::here("code/shared_functions.r"))
memory_limit(snakemake@resources[["mem_mb"]])
suppressWarnings(loadNamespace("lubridate")) # avoid https://github.com/tidyverse/lubridate/issues/965
options(warn = 2, mc.cores=1)

df_gdp <- readxl::read_excel(snakemake@input$gdp_file, range="A6:W907")
years <- 1997:2017
stopifnot(names(df_gdp) == c("Line", "...2", as.character(years)))
df_gdp %<>% dplyr::rename(line = Line, item = ...2) %>%
  dplyr::mutate(line = as.integer(line), item = trimws(item))


assert_item <- function(df, item) {
  stopifnot(setequal(df$item, item))
  df
}

df_deprec <- readxl::read_excel(snakemake@input$deprec_file, range="A6:BW85")
deprec_years <- 1947:2019
stopifnot(names(df_deprec) == c("Line", "...2", as.character(deprec_years)))
df_deprec %<>% dplyr::rename(line = Line, item = ...2) %>%
  dplyr::filter(!is.na(line)) %>%
  dplyr::mutate(line = as.integer(line), item = trimws(item)) %>%
  dplyr::filter(line == 6) %>%
  dplyr::select(-line) %>%
  data.table::transpose(make.names=TRUE) %>%
  dplyr::mutate(year = !!deprec_years) %>%
  dplyr::rename(deprec_current_cost = "Oil and gas extraction")


prices = load_price_index(snakemake@input$price_index_file, base_year=2019) %>%
  dplyr::mutate(year = lubridate::year(date)) %>%
  dplyr::group_by(year) %>%
  dplyr::summarize(price_index = mean(price_index), .groups="drop")

oil_gas <- dplyr::filter(df_gdp, line %in% 56:57) %>%
  assert_item(c("Value added", "Compensation of employees")) %>%
  dplyr::select(-line) %>%
  data.table::transpose(make.names=TRUE) %>%
  dplyr::mutate(year = !!years) %>%
  dplyr::mutate_if(is.character, as.numeric) %>%
  safejoin::safe_inner_join(df_deprec, by="year", check="UVBVTL") %>%
  safejoin::safe_inner_join(prices, by="year", check="UVBVL") %>%
  dplyr::rename_all(make_better_names) %>%
  dplyr::mutate(
    net_value_nominal_bn = value_added - compensation_of_employees - deprec_current_cost,
    net_value_real_bn = net_value_nominal_bn / price_index,
  ) %>%
  dplyr::as_tibble()


stopifnot(!anyNA(oil_gas))

# 2017 is lastest easily available. 2013 is arbitrary, but 5 years seems nice.
mean_net_val <- dplyr::filter(oil_gas, dplyr::between(year, 2013, 2017))$net_value_real_bn %>% mean() %>%
  signif(2)
to_write <- c(
  "% Def: value added - employee compensation - current-cost depreciation, expressed in $2019 billions",
  "% Data from BEA",
  paste0(mean_net_val, "%") # "%" to avoid extra space in latex
)

writeLines(to_write, snakemake@output$net_value)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
options(warn = 2)
source(here::here("code/shared_functions.r"))
source(here::here("code/stan_helper_functions.r"))
check_cmdstan() # not the same as cmdstanr::check_cmdstan_toolchain

m = cmdstanr::cmdstan_model(
  snakemake@input[[1]],
  include_paths=here::here("code/stan_models"),
  quiet=FALSE,
  cpp_options=list(PRECOMPILED_HEADERS = "false")
)
R From line 1 of code/compile_stan.R
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
import sys
import itertools
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
import pyarrow
import pyarrow.parquet as pq
import logging
from functools import partial

if sys.version_info < (3, 7):
    raise AssertionError("Need python version 3.7 or higher")

# Create a logger to output any messages we might have...
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


def dtypes_to_schema(dtype_dict):
    schema = []
    type_conversion = {
        "Int64": pyarrow.int64(),
        "Int32": pyarrow.int32(),
        "int64": pyarrow.int64(),
        "int32": pyarrow.int32(),
        "int": pyarrow.int32(),
        "float64": pyarrow.float64(),
        "float32": pyarrow.float32(),
        "float": pyarrow.float32(),
        "str": pyarrow.string(),
        "O": pyarrow.string(),
        # Note that the arguments to pyarrow.dictionary changed between v0.13.0
        # and v0.14.1
        "category": pyarrow.dictionary(
            pyarrow.int32(), pyarrow.string(), ordered=False
        ),
        "date32": pyarrow.date32(),  # Note: date32 isn't a python type
    }
    for field_name, dtype in dtype_dict.items():
        type = type_conversion[dtype]
        schema.append(pyarrow.field(field_name, type, nullable=True))
    return pyarrow.schema(schema)


def extension_int_to_float(df, exclude=[]):
    """
    You can delete this function once this issue is resolved:
    https://issues.apache.org/jira/browse/ARROW-5379
    """
    extension_int_types = {
        pd.Int8Dtype(),
        pd.Int16Dtype(),
        pd.Int32Dtype(),
        pd.Int64Dtype(),
    }
    new_types = {
        col: "float64"
        for col in df.columns
        if df[col].dtypes in extension_int_types and col not in exclude
    }
    return df.astype(new_types)


def fix_monthly_dates(df):
    df = coerce_to_date(df, ["date"], drop_bad=True)
    # Note: these are sometimes NA
    df["month"] = df["date"].dt.month.astype("Int32")
    df["year"] = df["date"].dt.year.astype("Int32")
    del df["date"]
    return df


def read_csv(filepath_or_buffer, *args, **kwargs):
    """Thin wrapper around pd.read_csv"""
    logger.info(f"  Reading {filepath_or_buffer}")
    return pd.read_csv(filepath_or_buffer, *args, **kwargs)


def convert_monthly(csv_file, out_dir):
    csv_file = Path(csv_file)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    # Note: we're using python 3.7+, so dicts keep their order
    pd_dtypes = {
        "entity_id": "Int32",
        "api": "str",
        "api_list": "str",
        "date": "str",
        "oil": "str",  # will convert to float later
        "gas": "str",  # will convert to float later
        "water": "str",  # will convert to float later
        "well_count": "str",  # will convert to Int later
        "days": "str",  # will convert to Int later
        "daily_avg_oil": "str",  # will convert to float later
        "daily_avg_gas": "str",  # will convert to float later
        "daily_avg_water": "str",  # will convert to float later
        "reservoir": "str",
        "well_name": "str",
        "well_number": "str",
        "operator_alias": "str",
        "production_type": "str",
        "production_status": "str",
        "entity_type": "str",
    }
    # Need to read these as str, then convert to float, because of errors.
    float_cols = [
        "oil",
        "gas",
        "water",
        "daily_avg_oil",
        "daily_avg_gas",
        "daily_avg_water",
    ]
    int_cols = ["well_count", "days"]
    output_dtypes = pd_dtypes.copy()
    del output_dtypes["date"]
    output_dtypes["year"] = "int32"
    output_dtypes["month"] = "int32"
    for c in float_cols:
        output_dtypes[c] = "float"
    for c in int_cols:
        output_dtypes[c] = "int"
    pq_schema = dtypes_to_schema(output_dtypes)

    csv_iter = read_csv(
        csv_file,
        dtype=pd_dtypes,
        chunksize=3_000_000,
        index_col=False,
        names=list(pd_dtypes.keys()),
        on_bad_lines="warn",
        header=0,
        low_memory=True,
        na_values={"(N/A)"},
        keep_default_na=True,
    )
    partition_cols = ["year"]
    for chunk in csv_iter:
        chunk = (
            chunk.pipe(fix_monthly_dates)
            .pipe(coerce_to_float, cols=float_cols)
            .pipe(coerce_to_integer, cols=int_cols)
            .pipe(extension_int_to_float)
            .pipe(fill_na_for_partitioning, cols=partition_cols)
        )
        table = pyarrow.Table.from_pandas(chunk, preserve_index=False, schema=pq_schema)
        pq.write_to_dataset(
            table, root_path=str(out_dir), partition_cols=partition_cols, version="2.0"
        )
    logger.info(f"  Done with monthly production {csv_file}")


def fill_na_for_partitioning(df, cols):
    """
    Replace NaN or null values with non-NA placeholders.

    These placeholders vary by column type: "NA" for string or category,
    -9 for numeric.

    The goal here is to prevent write_to_dataset from dropping groups with
    missing-valued partition columns. This function can be cut out when this
    pyarrow issue is fixed:
    https://issues.apache.org/jira/projects/ARROW/issues/ARROW-7345
    """
    orig_order = df.columns
    assert orig_order.nunique() == df.shape[1]

    to_fill = df[cols]
    fill_str = to_fill.select_dtypes(include=["object", "category"]).fillna("NA")
    fill_num = to_fill.select_dtypes(include="number").fillna(-9)
    unfilled = to_fill.columns.difference(fill_str.columns).difference(fill_num.columns)
    if len(unfilled) > 0:
        raise NotImplementedError(f"Can't fill columns: {unfilled.tolist()}")
    # Everything else:
    remainder = df.drop(cols, axis=1)
    out = pd.concat([fill_str, fill_num, remainder], axis=1)[orig_order.tolist()]
    # Reset columns to original ordering:
    return out


def coerce_cols(df, cols, target_type, drop_bad=True, **kwargs):
    """Coerce `cols` to `target_type`, optionally dropping rows that fail.

    A very small number of rows fail to parse because they have things like
    "WELL" in the oil quantity column. Just drop these.
    But because of this, we have to read all of these column as str, then convert

    args: df a dataframe
    cols: columns, currently string, that will be coerced
    target_type: what type to coerce to?
    drop_bad: should we drop the whole row if the coersion fails? (Default True)
    kwargs: arguments passed on to pd.to_datetime or pd.to_numeric.
    """

    rename_dict = {c: c + "_str" for c in cols}
    if df.columns.isin(rename_dict.values()).any():
        raise ValueError("temp rename column already exists")
    df = df.rename(columns=rename_dict)

    if target_type == "datetime":
        conversion_fn = pd.to_datetime
    else:
        conversion_fn = partial(pd.to_numeric, downcast=target_type)

    for new_col, old_col in rename_dict.items():
        try:
            # Happy case first
            df[new_col] = conversion_fn(df[old_col], errors="raise", **kwargs)
        except ValueError:
            df[new_col] = conversion_fn(df[old_col], errors="coerce", **kwargs)
            newly_missing = df[new_col].isna() & df[old_col].notna()
            newly_missing_count = newly_missing.sum()
            logger.info(
                f"  {newly_missing_count} values failed to parse as {target_type}."
                + f" Here are a few: "
                + ", ".join(df.loc[newly_missing, old_col].iloc[:6])
            )
            if drop_bad:
                # Drop rows that failed to parse (very few of these)
                df = df.loc[~newly_missing]
        finally:
            del df[old_col]
    return df


coerce_to_date = partial(coerce_cols, target_type="datetime", format="%Y-%m-%d")
coerce_to_float = partial(coerce_cols, target_type="float")
coerce_to_integer = partial(coerce_cols, target_type="integer")


def convert_headers(csv_file, out_dir):
    csv_file = Path(csv_file)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    pd_dtypes = {
        "api": "str",  # API/UWI
        "operator_alias_legacy": "str",  # Operator Alias (Legacy)
        "operator_company_name": "str",  # Operator Company Name
        "operator_reported": "str",  # Operator (Reported)
        "operator_ticker": "str",  # Operator Ticker
        "well_name": "str",  # Well/Lease Name
        "well_number": "str",  # Well Number
        "entity_type": "str",  # Entity Type
        "county": "str",  # County/Parish
        "di_basin": "str",  # DI Basin
        "di_play": "str",  # DI Play
        "di_subplay": "str",  # DI Subplay
        "reservoir": "str",  # Reservoir
        "production_type": "str",  # Production Type
        "producing_status": "str",  # Producing Status
        "drill_type": "str",  # Drill Type
        "first_prod_date": "str",  # First Prod Date
        "last_prod_date": "str",  # Last Prod Date
        "cum_gas": "float",  # Cum Gas
        "cum_oil": "float",  # Cum Oil
        "cum_boe": "float",  # Cum BOE
        "cum_water": "float",  # Cum Water
        "cum_mmcfge": "float",  # Cum MMCFGE
        "cum_bcfge": "float",  # Cum BCFGE
        "daily_gas": "float",  # Daily Gas
        "daily_oil": "float",  # Daily Oil
        "first_month_oil": "float",  # First Month Oil
        "first_month_gas": "float",  # First Month Gas
        "first_month_water": "float",  # First Month Water
        "first_6_oil": "float",  # First 6 Oil
        "first_6_gas": "float",  # First 6 Gas
        "first_6_boe": "float",  # First 6 BOE
        "first_6_water": "float",  # First 6 Water
        "first_12_oil": "float",  # First 12 Oil
        "first_12_gas": "float",  # First 12 Gas
        "first_12_boe": "float",  # First 12 BOE
        "first_12_mmcfge": "float",  # First 12 MMCFGE
        "first_12_water": "float",  # First 12 Water
        "first_24_oil": "float",  # First 24 Oil
        "first_24_gas": "float",  # First 24 Gas
        "first_24_boe": "float",  # First 24 BOE
        "first_24_mmcfge": "float",  # First 24 MMCFGE
        "first_24_water": "float",  # First 24 Water
        "first_60_oil": "float",  # First 60 Oil
        "first_60_gas": "float",  # First 60 Gas
        "first_60_boe": "float",  # First 60 BOE
        "first_60_water": "float",  # First 60 Water
        "first_60_mmcfge": "float",  # First 60 MMCFGE
        "prac_ip_oil_daily": "float",  # Prac IP Oil Daily
        "prac_ip_gas_daily": "float",  # Prac IP Gas Daily
        "prac_ip_cfged": "float",  # Prac IP CFGED
        "prac_ip_boe": "float",  # Prac IP BOE
        "latest_oil": "float",  # Latest Oil
        "latest_gas": "float",  # Latest Gas
        "latest_water": "float",  # Latest Water
        "prior_12_oil": "float",  # Prior 12 Oil
        "prior_12_gas": "float",  # Prior 12 Gas
        "prior_12_water": "float",  # Prior 12 Water
        "last_test_date": "str",  # Last Test Date
        "last_flow_pressure": "float",  # Last Flow Pressure
        "last_whsip": "float",  # Last WHSIP
        "2nd_month_gor": "float",  # 2nd Month GOR
        "latest_gor": "float",  # Latest GOR
        "cum_gor": "float",  # Cum GOR
        "last_12_yield": "float",  # Last 12 Yield
        "2nd_month_yield": "float",  # 2nd Month Yield
        "latest_yield": "float",  # Latest Yield
        "peak_gas": "float",  # Peak Gas
        "peak_gas_month_no.": "Int32",  # Peak Gas Month No.
        "peak_oil": "float",  # Peak Oil
        "peak_oil_month_no.": "Int32",  # Peak Oil Month No.
        "peak_boe": "float",  # Peak BOE
        "peak_boe_month_no.": "Int32",  # Peak BOE Month No.
        "peak_mmcfge": "float",  # Peak MMCFGE
        "peak_mmcfge_month_no.": "Int32",  # Peak MMCFGE Month No.
        "upper_perforation": "float",  # Upper Perforation
        "lower_perforation": "float",  # Lower Perforation
        "gas_gravity": "float",  # Gas Gravity
        "oil_gravity": "float",  # Oil Gravity
        "completion_date": "str",  # Completion Date
        "well_count": "Int32",  # Well Count
        "max_active_wells": "Int32",  # Max Active Wells
        "months_produced": "Int32",  # Months Produced
        "gas_gatherer": "str",  # Gas Gatherer
        "oil_gatherer": "str",  # Oil Gatherer
        "lease_number": "str",  # Lease Number
        "spud_date": "str",  # Spud Date
        "measured_depth_td": "float",  # Measured Depth (TD)
        "true_vertical_depth": "float",  # True Vertical Depth
        "gross_perforated_interval": "float",  # Gross Perforated Interval
        "field": "str",  # Field
        "state": "str",  # State
        "district": "str",  # District
        "aapg_geologic_province": "str",  # AAPG Geologic Province
        "country": "str",  # Country
        "section": "str",  # Section
        "township": "str",  # Township
        "range": "str",  # Range
        "abstract": "str",  # Abstract
        "block": "str",  # Block
        "survey": "str",  # Survey
        "ocs_area": "str",  # OCS Area
        "pgc_area": "str",  # PGC Area
        "surface_latitude_wgs84": "float",  # Surface Latitude (WGS84)
        "surface_longitude_wgs84": "float",  # Surface Longitude (WGS84)
        "last_12_oil": "float",  # Last 12 Oil
        "last_12_gas": "float",  # Last 12 Gas
        "last_12_water": "float",  # Last 12 Water
        "entity_id": "int32",  # Entity ID
    }
    date_cols = [
        "first_prod_date",
        "last_prod_date",
        "last_test_date",
        "completion_date",
        "spud_date",
    ]

    output_dtypes = pd_dtypes.copy()
    for d in date_cols:
        output_dtypes[d] = "date32"
    output_dtypes["first_prod_year"] = "int32"
    pq_schema = dtypes_to_schema(output_dtypes)
    csv_iter = read_csv(
        csv_file,
        dtype=pd_dtypes,
        chunksize=1_000_000,
        index_col=False,
        names=list(pd_dtypes.keys()),
        on_bad_lines="warn",
        parse_dates=date_cols,
        header=0,
        low_memory=True,
    )
    partition_cols = ["first_prod_year"]
    for chunk in csv_iter:
        chunk["first_prod_year"] = chunk["first_prod_date"].dt.year
        chunk = chunk.pipe(extension_int_to_float, exclude=partition_cols).pipe(
            fill_na_for_partitioning, cols=partition_cols
        )
        table = pyarrow.Table.from_pandas(chunk, preserve_index=False, schema=pq_schema)
        pq.write_to_dataset(
            table, root_path=str(out_dir), partition_cols=partition_cols, version="2.0"
        )
    logger.info(f"  Done with headers: {csv_file}")


def make_uniform_bins(x, num):
    import numpy as np

    # Bins are evenly spaced in the range of unsigned 64bit ints
    bins = np.linspace(0, np.iinfo("uint64").max - 1024, num=num, dtype="uint64")
    # Hash (determanistic) of the values is a uint64
    hashes = pd.util.hash_pandas_object(x, index=False)
    assert hashes.dtypes == "uint64"
    return np.digitize(hashes, bins)


def regroup_parquet_fragments(directory, recursive=True):
    """
    Read all the parquet files in a directory and write a single file.
    If there are inner directories, `regroup_parquet_fragments` will be called
    recursively on those.

    Why:
    Because I'm calling write_to_dataset repeatedly for multiple chunks of data,
    I sometimes get multiple fragments for the same partition. That's a pain,
    so this function fixes that. After running it, there will be only one file
    (or zero if the directory was initially empty), named `file.parquet`.
    """
    assert directory.is_dir()
    pq_files = []
    dirs = []
    other = []
    for p in directory.iterdir():
        if p.is_file() and p.suffix == ".parquet":
            pq_files.append(p)
        elif p.is_dir():
            dirs.append(p)
        else:
            other.append(p)
    if len(other):
        raise ValueError("Unexpected files found: " + ", ".join(other))
    if recursive:
        [regroup_parquet_fragments(d, recursive=recursive) for d in dirs]

    outfile = directory.joinpath("file.parquet")
    assert not outfile.exists()
    if len(pq_files) == 0:
        return
    elif len(pq_files) == 1:
        pq_files[0].rename(outfile)
        return
    else:
        data = pq.ParquetDataset(pq_files, memory_map=True).read(
            use_pandas_metadata=True
        )
        pq.write_table(data, outfile, version="2.0")
        [f.unlink() for f in pq_files]


def main(snakemake):
    # Fill with partial so we can easily map with one argument.
    monthly = partial(convert_monthly, out_dir=Path(snakemake.output["monthly"]))
    headers = partial(convert_headers, out_dir=Path(snakemake.output["headers"]))

    with ProcessPoolExecutor(max_workers=snakemake.threads) as ex:
        ex.map(monthly, snakemake.input["monthly"])
        ex.map(headers, snakemake.input["headers"])
    regroup_parquet_fragments(Path(snakemake.output["headers"]))
    regroup_parquet_fragments(Path(snakemake.output["monthly"]))


if __name__ == "__main__":
    main(snakemake)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
source(here::here("code/shared_functions.r"))

cite_attached_packages(snakemake@output[[1]], c(
  "base", # R base
  "arrow",
  "brms",
  "broom",
  "cmdstanr",
  "curl",
  "data.table",
  "digest",
  "dplyr",
  "forcats",
  "fs",
  "furrr",
  "ggplot2",
  "future",
  "glue",
  "here",
  "igraph",
  "jsonlite",
  "lubridate",
  "magrittr",
  "matrixStats",
  "nngeo",
  "posterior",
  "processx",
  "purrr",
  "RColorBrewer",
  "readxl",
  "rlang",
  "safejoin",
  "sf",
  "stringr",
  "tibble",
  "tidyr",
  "tidyselect",
  "units",
  "unix"
))
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
source(here::here("code/shared_functions.r"))

INPUT_CRS <- sf::st_crs(4326) # https://epsg.io/4326
WORKING_CRS <- sf::st_crs(6350) # https://epsg.io/6350


find_well_pads <- function(df, well_pad_number_start, verbose=FALSE) {
  if (verbose) {
    message("  ", paste0(sort(unique(df$aapg_geologic_province)), collapse=", "))
  }
  df_radius <- df %>%
    # Drop different entity_id at the exact same location
    # (distinct.sf already keeps geometry column by default, but for some
    # reason gives an error)
    dplyr::distinct(aapg_geologic_province, geometry) %>%
    dplyr::mutate(
      buffer_radius_m = dplyr::case_when(
        aapg_geologic_province == "SAN JOAQUIN BASIN" ~ units::as_units(20, "m"),
        TRUE ~ units::as_units(50, "m")
      ))
  well_pad_geometry <- df_radius %>%
    sf::st_geometry() %>%
    sf::st_buffer(dist=df_radius$buffer_radius_m) %>%
    st_union_intersection()
  new_well_pad_ids <- seq.int(
    from = well_pad_number_start,
    length.out = length(well_pad_geometry)
  )
  well_pad_info <- well_pad_geometry %>%
    sf::st_centroid() %>%
    sf::st_transform(crs=INPUT_CRS) %>%
    sf::st_coordinates() %>%
    dplyr::as_tibble() %>%
    dplyr::transmute(
      well_pad_lon = .data$X,
      well_pad_lat = .data$Y,
      well_pad_id  = new_well_pad_ids,
    )
  well_pad_df <- sf::st_sf(well_pad_info, geometry = well_pad_geometry)

  out <- dplyr::select(df, entity_id) %>%
    sf::st_join(well_pad_df, left=FALSE) %>%
    sf::st_drop_geometry() %>%
    dplyr::as_tibble()
  out
}


assign_well_pads <- function(header_df, verbose=TRUE) {
  stopifnot(!anyDuplicated(header_df$entity_id))
  # My definition of a well pad, based on Alvarez et al. (2018) SI:
  # - a group of wells that are close together
  # - if the DI record is a lease (irrelevant for CA, NM, and post-1999 CO),
  #   assume it is its own well pad; don't group it with other records
  # - if more than two wells are nearby, take the union. (See the function
  #   st_union_intersection.)
  # NB: It would be nice to use operator, but there are a bunch of wells with
  # different operators reported at the same location.
  # Steps:
  # 1. Make into an sf dataframe
  # 2. Transform to a useful projection
  # 3. Buffer all points 50m, except the San Joaquin basin, where the buffer is 20m.
  # 4. Union the buffers and assign IDs
  # 5. Take the intersection of the original points with the buffers
  # 6. Do the above separately for each state or basin, then re-aggregate.
  # Note that the well-pad definitions are sensitive to the projection and
  # assumed well spacing (eg. 50m) for a small number of wells.

  # For records that are marked as "LEASE" (as opposed to "COM", "DRIP POINT",
  # "SWD", "UNIT", or "WELL"), don't do the geographic work, just label them
  # their own singleton well pad.
  singletons <- header_df %>%
    dplyr::group_by(aapg_geologic_province, operator_company_name) %>%
    dplyr::filter(dplyr::n() == 1 | entity_type == "LEASE") %>%
    dplyr::ungroup() %>%
    dplyr::select(entity_id, surface_longitude_wgs84, surface_latitude_wgs84) %>%
    dplyr::mutate(well_pad_id = dplyr::row_number()) %>%
    dplyr::rename(
      well_pad_lon = surface_longitude_wgs84,
      well_pad_lat = surface_latitude_wgs84,
    )

  # For records that aren't LEASE or singleton, look for nearby wells within
  # groups defined by basin and operator_company_name. Note that there are a
  # small number of overlaps, where a well falls within a well pad with
  # different basin or operator_company_name. (This is true even if you only use
  # basin.)
  non_singleton <- header_df %>%
    dplyr::anti_join(singletons, by="entity_id") %>%
    sf::st_as_sf(
      crs=INPUT_CRS,
      coords=c("surface_longitude_wgs84", "surface_latitude_wgs84")
    ) %>%
    sf::st_transform(crs=WORKING_CRS)

  # well_pads is the geometry of the well pads (a POLYGON buffered around the
  # points of each group of wells)
  non_singleton_lst <- non_singleton %>%
    # Split by basin so we can run different regions in parallel
    # Well pads will not span different subgroups of the group_by, so be careful
    # before adding grouping variables.
    dplyr::group_by(aapg_geologic_province, operator_company_name) %>%
    dplyr::group_split()
  # We then also need to make sure we're not assigning duplicate well_pad_id
  # values in the different parallel processes, so calculate disjoint sets of
  # possible_well_pad_ids
  n_groups <- length(non_singleton_lst)
  subgroup_row_counts <- purrr::map_int(non_singleton_lst, nrow)
  well_pad_id_bounds <- c(nrow(singletons), subgroup_row_counts) %>% cumsum()
  well_pad_id_starts <- dplyr::lag(well_pad_id_bounds)[-1] + 1
  # possible_well_pad_ids isn't used, just checked
  possible_well_pad_ids <- purrr::map2(
      well_pad_id_starts,
      well_pad_id_bounds[-1],
      ~seq.int(from=.x, to=.y, by=1)
    ) %>%
    unlist() %>%
    c(singletons$well_pad_id)

  stopifnot(
    all(subgroup_row_counts) > 0,
    anyDuplicated(possible_well_pad_ids) == 0,
    length(non_singleton_lst) == length(well_pad_id_starts)
  )

  # Actually do the well pad creation, and re-add the singletons
  entity_pad_crosswalk <- furrr::future_map2_dfr(
      non_singleton_lst, well_pad_id_starts, find_well_pads,
      verbose=verbose,
      # we're not doing any RNG, but suppress warnings
      .options=furrr::furrr_options(seed=TRUE)
    ) %>%
    dplyr::bind_rows(singletons)
  # Note: singletons got assigned well_pad_id values 1:nrow(singletons), but it
  # doesn't matter that they're bound at the end here.

  out <- dplyr::inner_join(header_df, entity_pad_crosswalk, by="entity_id")
  stopifnot(
    anyDuplicated(header_df$entity_id) == 0,
    anyDuplicated(entity_pad_crosswalk$entity_id) == 0,
    setequal(header_df$entity_id, out$entity_id)
  )
  if (verbose) {
    n_wells <- nrow(header_df)
    n_unmatched <- n_wells - nrow(out)
    well_pad_well_counts <- dplyr::count(out, well_pad_id)
    n_wells_at_multi_well_pads <- dplyr::filter(well_pad_well_counts, n > 1) %>%
      dplyr::pull("n") %>%
      sum_() %>%
      fill_na(0)
    n_wells_at_single_well_pads <- dplyr::filter(well_pad_well_counts, n == 1) %>%
      dplyr::pull("n") %>%
      sum_() %>%
      fill_na(0)
    glue_message(
      "Of {n_wells} total wells, {n_wells_at_multi_well_pads} were matched to ",
      "multi-well pads, {n_wells_at_single_well_pads} were matched to single-",
      "well pads, and {n_unmatched} were unmatched for data quality reasons."
    )
  }
  out
}


parse_year_range <- function(year_range) {
  # Use a regex header_df to pull out the years
  # look for things like "1990-2018"
  stopifnot(length(year_range) == 1)
  years <- stringr::str_match(year_range, "^(\\d{4})-(\\d{4})$")[1, 2:3] %>%
    as.integer() %>%
    unclass()
  stopifnot(length(years) == 2, !anyNA(years))
  years
}


create_well_pad_crosswalk <- function(header_dir, output_file, year_range) {
  year_range %<>% parse_year_range()
  stopifnot(dir.exists(header_dir), length(output_file) == 1)
  header_df <- arrow::open_dataset(header_dir) %>%
    dplyr::filter(
      first_prod_year >= year_range[1],
      first_prod_year <= year_range[2],
      !is.na(aapg_geologic_province),
      aapg_geologic_province != "(N/A)",
      !is.na(surface_latitude_wgs84),
      !is.na(surface_longitude_wgs84),
    ) %>%
    dplyr::select(
      entity_id, aapg_geologic_province,
      surface_latitude_wgs84, surface_longitude_wgs84,
      # NA operator_company_name is allowed (currently coded as "(N/A)")
      operator_company_name, entity_type
    ) %>%
    dplyr::collect() %>%
    dplyr::distinct()  # distinct for multiple wells at the same entity and location

  well_pad_crosswalk <- assign_well_pads(header_df, verbose=FALSE)
  arrow::write_parquet(well_pad_crosswalk, output_file)
}


test_well_pad_creation <- function() {
  header_df <- tibble::tibble(
    entity_id = 1:4,
    aapg_geologic_province = "SAN JOAQUIN BASIN",
    entity_type = "WELL",
    operator_company_name="AAAAA",
    # First two points with identical location, third point about 9m away,
    # and fourth point 4.3km away
    surface_longitude_wgs84 = c(-119.7514, -119.7514, -119.7515, -119.8),
    surface_latitude_wgs84 = 35.48126,
  )

  expected_results <- header_df %>%
    dplyr::mutate(well_pad_id = c(1, 1, 1, 2))
  actual_results <- assign_well_pads(header_df, verbose=FALSE) %>%
    dplyr::select(-well_pad_lon, -well_pad_lat) # don't test that part
  compare_results <- all.equal(actual_results, expected_results)
  if (!isTRUE(compare_results)) {
    print(tibble::tibble(
      expected_id = expected_results$well_pad_id,
      actual_id = actual_results$well_pad_id,
    ))
    print(actual_results)
    stop(compare_results)
  }
}


if (!exists("snakemake")) {
  stop("This script is meant to be run with snakemake.")
}

# Set up resources (mem limit doesn't work on MacOS)
memory_limit(snakemake@resources[['mem_mb']])
future::plan("multicore", workers = snakemake@threads)

test_well_pad_creation()

create_well_pad_crosswalk(
  header_dir = snakemake@input$headers,
  output_file = snakemake@output[[1]],
  year_range = snakemake@wildcards$year_range
)
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
source(here::here("code/shared_functions.r"))

options(scipen=5, SOURCE_DATE_EPOCH=0)
set.seed(6350) # CRS as seed? Sure.

# Note on variable names:
# mcfd is mcf per day (thousands of cubic feet per day), but is reported monthly
# bbld is barrels per day, but is measured monthly.
# There are other variables that report the average on days when production was
# happening, but I don't use those (maybe I should?).

LON_LAT_CRS <- sf::st_crs(4326) # https://epsg.io/4326
OUTPUT_CRS <- sf::st_crs(6350) # https://epsg.io/6350 (CONUS Albers)
MAX_ACCEPTABLE_MATCH_DIST_METERS <- 500
# These are the source types we'll keep. There are others, like landfills and
# dairies, that we don't care about. This is a list of oil and gas types.
# I wasn't sure about gas compressors, but they seem to be far from wells.
SOURCE_TYPES <- c(
  # "gas compressor",
  # "gas distribution line",
  # "gas LNG station",
  # "gas processing plant",
  # "gas storage facility",
  "oil/gas compressor",
  "oil/gas drill rig",
  "oil/gas gathering line",
  "oil/gas possible plugged well",
  "oil/gas pumpjack",
  "oil/gas stack",
  "oil/gas tank",
  "oil/gas unknown infrastucture",
  "oil/gas waste lagoon"
)

read_jpl_plumes <- function(input_files) {
  df <- data.table::fread(input_files[["duren_2019_plumes"]], data.table=FALSE) %>%
    dplyr::as_tibble() %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::filter(source_type_best_estimate %in% !!SOURCE_TYPES) %>%
    dplyr::mutate(
      datetime_of_detection = lubridate::mdy_hm(paste(date_of_detection, time_of_detection_utc), tz="UTC"),
      date_of_detection = lubridate::mdy(date_of_detection),
      flight_name = stringr::str_match(candidate_identifier, "^(ang\\d{8}t\\d{6})")[, 2L, drop=TRUE]
    ) %>%
    dplyr::rename(
      emiss_kg_hr = qplume_kg_hr_plume_emissions,
      emiss_se_kg_hr = sigma_qplume_kg_hr_uncertainty_for_plume_emissions,
    ) %>%
    dplyr::select(-sectors_ipcc, -time_of_detection_utc) %>%
    ensure_id_vars(source_identifier, datetime_of_detection) %>%
    sf::st_as_sf(
      crs=LON_LAT_CRS,
      coords=c("plume_longitude_deg", "plume_latitude_deg"),
      agr=c(
        source_identifier = "identity", candidate_identifier = "identity",
        date_of_detection = "constant", source_type_best_estimate = "constant",
        emiss_kg_hr = "constant", emiss_se_kg_hr = "constant",
        datetime_of_detection = "constant", flight_name = "constant"
      )
    ) %>%
    sf::st_transform(OUTPUT_CRS)
  stopifnot(!anyNA(df$flight_name))
  df
}

read_jpl_sites <- function(input_files) {
  # Sites are different than plumes because they revisited
  # This excel file is translated from a PDF in the Durent et al. (2019)
  # supplementary materials.
  df <- readxl::read_excel(input_files[["duren_2019_sites"]]) %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::filter(source_type %in% !!SOURCE_TYPES) %>%
    dplyr::rename(
      emiss_kg_hr = qsource_kg_hr,
      persistence_frac = source_persistence_f,
    ) %>%
    # Run this rename twice because the sigma character gets assigned a different name on windows.
    rename_cols(c("emiss_se_kg_hr" = "x_q_kg_hr"), strict=FALSE) %>%
    rename_cols(c("emiss_se_kg_hr" = "x_u_f073_q_kg_hr"), strict=FALSE) %>%
    dplyr::select(-ipcc_sector, -confidence_in_persistence) %>%
    ensure_id_vars(source_identifier) %>%
    sf::st_as_sf(
      crs=LON_LAT_CRS,
      coords=c("source_longitude_deg", "source_latitude_deg"),
      agr=c(
        source_identifier = "identity", source_type = "constant",
        n_overflights = "constant", persistence_frac = "constant",
        emiss_kg_hr = "constant", emiss_se_kg_hr = "constant"
      )
    ) %>%
    sf::st_transform(OUTPUT_CRS)
  df
}

read_headers <- function(years, states) {
  header_dir <- here::here("data/generated/production/well_headers/")
  stopifnot(dir.exists(header_dir), !anyNA(years), !anyNA(states))
  date_min <- lubridate::make_date(min(years), 1, 1)
  date_max <- lubridate::make_date(max(years) + 1, 1, 1)

  well_headers <- arrow::open_dataset(header_dir) %>%
    dplyr::filter(
      state %in% !!states,
      # !(is.na(first_60_oil) & is.na(first_60_gas)),
      # !is.na(completion_date)
    ) %>%
    dplyr::select(
      county, state, production_type, drill_type, aapg_geologic_province,
      surface_latitude_wgs84, surface_longitude_wgs84, first_prod_date, last_prod_date,
      first_60_oil, first_60_gas, completion_date, spud_date, months_produced, entity_id,
    ) %>%
    dplyr::collect() %>%
    dplyr::filter(
      !is.na(surface_latitude_wgs84),
      !is.na(surface_longitude_wgs84),
      !is.na(first_prod_date),
      !is.na(last_prod_date),
      # Drop wells that started producing after the end of observation or ended
      # before the beginning (just for efficiency of not doing spatial operations
      # on a bunch of irrelevant wells)
      first_prod_date < !!date_max,
      last_prod_date > !!date_min,
      !production_type %in% c("WATER", "STEAMFLOOD"),
    ) %>%
    dplyr::rename(basin = aapg_geologic_province) %>%
    dplyr::mutate(production_type = dplyr::case_when(
      production_type %in% c("GAS", "OIL") ~ production_type,
      production_type == "GAS STORE" ~ "GAS",
      production_type == "OIL (CYCLIC STEAM)" ~ "OIL",
      TRUE ~ NA_character_
    )) %>%
    ensure_id_vars(entity_id) %>%
    sf::st_as_sf(
      crs=LON_LAT_CRS,
      coords=c("surface_longitude_wgs84", "surface_latitude_wgs84"),
      agr=c(
        entity_id = "identity", county = "constant", production_type = "constant",
        drill_type = "constant", basin = "constant", months_produced = "constant",
        surface_latitude_wgs84 = "constant", surface_longitude_wgs84 = "constant",
        first_prod_date = "constant", last_prod_date = "constant",
        first_60_oil = "constant", first_60_gas = "constant",
        completion_date = "constant", spud_date = "constant"
      )
    ) %>%
    sf::st_transform(OUTPUT_CRS)
  well_headers
}

match_wells_to_flight_day <- function(well_headers, flight_paths) {
  wells_flown_over <- well_headers %>%
    ensure_id_vars(entity_id) %>%
    # Only consider wells that are in the flight paths *and* are producing when
    # the plane flys over. This join generates duplicate rows wells because
    # the there are multiple flights over the same areas. Keep rows that had
    # any overpass while active, then drop duplicates.
    # (doing it this way because distinct doesn't work quite the same on sf objects)
    sf::st_join(dplyr::select(flight_paths, flight_date), join=sf::st_intersects, left=FALSE) %>%
    dplyr::filter(flight_date >= first_prod_date, flight_date <= last_prod_date) %>%
    sf::st_drop_geometry() %>%
    dplyr::distinct(entity_id, flight_date)

  dplyr::inner_join(well_headers, wells_flown_over, by="entity_id")
}

st_nn <- function(...){
  # Suppress nngeo's narration
  suppressMessages(nngeo::st_nn(...))
}

load_flight_paths <- function(plume_measurements) {
  # Flight paths:
  flight_df <- data.table::fread(
      here::here("data/studies/duren_etal_2019/AVIRIS-NG Flight Lines - AVIRIS-NG Flight Lines.csv"),
      data.table=FALSE
    ) %>%
    dplyr::as_tibble() %>%
    dplyr::rename(flight_name = Name) %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::mutate(
      flight_date = lubridate::make_date(.data$year, .data$month, .data$day),
      utc_hour = dplyr::if_else(dplyr::between(utc_hour, 0, 23), utc_hour, NA_integer_),
      utc_minute = dplyr::if_else(dplyr::between(utc_minute, 0, 23), utc_minute, NA_integer_),
      flight_datetime_utc = lubridate::make_datetime(
        .data$year, .data$month, .data$day, .data$utc_hour, .data$utc_minute
      ),
    ) %>%
    dplyr::semi_join(plume_measurements, by="flight_name") %>% # keep matching flights
    dplyr::select(flight_name, flight_date, flight_datetime_utc, number_of_lines, investigator,
      dplyr::starts_with("lat"), dplyr::starts_with("lon")
    ) %>%
    dplyr::distinct() %>%
    ensure_id_vars(flight_name)
  stopifnot(setequal(flight_df$flight_name, plume_measurements$flight_name))

  flight_polygons <- flight_df %>%
    dplyr::select(flight_name, dplyr::starts_with("lon"), dplyr::starts_with("lat")) %>%
    # This is terrible.
    tidyr::pivot_longer(dplyr::starts_with("lon"), names_to="idx1", names_prefix="lon", values_to="lon") %>%
    tidyr::pivot_longer(dplyr::starts_with("lat"), names_to="idx2", names_prefix="lat", values_to="lat") %>%
    dplyr::filter(idx1 == idx2) %>%
    dplyr::filter(!is.na(lat), !is.na(lon)) %>%
    sf::st_as_sf(coords = c("lon", "lat"), crs = LON_LAT_CRS) %>%
    dplyr::group_by(flight_name) %>%
    dplyr::summarize(
      geometry = sf::st_cast(sf::st_combine(geometry), "POLYGON"),
      .groups="drop"
    ) %>%
    sf::st_transform(OUTPUT_CRS)

  # Join back the variables (flight_polygons has to be first to get the s3 dispatch right)
  out <- dplyr::inner_join(flight_polygons,
    dplyr::select(flight_df, flight_name, flight_date, flight_datetime_utc, number_of_lines, investigator),
    by="flight_name"
  )
  if ("date_of_detection" %in% colnames(plume_measurements)) {
    # flight dates are detection dates:
    stopifnot(setequal(out$flight_date, plume_measurements$date_of_detection))
  }
  stopifnot(
    # detected plumes are inside flight paths:
    all(sf::st_intersects(plume_measurements, sf::st_union(out), sparse=FALSE))
  )
  out
}

load_monthly_prod <- function(wells_by_flight_day) {
  if (inherits(wells_by_flight_day, "sf")) {
    wells_by_flight_day %<>% sf::st_drop_geometry()
  }
  prod_dir <- here::here("data/generated/production/monthly_production/")
  stopifnot(
    dir.exists(prod_dir),
    is_id(wells_by_flight_day, entity_id, flight_date)
  )
  well_to_match <- wells_by_flight_day %>%
    dplyr::transmute(
      flight_date = flight_date,
      year = lubridate::year(flight_date),
      month = lubridate::month(flight_date),
      entity_id = entity_id
  )
  # These are just for speed, so we don't read anything we're positive we don't want
  desired_entity_id <- unique(well_to_match$entity_id)
  desired_year <- unique(well_to_match$year)
  desired_month <- unique(well_to_match$month)

  # Only diff from the stored schema is year, which currently registers as string
  # To select another column, add it to this schema.
  schema <- arrow::schema(
    year = arrow::int32(), month = arrow::int32(), entity_id = arrow::int32(),
    daily_avg_gas = arrow::float(), daily_avg_oil = arrow::float(), well_count = arrow::int32()
  )
  prod <- arrow::open_dataset(prod_dir, schema=schema) %>%
    dplyr::filter(
      daily_avg_gas > 0,
      year %in% !!desired_year,
      month %in% !!desired_month,
      entity_id %in% !!desired_entity_id,
    ) %>%
    dplyr::collect() %>%
    # 1:m join (U), no column conflicts (C)
    safejoin::safe_inner_join(well_to_match, by=c("entity_id", "year", "month"), check="U C")
  prod
}

load_well_records <- function(flight_paths, states, nat_gas_price) {
  years <- flight_paths$flight_date %>% lubridate::year() %>% unique()
  # First, find wells that were flown over and had production start before the
  # flight date and end after the flight date.
  # wells_by_flight_day has one row for each well each day it was flown over.
  wells_by_flight_day <- read_headers(years, states) %>%
    match_wells_to_flight_day(flight_paths)

  # Note: If the well was flown over in different months, this is the average
  # across months.
  monthly_prod_when_flown_over <- load_monthly_prod(wells_by_flight_day) %>%
    dplyr::group_by(entity_id) %>%
    dplyr::summarize(
      oil_avg_bbld = mean_(daily_avg_oil),
      gas_avg_mcfd = mean_(daily_avg_gas),
      .groups="drop"
    )
  # Recall wells_by_flight_day has one row for each well each day it was flown over.
  # Drop down to one row per well.
  # Do both well age and price here because it's easier than handling prices later
  well_age_and_price <- wells_by_flight_day %>%
    sf::st_drop_geometry() %>%
    dplyr::mutate(age_yr = as.numeric(flight_date - first_prod_date) / 365.25) %>%
    match_full_state_names() %>%
    dplyr::rename(state = state_full) %>%
    harmonize_basin_name() %>%
    match_commodity_prices(nat_gas_price) %>%
    dplyr::group_by(entity_id) %>%
    dplyr::summarize(
      age_yr = mean(age_yr),
      gas_price_per_mcf = mean(gas_price_per_mcf),
      gas_frac_methane = mean(gas_frac_methane),
      .groups="drop"
    )
  stopifnot(noNAs(well_age_and_price))

  wells <- wells_by_flight_day %>%
    dplyr::group_by(entity_id) %>%
    dplyr::select(-flight_date) %>%
    dplyr::distinct(.keep_all=TRUE) %>% # could speed this up here by not doing a spatial distinct
    ensure_id_vars(entity_id) %>%
    dplyr::inner_join(well_age_and_price, by="entity_id") %>% # 1:1 join
    harmonize_basin_name(group_small_CA_basins=TRUE)
  dplyr::inner_join(wells, monthly_prod_when_flown_over, by="entity_id") # 1:1 join
}

match_with_wells <- function(observed_sites, wells) {
  sites_matched <- sf::st_join(
      observed_sites, wells,
      join=st_nn, k=1, maxdist=MAX_ACCEPTABLE_MATCH_DIST_METERS,
      left=FALSE,
      progress=FALSE
    ) %>%
    # drop doubly-matched sites.
    dplyr::group_by(entity_id) %>%
    dplyr::filter(dplyr::n() == 1) %>%
    dplyr::ungroup() %>%
    ensure_id_vars(entity_id)
  if (nrow(sites_matched) != nrow(observed_sites)) {
    warning(
      nrow(observed_sites) - nrow(sites_matched),
      " leaks didn't have a matching well"
    )
  }
  sites_matched
}

make_pairs_plots <- function(jpl_sites_matched) {
  if (!rlang::is_installed("GGally")) {
    stop("GGally package required for this function")
  }
  pseudo_log <- scales::pseudo_log_trans() # basically asinh

  plt_matched <- jpl_sites_matched %>%
    dplyr::mutate_at(dplyr::vars(oil_avg_bbld, gas_avg_mcfd, emiss_kg_hr, emiss_se_kg_hr), list(log=pseudo_log$transform)) %>%
    sf::st_drop_geometry() %>%
    GGally::ggpairs(
      ggplot2::aes(color=production_type, alpha=0.7),
      columns = c("emiss_kg_hr_log", "oil_avg_bbld_log", "gas_avg_mcfd_log", "persistence_frac"),
      progress=FALSE
    ) +
    ggplot2::theme_bw()
  save_plot(plt_matched, here::here("graphics/pairs_plot_matched_wells_jpl_di.pdf"), scale=3)

}

make_plots <- function(wells_all, jpl_sites_matched, ground_studies) {
  # GAS_SCALE <- ggplot2::scale_x_continuous(trans="pseudo_log", breaks=c(0, 10, 100, 1000, 10000, 100000))
  # EMISS_SCALE <- ggplot2::scale_y_continuous(trans="pseudo_log", breaks=c(0, 10, 30, 100, 1000))

  make_plot_qq <- function(df, add_line=TRUE, ...) {
    df <- dplyr::filter(df, emiss_kg_hr > 0)

    params <- as.list(MASS::fitdistr(df$emiss_kg_hr, "log-normal")$estimate)
    # The `...` here is just so I can pass in a color aesthetic in the combined case
    plt <- ggplot2::ggplot(df, ggplot2::aes(sample=emiss_kg_hr, ...)) +
      ggplot2::geom_qq(distribution=stats::qlnorm, dparams=params) +
      ggplot2::scale_x_continuous(
        trans="pseudo_log",
        limits=c(-0.2, 3000),
        breaks=c(0, 10, 30, 100, 1000)
      ) +
      ggplot2::scale_y_continuous(
        trans="pseudo_log",
        limits=c(-0.2, 3000),
        breaks=c(0, 10, 30, 100, 1000)
      ) +
      ggplot2::theme_bw() +
      ggplot2::labs(
        x="Theoretical distribution (kg/hr)",
        y="Observed distribution (kg/hr)"
      )
    if (add_line) {
      plt <- plt + ggplot2::geom_qq_line(distribution=stats::qlnorm, dparams=params)
    }
    plt
  }
  make_plot_emiss_prod_point_density <- function(df, adjust=0.1) {
    min_emiss <- min_(dplyr::filter(df, emiss_kg_hr > 0)$emiss_kg_hr)
    plt <- df %>%
      dplyr::mutate(emiss_kg_hr_filled =
        dplyr::if_else(is.na(emiss_kg_hr), min_emiss - (0.1 * min_emiss), emiss_kg_hr)) %>%
      dplyr::filter(gas_avg_mcfd > 0) %>%
      ggplot2::ggplot(ggplot2::aes(x=gas_avg_mcfd, y=emiss_kg_hr_filled)) +
      ggpointdensity::geom_pointdensity(adjust=adjust, alpha=0.8) +
      # options are "magma", "plasma", "viridis", or "cividis"
      ggplot2::scale_color_viridis_c(option="inferno")+
      ggplot2::geom_hline(yintercept=min_emiss) +
      ggplot2::scale_x_continuous(
        trans="pseudo_log",
        breaks=c(0, 10^(1:6)),
        limits=c(0, 1.5e6)
      ) +
      ggplot2::scale_y_continuous(trans="pseudo_log", breaks=c(0, 10, 30, 100, 1000)) +
      ggplot2::theme_bw() +
      ggplot2::theme(legend.position="none") +
      ggplot2::labs(
        x="Average gas production (mcf/mo)",
        y="Measured emissions (kg/hr)"#,
        # color="Well\ncount"
      )
    plt
  }
  #
  # plt_ecdf_jpl <- ggplot2::ggplot(jpl_sites_matched, ggplot2::aes(x=emiss_kg_hr)) +
  #   ggplot2::stat_ecdf(geom="step") +
  #   ggplot2::scale_x_log10() +
  #   ggplot2::theme_bw()

  plt_obs_count_jpl <- (make_plot_emiss_prod_point_density(wells_all) +
    ggplot2::labs(
      title="Most wells have no JPL-observed emissisons"
    )) %>%
    save_plot(here::here("graphics/observation_count_jpl_flights.pdf"))

  plt_qq_jpl <- (make_plot_qq(jpl_sites_matched) +
    ggplot2::labs(
      title="Observed JPL measures are distributed log-normal"
    ))
    save_plot(plt_qq_jpl, here::here("graphics/jpl_flights_qq_plot.pdf"))

  ground_studies <- dplyr::rename(ground_studies,
    gas_avg_mcfd = gas_production_mcfd
  )
  plt_obs_count_ground <- (make_plot_emiss_prod_point_density(ground_studies, adjust=0.05) +
    ggplot2::labs(
      title="Ground measures have many fewer zeros"
    )) %>%
    save_plot(here::here("graphics/observation_count_ground_studies.pdf"))

  plt_qq_ground <- (make_plot_qq(ground_studies) +
    ggplot2::labs(
      title="Ground-based measurements"
    )) %>%
    save_plot(here::here("graphics/ground_studies_qq_plot.pdf"))

  to_plot_combined_qq <- dplyr::bind_rows(
      dplyr::mutate(ground_studies, src = "Ground studies"),
      sf::st_drop_geometry(dplyr::mutate(jpl_sites_matched, src = "JPL flights")),
    )
  plot_qq_combined <- make_plot_qq(to_plot_combined_qq, add_line = TRUE, color=src) +
    # ggplot2::facet_grid(rows="src") +
    ggplot2::labs(
      title="Comparison QQ plots"
    )
  save_plot(plot_qq_combined, here::here("graphics/combined_ground_jpl_qq_plot.pdf"))

}

aggregate_to_well_pads <- function(wells_all, well_pad_mapping_file) {
  stopifnot(
    anyDuplicated(wells_all$entity_id) == 0,
    length(well_pad_mapping_file) == 1
  )
  well_pad_mapping <- arrow::read_parquet(
    well_pad_mapping_file,
    col_select=c("entity_id", "well_pad_id", "well_pad_lon", "well_pad_lat")
  )
  # We drop study-specific columns here, but might want to bring them back.
  # n_overflights, persistence_frac, ...
  # NOTE: we're doing an inner join with well_pad_id here. That means wells that
  # aren't part of the well pad mapping will not be part of the output.
  well_pad_df <- wells_all %>%
    # Check that the join is unique by right side (V), types match (T),
    # columns don't collide (C), and all rows in left side are matched (m)
    # Note: not currently true that all LHS are matched -- m instead of M warns.
    safejoin::safe_inner_join(well_pad_mapping, by="entity_id", check="V m T C") %>%
    dplyr::group_by(well_pad_id) %>%
    dplyr::summarize(
      # These aggregating functions are defined in shared_functions.r
      emiss_kg_hr = mean_(emiss_kg_hr),
      emiss_se_kg_hr = mean_(emiss_se_kg_hr),
      production_type = Mode(production_type),
      drill_type = Mode(drill_type),
      first_60_oil = sum_(first_60_oil),
      first_60_gas = sum_(first_60_gas),
      oil_avg_bbld = sum_(oil_avg_bbld),
      gas_avg_mcfd = sum_(gas_avg_mcfd),
      months_produced = mean_(months_produced),
      county = Mode(county),
      # Like Lyon et al. 2016, we'll define pad age by age of the most recently
      # drilled well.
      age_yr = min(age_yr),
      gas_price_per_mcf = mean(gas_price_per_mcf), # probably unique by well pad
      # already unique by well_pad_id; the particular aggregation doesn't matter
      well_pad_lon = min(well_pad_lon),
      well_pad_lat = min(well_pad_lat),
      basin = min(basin),
      gas_frac_methane = min(gas_frac_methane),
      .groups="drop"
    )

  stopifnot(!anyNA(well_pad_df$well_pad_id))

  # NOTES on well pads:
  # - There's still a vast number of well pads with no measurements
  # (currently 112 with measurements and 17006 without. This could be improved
  # slightly by thinking about deduplication, but it's minimal.)
  well_pad_df
}

load_ground_studies <- function(input_files) {
  # Drop studies that can't be used for our application. See the notes in
  # notes/measurement_paper_notes.csv
  studies_that_dont_work_for_us <- c("Rella et al.")

  df_alvarez <- readxl::read_excel(input_files[["alvarez_2018"]], sheet="inputs_meas_sites") %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::select(basin, study, methane_emissions_kgh, gas_production_mcfd) %>%
    dplyr::rename(emiss_kg_hr = methane_emissions_kgh) %>%
    dplyr::mutate(
      study = dplyr::case_when(
        study == "Omara et al." ~ "Omara et al. (2016)",
        study == "Robertson et al." ~ "Robertson et al. (2017)",
        TRUE ~ study
      ),
      basin = dplyr::case_when(
        basin == "SWPA" ~ "SW Pennsylvania",
        basin == "Weld County" ~ "Denver Julesburg",
        TRUE ~ basin
      )
    )

  df_omara_2018 <- data.table::fread(input_files[["omara_2018"]], data.table=FALSE) %>%
    dplyr::as_tibble() %>%
    dplyr::transmute(
      basin = dplyr::case_when(
        site == "Uinta Basin (Uintah County, UT)" ~ "Uinta", # match with Alvarez et al.
        site == "Denver Julesburg Basin (Weld County, CO)" ~ "Denver Julesburg",
        site == "NE PA (Bradford, Susquehanna, Wyoming, Sullivan Counties)" ~ "NE Pennsylvania",
        TRUE ~ site
      ),
      study = "Omara et al. (2018)",
      gas_production_mcfd = tot_gas_prod_mcfd,
      emiss_kg_hr = avg_emissions_kg_per_h,
    )

  df_zavala_araiza_2018 <- readxl::read_excel(input_files[["zavala_araiza_2018"]]) %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::filter(method == "tracer flux") %>% # different sampling strategies
    dplyr::transmute(
      emiss_kg_hr = ch4_emission_rate_kgh,
      gas_production_mcfd = gas_production_mcfd,
      study = "Zavala-Araiza et al. (2018)",
      basin = "Alberta",
      # other variables that could be interesting:
      # ch4_lb, ch4_ub, oil_production_bbld, age_yr, wells_per_site,
      # reported_emissions_kgh, gas_composition_c1_percent
    )
  out <- dplyr::bind_rows(df_alvarez, df_omara_2018, df_zavala_araiza_2018)
  stopifnot(all(studies_that_dont_work_for_us %in% out$study)) # guard against typos
  out <- dplyr::filter(out, !study %in% !!studies_that_dont_work_for_us)
  out
}

compare_ground_studies_with_jpl <- function(wells_all, ground_studies) {
  # Table:
  # - N non-zero
  # - N zero
  # - mean, median leak rate
  # - correlation with size, among non-zero leaks
  # - age?
  wells_num_zero <- dplyr::filter(wells_all, is.na(emiss_kg_hr) | emiss_kg_hr < 5) %>% nrow()
  wells_all_stats <- dplyr::filter(wells_all, emiss_kg_hr >= 5, !is.na(gas_avg_mcfd)) %>%
    dplyr::summarize(
      source = "JPL flights",
      n_positive = dplyr::n(),
      mean_if_positive = mean_(emiss_kg_hr),
      corr_with_size = stats::cor(emiss_kg_hr, gas_avg_mcfd),
      .groups="drop"
    ) %>%
    dplyr::mutate(n_zero = !!wells_num_zero)
  ground_studies_num_zero <- dplyr::filter(ground_studies, is.na(emiss_kg_hr) | emiss_kg_hr == 0) %>% nrow()
  ground_studies_stats <- dplyr::filter(ground_studies, emiss_kg_hr > 0) %>%
    dplyr::summarize(
      source = "Ground studies",
      n_positive = dplyr::n(),
      mean_if_positive = mean_(emiss_kg_hr),
      # corr_with_size = stats::cor(emiss_kg_hr, gas_avg_mcfd)
      corr_with_size = 0,
      .groups="drop"
    ) %>%
    dplyr::mutate(n_zero = !!ground_studies_num_zero)
  tab <- cbind(t(wells_all_stats), t(ground_studies_stats))
  return(tab)
}

read_lyon_etal_2016 <- function(input_files) {
  locations <- readxl::read_excel(input_files[["lyon_etal_2016_locations"]]) %>%
    dplyr::as_tibble() %>%
    dplyr::rename_all(make_better_names) %>%
    ensure_id_vars(pad_id) %>%
    dplyr::select(-basin) %>%
    dplyr::mutate(longitude = as.double(stringr::str_trim(longitude))) # excel issues

  df <- readxl::read_excel(input_files[["lyon_etal_2016_measures"]], sheet=1) %>%
    dplyr::as_tibble() %>%
    dplyr::rename_all(make_better_names) %>%
    dplyr::select(-video_id) %>%
    ensure_id_vars(pad_id) %>%
    dplyr::left_join(locations, by="pad_id") %>%
    # Rename to harmonize with other datasets
    dplyr::rename(
      gas_avg_mcfd = gas_production_mcf_pad_day,
      oil_avg_bbld = oil_production_bbl_pad_day,
      detect_emiss = emissions_detected_0_no_1_yes,
    ) %>%
    dplyr::mutate(
      # This is approx the same as our definition
      age_yr = well_age_months_since_initial_production_of_newest_well / 12,
      flight_date = as.Date("2014-08-01"),
      state = basin_to_state(basin),
    )

  df
}

basin_to_state <- function(basin) {
  stopifnot(is.character(basin))
  # These aren't perfect, but for basins that cross states, we pick the majority
  # state. (But if you're worried about this, you should worry about using
  # state citygate prices instead.)
  conversions = c(
    "Bakken" = "North Dakota",
    "Barnett" = "Texas",
    "EagleFord" = "Texas",
    "Fayetteville" = "Arkansas",
    "Marcellus" = "Pennsylvania",
    "PowderRiver" = "Wyoming",
    "Uintah" = "Utah",
    "San Juan" = "New Mexico",
    "San Joaquin" = "California"
  )
  missing_basins <- setdiff(unique(basin), names(conversions))
  if (length(missing_basins) > 0) {
    stop("Missing states for these basins: ", paste(missing_basins, collapse=", "))
  }
  conversions[basin]
}

read_frankenberg_etal_2016 <- function(input_files) {
  measures <- readxl::read_excel(input_files[["frankenberg_etal_2016_measures"]]) %>%
    dplyr::transmute(
      # Do the rename and select in one step. Drop "Thorpe ID" and
      # "Rank of size of flux for 245 sources (Frankenberg et al., 2016)."
      plume_id = `Thorpe ID`,
      latitude = Latitude,
      longitude = Longitude,
      source_type = `Source designation (Thorpe, performed after Frankenberg et al., 2016, total 273 sources with some repeat observations)`,
      emiss_kg_hr = `Estimated flux for 245 sources (Frankenberg et al., 2016). Units kg CH4/hr`
    ) %>%
    dplyr::slice(-1) %>% # omit spacer row
    dplyr::filter(
      not_na(emiss_kg_hr), # sources that had a plume, but it wasn't quantified.
      source_type %in% c("Tank", "Tanks", "Unknown", "Unknown facility",
        "Unknown infrastructure", "Well completion", "Wellpad infrastructure"
      ),
      emiss_kg_hr == 0 | emiss_kg_hr > 0.001 # one *very* small value seems like a mistake
    )
  # Source types in the data:
  # source_type                n
  # Coal mine vent shaft       1
  # Gas processing plant      10
  # Natural                    3
  # Pipeline                   3
  # Tank                      61
  # Tanks                      1
  # Unknown                   10
  # Unknown facility          17
  # Unknown infrastructure     2
  # Well completion            1
  # Wellpad infrastructure   135

  sources <- readxl::read_excel(input_files[["frankenberg_etal_2016_sources"]]) %>%
    dplyr::rename_all(make_better_names)
  multiple_divider_row <- which(sources$x == "Multiple Overpasses")
  total_rows <- nrow(sources)
  stopifnot(identical(multiple_divider_row, 179L), total_rows > 179)
  sources <- dplyr::slice(sources, -multiple_divider_row) %>%
    dplyr::select(longitude, latitude, file) %>%
    dplyr::mutate(
      flight_name =
        stringr::str_match(.data$file, "^(ang\\d{8}t\\d{6})_ch4_v1e_img$")[, 2, drop=TRUE],
      flight_date = lubridate::ymd(
        stringr::str_match(.data$file, "^ang(\\d{8})t\\d{6}_ch4_v1e_img$")[, 2, drop=TRUE]
      )
    ) %>%
    ensure_id_vars(longitude, latitude)
  stopifnot(
    !anyNA(sources$flight_name),
    nrow(dplyr::anti_join(measures, sources, by=c("longitude", "latitude"))) == 0
  )
  # For these two, longitude and latitude match exactly because they're from the
  # same underyling measurements. Some sources were detected but don't have
  # quantified measurements.
  out <- dplyr::left_join(measures, sources, by=c("longitude", "latitude")) %>%
    # Drop 4 obs that don't have associated flight paths in the online data.
    # (They're just east of the flight main zone; not sure if they were recorded the same)
    dplyr::filter(! flight_name %in% c("ang20150421t160633", "ang20150423t150648")) %>%
    dplyr::select(-file) %>%
    sf::st_as_sf(
      crs=LON_LAT_CRS,
      coords=c("longitude", "latitude"),
      agr=c(
        plume_id = "identity", source_type = "constant", emiss_kg_hr = "aggregate",
        flight_name = "constant", flight_date = "constant"
      )
    ) %>%
    sf::st_transform(OUTPUT_CRS) %>%
    # Add all-NA variables to make it easier to share some functions
    dplyr::mutate(
      emiss_se_kg_hr = NA_real_,
    )
  out
}

write_datasets <- function(data_lst, output_lst) {
  data_names <- names(data_lst)
  stopifnot(
    length(data_names) == length(data_lst),
    all(data_names %in% names(output_lst))
  )
  for (nm in data_names) {
    df <- data_lst[[nm]]
    if (inherits(df, "sf")) {
      df <- geometry_to_lonlat(df)
    }
    arrow::write_parquet(df, output_lst[[nm]])
  }
  NULL
}

standardize_columns_for_plotting <- function(df_list, censor_threshold=5) {
  # NOTE: in the output of emiss_kg_hr, zero means emissions could have been
  # detected and quantified but were not. NA means emissions could not have been
  # quantified.
  df_list$jpl_wells_all %<>%
    dplyr::transmute(
      emiss_kg_hr = dplyr::if_else(is.na(emiss_kg_hr), 0, emiss_kg_hr),
      detect_emiss = emiss_kg_hr > 0,
      src = "California",
    )
  df_list$four_corners_all_wells %<>%
    dplyr::transmute(
      emiss_kg_hr = dplyr::if_else(is.na(emiss_kg_hr), 0, emiss_kg_hr),
      detect_emiss = emiss_kg_hr > 0,
      src = "Four Corners",
    )
  df_list$lyon %<>%
    dplyr::transmute(
      emiss_kg_hr = NA_real_,
      detect_emiss = detect_emiss == 1,
      src = "Lyon et al.",
    )
  df_list$ground_studies_censored_5kgh <- df_list$ground_studies %>%
    dplyr::transmute(
      emiss_kg_hr = dplyr::if_else(.data$emiss_kg_hr > 5, .data$emiss_kg_hr, 0),
      detect_emiss = .data$emiss_kg_hr > 0,
      src = "Ground studies censored at 5 kg/hr",
    )
  df_list$ground_studies_censored_10kgh <- df_list$ground_studies %>%
    dplyr::transmute(
      emiss_kg_hr = dplyr::if_else(.data$emiss_kg_hr > 10, .data$emiss_kg_hr, 0),
      detect_emiss = .data$emiss_kg_hr > 0,
      src = "Ground studies censored at 10 kg/hr",
    )
  df_list$ground_studies %<>%
    # dplyr::filter(emiss_kg_hr > 0) %>%
    dplyr::transmute(
      emiss_kg_hr = .data$emiss_kg_hr,
      detect_emiss = .data$emiss_kg_hr > 0,
      src = "Ground studies",
    )

  df_list
}

filter_within_distance <- function(x, y, max_dist=units::as_units(1000, "m")) {
  # Keep elements of x that are within 5000 m of any element of y
  y_buffer <- sf::st_geometry(y) %>%
    sf::st_union() %>%
    sf::st_buffer(dist=max_dist)
  stopifnot(length(y_buffer) == 1)
  sf::st_intersection(x, y_buffer)
}

data_to_check_matches <- function(snakemake) {
  outfile <- snakemake@output[["data_to_check_matches"]] %||% stop("Need outfile")
  # lyon <- read_lyon_etal_2016(snakemake@input)
  well_pad_mapping <- snakemake@input[["well_pad_crosswalk"]] %>%
    arrow::read_parquet(col_select=c("entity_id", "well_pad_id")) %>%
    ensure_id_vars(entity_id)

  jpl_sites <- read_jpl_sites(snakemake@input) %>%
    dplyr::rename(source_id = source_identifier)
  wells_ca <- read_headers(2016:2017, "CA") %>%
    filter_within_distance(jpl_sites) %>%
    dplyr::inner_join(well_pad_mapping, by="entity_id")

  four_corners_df <- read_frankenberg_etal_2016(snakemake@input) %>%
    dplyr::rename(source_id = plume_id)

  wells_nm_co <- read_headers(2015, c("CO", "NM")) %>%
    filter_within_distance(four_corners_df) %>%
    dplyr::inner_join(well_pad_mapping, by="entity_id")

  df_list <- readRDS(snakemake@output[["cleaned_matched_obs"]])
  jpl_matches <- df_list$jpl_wells_all %>%
    dplyr::filter(!is.na(emiss_kg_hr)) %>%
    dplyr::select(source_identifier, entity_id) %>%
    dplyr::rename(source_id = source_identifier)
  four_corners_matches <- df_list$four_corners_all_wells %>%
    dplyr::filter(!is.na(emiss_kg_hr)) %>%
    dplyr::select(plume_id, entity_id) %>%
    dplyr::rename(source_id = plume_id)
  stopifnot(
    anyDuplicated(jpl_matches$source_id) == 0,
    anyDuplicated(jpl_matches$entity_id) == 0,
    anyDuplicated(four_corners_matches$source_id) == 0,
    anyDuplicated(four_corners_matches$entity_id) == 0
  )
  jpl_sites %<>% dplyr::left_join(jpl_matches, by="source_id")
  four_corners_df %<>% dplyr::left_join(four_corners_matches, by="source_id")
  out <- list(
    ca    = list(measures = jpl_sites,       wells = wells_ca),
    co_nm = list(measures = four_corners_df, wells = wells_nm_co)
  )
  saveRDS(out, outfile)
  invisible(out)
}

match_full_state_names <- function(df) {
  state_names <- tibble::tibble(state_abb = state.abb, state_full = state.name)
  safejoin::safe_inner_join(state_names, df,
    # 1:m merge, require a match of all rows in df
    by=c("state_abb"="state"), check="B C U L T N"
  )
}

match_commodity_prices <- function(df, price_file) {
  stopifnot("basin" %in% names(df), length(price_file) == 1)
  prices <- arrow::read_parquet(price_file, col_select=c("date", "basin", "price_real")) %>%
    dplyr::rename(gas_price_per_mcf = price_real)
  out <- dplyr::mutate(df,
      date_monthly = first_of_month(flight_date),
    ) %>%
    safejoin::safe_inner_join(
      prices,
      # Check that the join is m:1, all rows of left have a match, and there are
      # no conflicts in columns or types.
      by=c("basin"="basin", "date_monthly"="date"), check="B C V M L T"
    ) %>%
    dplyr::select(-date_monthly) %>%
    dplyr::mutate(
      gas_frac_methane = 0.95, # APPROX!
    )
  out
}

write_match_percent <- function(count_total, count_match, outfile) {
  stopifnot(
    count_total > 100, count_total < 10000, count_match > 0,
    count_match <= count_total, length(outfile) == 1
  )
  count_unmatch <- count_total - count_match
  pct_drop <- signif(100 * count_unmatch / count_total, 2)
  writeLines(paste0(pct_drop, "\\%%"), outfile)
  invisible(NULL)
}

match_jpl_california <- function(input_files) {
  # Note:
  # Some wells are flown over multiple times. Ideally, we would use that
  # information to estimate leak probabilities for these wells. However, for the
  # short run, we're going to count wells that weren't leaking on their first
  # overpass as not leaking, becuase we definitely want to avoid overstating the
  # leak rate (as using ever-leak status would do)
  #
  # The steps we follow are:
  # - Match each flight polygon with wells.
  # - For each matched well, only keep one randomly drawn flight
  # - Within groups of flight ID, match wells to detected plumes.
  # - Report counts of all plumes, matched plumes, and plumes that would have
  #   matched if we weren't dropping the later overpasses.
  plumes <- read_jpl_plumes(input_files) %>%
    dplyr::filter(!is.na(emiss_kg_hr), !is.na(emiss_se_kg_hr))
  flight_paths <- load_flight_paths(plumes)

  # Load the well covariates.
  # Note: this is could be improved a bit, because for repeat flyovers we want the
  # specific draw we have, but load_well_records returns the average for the
  # well across all visits.
  well_info <- load_well_records(flight_paths, "CA", input_files$nat_gas_prices)


  # This is like match_wells_to_flight_day, but different because we pick out
  # individual flights instead of flight days (some days have multiple flights)
  # All flights are in CA
  wells_flight_match <- read_headers(
      years=unique(lubridate::year(flight_paths$flight_date)),
      states="CA"
    ) %>%
    # Keep only wells that we kept in load_well_records (dropping gas == 0)
    dplyr::filter(.data$entity_id %in% well_info$entity_id) %>%
    # Only consider wells that are in the flight paths *and* are producing when
    # the plane flys over. This join generates duplicate rows wells because
    # the there are multiple flights over the same areas. Keep rows that had
    # any overpass while active, then drop duplicates.
    # (doing it this way because distinct doesn't work quite the same on sf objects)
    sf::st_join(
      dplyr::select(flight_paths, flight_date, flight_name),
      join=sf::st_intersects,
      left=FALSE
    ) %>%
    dplyr::filter(flight_date >= first_prod_date, flight_date <= last_prod_date) %>%
    dplyr::select(entity_id, flight_date, flight_name) %>%
    ensure_id_vars(entity_id, flight_date, flight_name)

  flight_names <- unique(wells_flight_match$flight_name)
  stopifnot(noNAs(wells_flight_match), length(flight_names) > 10)

  wells_rand_overpass <- wells_flight_match %>%
    dplyr::group_by(entity_id) %>%
    dplyr::slice_sample(n = 1) %>%
    dplyr::ungroup()

  .match_one_overpass <- function(flight_name, plumes, wells) {
    plumes %<>% dplyr::filter(.data$flight_name == !!flight_name) %>%
      dplyr::select(-flight_name)
    wells  %<>% dplyr::filter(.data$flight_name == !!flight_name)
    if (nrow(plumes) == 0 || nrow(wells) == 0) {
      return(NULL)
    }
    matched <- sf::st_join(plumes, wells,
        join=st_nn, k=1, progress=FALSE, left=FALSE
      ) %>%
      sf::st_drop_geometry() %>%
      # Now, if there are multiple plumes observed and matched to a well
      # _from the same flight_, then average them.
      dplyr::group_by(entity_id) %>%
      dplyr::summarize(
        emiss_kg_hr = mean_(emiss_kg_hr),
        emiss_se_kg_hr = mean_(emiss_se_kg_hr),
        .groups="drop"
      )
    all_wells_flown_over <- wells %>%
      sf::st_drop_geometry() %>%
      safejoin::safe_left_join(matched, by="entity_id", check="B C U V N L T")
    all_wells_flown_over
  }
  # Loop over flight_name and match wells and plumes for each flight.
  # Note: there's absolutely a better way to do this.
  all_wells_flown_over <- purrr::map_dfr(
    flight_names, .match_one_overpass,
    plumes=plumes, wells=wells_rand_overpass
  )
  stopifnot(
    anyDuplicated(all_wells_flown_over$entity_id) == 0,
    all(all_wells_flown_over$entity_id %in% well_info$entity_id)
  )
  all_wells_flown_over %<>% dplyr::inner_join(well_info, by="entity_id")
  matched_wells <- dplyr::filter(all_wells_flown_over, !is.na(emiss_kg_hr))

  observed_well_pads <- aggregate_to_well_pads(all_wells_flown_over, input_files$well_pad_crosswalk)

  could_have_been_matches <- sf::st_join(
    plumes, wells_flight_match,
    join=st_nn, k=1,
    maxdist=MAX_ACCEPTABLE_MATCH_DIST_METERS,
    left=FALSE,
    progress=FALSE
  )

  list(
    observed_well_pads = observed_well_pads,
    matched_wells = matched_wells,
    count_total = nrow(plumes),
    count_matched = sum(!is.na(all_wells_flown_over$emiss_kg_hr)),
    count_would_have_matched = nrow(could_have_been_matches)
  )
}


if (!exists("snakemake")) {
  snakemake <- SnakemakePlaceholder(
    input = list(
      well_pad_crosswalk = "data/generated/production/well_pad_crosswalk_1970-2018.parquet",
      headers = glue::glue("data/generated/production/well_headers/first_prod_year={year}/file.parquet", year=1990:2018),
      prod = glue::glue("data/generated/production/monthly_production/year={year}/file.parquet", year=1990:2018),
      alvarez_2018 = "data/studies/alvarez_etal_2018/aar7204_Database_S1.xlsx",
      omara_2018 = "data/studies/omara_etal_2018/Omara_etal_SI_tables.csv",
      duren_2019_plumes = "data/studies/duren_etal_2019/Plume_list_20191031.csv",
      duren_2019_sites = "data/studies/duren_etal_2019/41586_2019_1720_MOESM3_ESM.xlsx",
      lyon_etal_2016_locations = "data/studies/lyon_etal_2016/es6b00705_si_005.xlsx",
      lyon_etal_2016_measures = "data/studies/lyon_etal_2016/es6b00705_si_004.xlsx",
      frankenberg_etal_2016_sources = "data/studies/frankenberg_etal_2016/AVNG_sources_all2.xlsx",
      frankenberg_etal_2016_measures = "data/studies/frankenberg_etal_2016/FourCorners_AV_NG_detections_Werner.xlsx",
      zavala_araiza_2018 = "data/studies/zavala-araiza_etal_2018/elementa-6-284-s1.xlsx",
      nat_gas_prices = "data/generated/nat_gas_prices_by_basin.parquet"
    ),
    output = list(
      plot_obs_count_jpl = "graphics/observation_count_jpl_flights.pdf",
      plot_jpl_flights_qq = "graphics/jpl_flights_qq_plot.pdf",
      lyon_etal_2016 = "data/generated/methane_measures/lyon_etal_2016.parquet",
      data_to_check_matches = "data/generated/methane_measures/data_to_check_matches.rds",
      cleaned_matched_obs = "data/generated/methane_measures/matched_wells_all.rds",
      aviris_match_fraction_dropped = "output/tex_fragments/aviris_match_fraction_dropped.tex"
    ),
    threads = 4,
    resources = list(mem_mb = 7000),
    rule = ""
  )
}

main <- function(snakemake, make_extra_plots=FALSE) {
  ground_studies <- load_ground_studies(snakemake@input)
  lyon <- read_lyon_etal_2016(snakemake@input) %>% match_commodity_prices(snakemake@input$nat_gas_prices)
  four_corners_df <- read_frankenberg_etal_2016(snakemake@input)

  jpl_ca_list <- match_jpl_california(snakemake@input)
  jpl_wells_all <- jpl_ca_list$observed_well_pads
  jpl_sites_matched <- jpl_ca_list$matched_wells

  # Keep track of the total number of sites and how many were matched.
  # (Note that read_jpl_sites already drops some non-O&G sites)
  # Note: these numbers are used to determine geographic match quality, so I'm
  # including counts that would have matched but for the random-overpass filter
  # (see details in match_jpl_california)
  count_aviris_total <- jpl_ca_list$count_total
  count_aviris_match <- jpl_ca_list$count_would_have_matched


  wells_nm_co <- load_flight_paths(four_corners_df) %>%
    load_well_records(c("CO", "NM"), snakemake@input$nat_gas_prices)

  four_corners_matched <- match_with_wells(four_corners_df, wells_nm_co)
  # Keep track of the total number of sites and how many were matched.
  count_aviris_total <- count_aviris_total + nrow(four_corners_df)
  count_aviris_match <- count_aviris_match + nrow(four_corners_matched)

  four_corners_vars_to_keep <- c(setdiff(colnames(four_corners_df), "geometry"), "entity_id")
  four_corners_all_wells <- four_corners_matched %>%
    sf::st_drop_geometry() %>%
    dplyr::select(!!four_corners_vars_to_keep) %>%
    dplyr::right_join(wells_nm_co, by="entity_id") %>%
    aggregate_to_well_pads(snakemake@input[["well_pad_crosswalk"]])

  df_list <- list(
    jpl_wells_all = jpl_wells_all,
    four_corners_all_wells = four_corners_all_wells,
    lyon = lyon,
    ground_studies = ground_studies
  )
  saveRDS(df_list, snakemake@output[["cleaned_matched_obs"]])

  # NOTE: not using these plots.
  if (make_extra_plots) {
    # Make plots of JPL measures and ground studies.
    make_plots(jpl_wells_all, jpl_sites_matched, ground_studies)
    # Make plots of distributions across studies.
  }

  write_match_percent(count_aviris_total, count_aviris_match,
    snakemake@output$aviris_match_fraction_dropped
  )

  # Create output data:
  # To add to this list, make sure the list names match names in snakemake@output
  # list(lyon_etal_2016 = read_lyon_etal_2016(snakemake@input)) %>%
    # write_datasets(snakemake@output)
}

arrow::set_cpu_count(snakemake@threads)
memory_limit(snakemake@resources[["mem_mb"]])


main(snakemake)
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
import os

os.environ["OPENBLAS_NUM_THREADS"] = "1"

# conda:
import numpy as np
import pandas as pd
import pyarrow.parquet as pq

import scipy.sparse
import scipy.stats
import pyprojroot
import ipopt  # installed from conda as cyipopt, except on windows


# standard lib
from pathlib import Path
from collections import namedtuple
from functools import partial
import json
import datetime
import pickle
import logging
import re
import warnings
import contextlib
import io
import sys

# The math details (FOC etc) are in outcomes_analysis_derivative_functions.py
# The logic of using those FOC are in this file.
from outcomes_analysis_helpers import (
    ProblemSetup,
    dwl_per_well,
    calc_well_audit_prob_uniform,
    prob_leak_with_policy,
    is_probability,
    read_constants,
    abatement_cost_per_pad,
    AuditInfo,
    DataParam,
    OutcomesOneDraw,
)

const = read_constants()
TAU_LEVELS = const["TAU_LEVELS"]
T_LEVELS = const["T_LEVELS"]
AUDIT_COST_TO_COMPARE_FIXED_VS_OPTIM_DWL = const[
    "AUDIT_COST_TO_COMPARE_FIXED_VS_OPTIM_DWL"
]
SOCIAL_COST_METHANE_PER_KG = const["SOCIAL_COST_METHANE_PER_KG"]
MODEL_NAMES = const["MODEL_NAMES"]
METHANE_GWP = const["METHANE_GWP"]


def identity(x):
    return x


def _helper_common_cost_param(b, Y, X):
    resid = Y - X @ np.atleast_2d(b).T
    return resid[:, 0]


def calc_common_cost_param(prob_leak, e_size_expect, price_H):
    """
    Use least squares to solve for the A and α coefficients of the cost function
    for this MCMC draw.
    (Doing this here because I'm having a hard time getting it to work in Stan)
    """
    N = len(prob_leak)
    assert N == len(e_size_expect) == len(price_H) > 1
    e_p = e_size_expect * price_H
    Y = np.expand_dims(np.log(e_p), axis=1)
    X = np.column_stack((np.ones_like(prob_leak), np.log(prob_leak)))
    # These are the max acceptable values (with some fudge factor) to make the
    # cost functions work out.
    # Here we estimate with constrained least squares.
    max_acceptable_A_log = np.log(np.min(e_p) * 0.999)
    max_acceptable_α = -1.001
    soln = scipy.optimize.least_squares(
        fun=_helper_common_cost_param,
        x0=np.array((max_acceptable_A_log, max_acceptable_α)),
        bounds=(
            (-np.inf, -np.inf),  # lower bounds
            (max_acceptable_A_log, max_acceptable_α),  # upper bounds
        ),
        args=(Y, X),
        method="dogbox",
    )
    coefs = soln.x
    A = np.exp(coefs[0])
    α = coefs[1]
    assert A <= np.min(e_p)
    assert α <= max_acceptable_α
    coefs_mat = np.full((N, 2), np.array([[A, α]]))
    return coefs_mat


def read_generated_iter(input, varnames):
    """Read the parquet files from the Stan generated quantities
    Provides a generator over MCMC draws to operate on each individually

    param: input is snakemake.input or other dictionary-like with file paths
    varnames is the variable name to read. Variables are stored one per file.
    """
    assert len(varnames) >= 1
    pq_files = [pq.ParquetFile(input[v], memory_map=True) for v in varnames]
    ncol = pq_files[0].metadata.num_columns
    for draw_id in range(ncol):
        draw_id = str(draw_id + 1)  # convert from R index to python zero-based
        draws = [
            f.read(draw_id).column(0).to_pandas(zero_copy_only=True).to_numpy()
            for f in pq_files
        ]
        yield draws


def read_generated_col(draw_id, input, varnames):
    """Read the parquet files from the Stan generated quantities
    Provides a generator over MCMC draws to operate on each individually

    param: input is snakemake.input or other dictionary-like with file paths
    varnames is the variable name to read. Variables are stored one per file.
    """
    assert len(varnames) >= 1
    draw_id = str(draw_id + 1)  # convert from R index to python zero-based
    pq_files = [pq.ParquetFile(input[v], memory_map=True) for v in varnames]
    draws = [
        f.read(draw_id).column(0).to_pandas(zero_copy_only=True).to_numpy()
        for f in pq_files
    ]
    return draws


def read_sdata(json_file):
    """Read the data that was input to Stan"""

    with open(json_file, "rt") as f:
        sdata = json.load(f)
    df = pd.DataFrame(sdata["X"], columns=sdata["X_varnames"])
    # Note that "Y", "noise", "price" are not columns in X.
    # See distribution_model_data_prep.R
    for c in {"Y", "noise", "price"}:
        df[c] = sdata[c]
    return df


def run_ipopt(data_param, audit_info, r_guess):
    N = len(data_param.e_size_expect)
    if audit_info.audit_frac > 0:
        budget_const = [N * audit_info.audit_frac]
    else:
        budget_const = []
    num_constr = len(budget_const)  # (variable bounds don't count)
    r_bound_low = np.zeros_like(r_guess)
    r_bound_high = np.ones_like(r_guess)
    prob = ipopt.problem(
        n=len(r_guess),  # N variables to optimize
        m=num_constr,
        problem_obj=ProblemSetup(data_param, audit_info),
        lb=r_bound_low,
        ub=r_bound_high,
        cu=budget_const,  # no lower bound
    )
    # print_level 0 still prints some stuff, but at least we'll save a little
    # time printing progress reports into the void.
    prob.addOption("print_level", 0)
    prob.addOption(
        "acceptable_iter", 30
    )  # How long to stay at an okay but not great solution? (Higher than default of 15)
    with contextlib.redirect_stdout(None):
        r, info = prob.solve(r_guess)
    # Pull the shadow value out
    assert len(info["mult_g"]) == num_constr
    if num_constr == 1:
        λ = -info["mult_g"][0]
    elif num_constr == 0:
        λ = -audit_info.audit_cost
    else:
        raise ValueError("unexpected number of shadow values")
    successful_solve = info["status"] == 0
    if not successful_solve:
        if info["status"] == 1:
            warnings.warn(
                "Solving to desired tolerances failed, but reached 'acceptable' tolerances",
                RuntimeWarning,
            )
        else:
            raise RuntimeError(f"Solving failed, status message {info['status_msg']}")
    return r, λ


def calc_outcomes_once(audit_info, data_param, r_guess=None):
    """
    Calculate the outcomes of one policy, for one MCMC draw.

    Output:
    Instance of OutcomesOneDraw named tuple, and r

    """
    N = len(data_param.e_size_expect)
    if r_guess is None:
        if audit_info.audit_rule == "uniform":
            r_guess = np.array([audit_info.audit_frac])
        elif audit_info.audit_rule == "target_e" and audit_info.detect_threshold > 0:
            # 2 probs per well -- see math writeup.
            # (Code should also accept shape (2,))
            r_guess = np.full((2 * N,), audit_info.audit_frac)
        else:
            r_guess = np.full((N,), audit_info.audit_frac)

    # Use a bunch of if-else here (rather than the cleaner dict of fns) to
    # get better tracebacks.
    if audit_info.audit_rule == "none":
        r = np.zeros_like(data_param.e_size_expect)
        λ = 0.0
        expected_fee_charged = audit_info.τT * r
    elif audit_info.audit_rule == "remote":
        r = np.zeros_like(data_param.e_size_expect)
        λ = 0.0
        # Wells with e < audit_info.detect_threshold face no price
        expected_fee_charged = audit_info.τT * (
            data_param.e_size_draw >= audit_info.detect_threshold
        )
    elif audit_info.audit_rule == "uniform" and audit_info.audit_frac > 0:
        r, λ = calc_well_audit_prob_uniform(data_param, audit_info, r_guess)
        expected_fee_charged = audit_info.τT * r
    elif audit_info.audit_rule == "target_e" and audit_info.detect_threshold > 0:
        r_big_r_small, λ = run_ipopt(data_param, audit_info, r_guess)
        # The well knows their size, so the effective fee per kg is now r_big if
        # the e is big and r_small if the e is small. This is different than the
        # probability-mix we had in the regulator's problem.
        # See the math notes for more.
        # r_big_r_small has shape (2 * N,), with r for big-e wells and r for
        # small-e wells, but appended together into one long vector because
        # that's how ipopt needs things to be. Use np.where to pick.
        r = np.where(
            data_param.e_size_draw > audit_info.detect_threshold,
            r_big_r_small[:N],
            r_big_r_small[N:],
        )
        expected_fee_charged = audit_info.τT * r
        # In small tests, it's faster to use the uniform than to start from the last point
    else:
        r, λ = run_ipopt(data_param, audit_info, r_guess)
        expected_fee_charged = audit_info.τT * r

    # Need to have the right shape here, so special-case target_e_high
    if audit_info.audit_rule == "target_e" and audit_info.detect_threshold > 0:
        r_guess_next = r_big_r_small
    else:
        r_guess_next = r

    # Score the r chosen above, now using e_size_draw instead of e_size_expect
    # Note that dwl here will include the audit cost when audit_cost > 0, but
    # not when audit_frac > 0, so we add it in.
    # (Easier to do it here, before the calc_relative_outcomes function)
    dwl = dwl_per_well(expected_fee_charged, data_param)
    if audit_info.audit_frac > 0 and audit_info.audit_cost == 0:
        additional_audit_cost_mean = (
            audit_info.audit_frac * AUDIT_COST_TO_COMPARE_FIXED_VS_OPTIM_DWL
        )
    else:
        additional_audit_cost_mean = 0.0
    additional_audit_cost_tot = additional_audit_cost_mean * N
    new_prob_leak = prob_leak_with_policy(
        expected_fee_charged,
        data_param.cost_coef,
        data_param.e_size_draw,
        data_param.price_H,
    )
    bau_prob_leak = prob_leak_with_policy(
        np.zeros_like(expected_fee_charged),
        data_param.cost_coef,
        data_param.e_size_draw,
        data_param.price_H,
    )
    emis = data_param.e_size_draw * data_param.time_H * new_prob_leak
    # Note: this is a little redundant to calculate BAU values here and
    # separately in calc_all_outcomes_per_draw, but I can live with it. I want
    # it available here, before aggregating across wells, but I don't want to
    # change the rest of the code that works with aggregated data.

    emis_bau = data_param.e_size_draw * data_param.time_H * bau_prob_leak
    # Note: "expected_fee_charged" is the well's expectation of the fee they'll
    # face per unit of leak rate (i.e. tau * T * r_i = total fee / e_i). The
    # units are dollars-hours per kg. That works for thinking about incentives,
    # but we also want to make comparisons with the social cost, so provide a
    # number that's just dollars per kilogram (using the true kg of emissions,
    # according to the model).
    # Total expected fee: tau * T * r_i * e_i * (1 - q_i)
    # Total expected emissions: e_i * (1 - q_i) * H
    # Expected fee per unit of emissions: tau * T * r_i / H
    expected_fee_per_kg = expected_fee_charged / data_param.time_H
    fee_quantiles = np.quantile(expected_fee_per_kg, (0.5, 0.1, 0.9))
    # Note that tot_cost_per_kg is the expected cost to the operator of having
    # one more kg emissions: the expected fee plus the commodity value lost.
    # It does *not* include other costs, like abatement or audits.
    gas_price_per_kg_ch4 = data_param.price_H / data_param.time_H

    tot_cost_per_kg = expected_fee_per_kg + gas_price_per_kg_ch4

    # Now, we also want to compare to the gas price, so convert back from kg CH4
    # to mcf natural gas (see notes in distribution_model_data_prep.r and
    # match_jpl_measurements.R)
    methane_kg_per_mcf = 18.8916  # Note: this is for pure CH4, not nat gas
    approx_ch4_fraction = 0.95
    gas_price_per_mcf_gas = (
        gas_price_per_kg_ch4 * methane_kg_per_mcf * approx_ch4_fraction
    )

    # direct_private_cost_per_pad answers the question "how much do the well
    # pad's costs go up per period H?" This includes fees and abatement costs.
    # It does notinclude changes in revenue (from additional gas captured).
    direct_private_cost_per_pad = (
        abatement_cost_per_pad(new_prob_leak, data_param)
        - abatement_cost_per_pad(bau_prob_leak, data_param)
        + expected_fee_per_kg * emis
    )
    # net private cost = abatement cost + expected fee - additional revenue
    gas_prod_mcf_per_H = data_param.gas_avg_mcfd / 24.0 * data_param.time_H
    # We want the fractional change (fraction of price, rather than $) because
    # we'll use this number in an elasticity calculation later.
    net_private_cost_per_mcf_pct_price = (
        100
        * (direct_private_cost_per_pad - gas_price_per_kg_ch4 * (emis_bau - emis))
        / gas_prod_mcf_per_H
        / gas_price_per_mcf_gas
    )
    net_private_cost_per_mcf_pct_price_weighted = np.average(
        net_private_cost_per_mcf_pct_price, weights=data_param.gas_avg_mcfd
    )

    outcomes = OutcomesOneDraw(
        dwl_mean=np.mean(dwl) + additional_audit_cost_mean,
        dwl_tot=np.sum(dwl) + additional_audit_cost_tot,
        emis_mean=np.mean(emis),
        emis_tot=np.sum(emis),
        tot_cost_per_kg_mean=np.mean(tot_cost_per_kg),
        fee_per_kg_mean=np.mean(expected_fee_per_kg),
        fee_per_kg_med=fee_quantiles[0],
        fee_per_kg_p10=fee_quantiles[1],
        fee_per_kg_p90=fee_quantiles[2],
        net_private_cost_per_mcf_pct_price=net_private_cost_per_mcf_pct_price_weighted,
        shadow_price=λ,
        audit_rule=audit_info.audit_rule,
        audit_frac=audit_info.audit_frac,
        τT=audit_info.τT,
        detect_threshold=audit_info.detect_threshold,
        audit_cost=audit_info.audit_cost,
    )
    # returning r here is only useful to use as r_guess for the next iter
    # (and in testing)
    return outcomes, r_guess_next


def check_outcomes(df):
    """Run some tests on the outcomes. Return the original df"""

    # 1. Uniform should have uniform fee
    df_uni = df.query("audit_rule == 'uniform'")
    assert (df_uni.fee_per_kg_mean == df_uni.fee_per_kg_med).all()

    return df


def get_input_files(snakemake):
    """
    All model-generate produce leak_size_expect.parquet, stan_data.json,
    and leak_size_draw.parquet.
    The code also depends on either prob_leak.parquet or cost_param_A.parquet
    and cost_param_alpha.parquet in the same folder.
    Read those in based on the model name.
    """
    input_dir = Path(snakemake.input["stan_data_json"]).resolve().parent
    assert input_dir.is_dir()
    model_name = snakemake.wildcards["model_name"]
    input_files = {
        "stan_data_json": input_dir / "stan_data.json",
        "leak_size_expect": input_dir / "leak_size_expect.parquet",
        "leak_size_draw": input_dir / "leak_size_draw.parquet",
        "prob_size_above_threshold": input_dir / "prob_size_above_threshold.parquet",
    }
    if model_name in MODEL_NAMES["cost_coef_models"]:
        input_files["cost_param_A"] = input_dir / "cost_param_A.parquet"
        input_files["cost_param_alpha"] = input_dir / "cost_param_alpha.parquet"
    else:
        input_files["prob_leak"] = input_dir / "prob_leak.parquet"
    return input_files


def calc_all_outcomes_per_draw(
    draw_id, input_files, audit_info, price_H, time_H, gas_avg_mcfd, r_guess=None
):
    # Read and unpack common values from parquet files.
    data_to_read = (
        "leak_size_expect",
        "leak_size_draw",
        "prob_size_above_threshold",
    )
    data_list = read_generated_col(draw_id, input_files, data_to_read)
    e_size_expect = data_list[0]
    e_size_draw = data_list[1]
    prob_size_above_threshold = data_list[2]

    if "cost_param_alpha" in input_files.keys():
        cost_coef = np.column_stack(
            read_generated_col(
                draw_id, input_files, ("cost_param_A", "cost_param_alpha")
            )
        )
    else:
        cost_coef = calc_common_cost_param(
            prob_leak=read_generated_col(draw_id, input_files, ["prob_leak"]),
            e_size_expect=e_size_expect,
            price_H=price_H,
        )
    data_param = DataParam(
        price_H=price_H,
        e_size_expect=e_size_expect,
        e_size_draw=e_size_draw,
        cost_coef=cost_coef,
        time_H=time_H,
        prob_is_large=prob_size_above_threshold,
        gas_avg_mcfd=gas_avg_mcfd,
    )

    try:
        outcomes_policy, _ = calc_outcomes_once(audit_info, data_param, r_guess)
    except RuntimeError:
        raise RuntimeError(f"Failed to converge for draw {draw_id}, audit {audit_info}")
    # Also do BAU and optimal. These are faster, since they're analytical
    audit_bau = AuditInfo(
        audit_rule="none", audit_frac=0.0, τT=0.0, detect_threshold=0.0, audit_cost=0.0
    )
    audit_optimal = AuditInfo(
        audit_rule="remote",
        audit_frac=0.0,
        τT=SOCIAL_COST_METHANE_PER_KG * time_H,
        detect_threshold=0.0,
        audit_cost=0.0,
    )

    outcome_bau, _ = calc_outcomes_once(audit_bau, data_param)
    outcome_optimal, _ = calc_outcomes_once(audit_optimal, data_param)
    return outcomes_policy, outcome_bau, outcome_optimal


def calc_all_outcomes_all_draws(snakemake):
    """
    Calculate all outcomes for all MCMC draws.
    (this takes a while, depending on the number of cases considered)
    """
    input_files = get_input_files(snakemake)
    sdata = read_sdata(input_files["stan_data_json"])
    price = sdata["price"]  # gas_price_per_kg_ch4
    gas_avg_mcfd = np.sinh(sdata["asinhgas_avg_mcfd"])
    time_H = parse_period_wildcard(snakemake.wildcards["time_period"])
    price_H_dollar_hr_per_kg = price * time_H
    audit_info = parse_audit_info(snakemake.wildcards)
    logging.info(audit_info)
    num_draws = 4000
    # Here I designed the loop to allow you to provide last round's r as a guess
    # for this round's r, thinking that would allow ipopt to converge faster.
    # It's actually slower, so we'll just provide r_guess = None for every iter.
    r_guess = None
    outcome_policy_list = []
    outcome_bau_list = []
    outcome_optimal_list = []

    for i in range(num_draws):
        outcomes_policy, outcome_bau, outcome_optimal = calc_all_outcomes_per_draw(
            i,
            input_files=input_files,
            audit_info=audit_info,
            price_H=price_H_dollar_hr_per_kg,
            time_H=time_H,
            gas_avg_mcfd=gas_avg_mcfd,
            r_guess=None,
        )
        outcome_policy_list.append(outcomes_policy)
        outcome_bau_list.append(outcome_bau)
        outcome_optimal_list.append(outcome_optimal)

    df_policy = pd.DataFrame.from_records(
        outcome_policy_list, columns=OutcomesOneDraw._fields
    )
    df_bau = pd.DataFrame.from_records(
        outcome_bau_list, columns=OutcomesOneDraw._fields
    )
    df_optimal = pd.DataFrame.from_records(
        outcome_optimal_list, columns=OutcomesOneDraw._fields
    )
    df_policy["draw_id"] = range(num_draws)
    df_bau["draw_id"] = range(num_draws)
    df_optimal["draw_id"] = range(num_draws)

    df_all = calc_relative_outcomes(
        df_policy, df_bau, df_optimal, time_H, gas_avg_mcfd
    ).pipe(check_outcomes)
    return df_all


def calc_relative_outcomes(df_policy, df_bau, df_optimal, time_H, gas_avg_mcfd):
    """
    Calculate outcomes relative to the first best outcome for
    variables dwl_mean, dwl_tot, emis_mean, and emis_tot.
    1. Relative outcomes, e.g. emis_mean_rel_pct, are on a scale from 0 (BAU) to
    100 (first best). Note that first-best emissions won't be zero kg.
    2. Difference in outcomes, e.g. emis_reduce_mean, from BAU.
    """
    # Don't construct these from strings so I can grep for them later.
    var_to_normalize = {
        "dwl_mean": "dwl_mean_rel_pct",
        "dwl_tot": "dwl_tot_rel_pct",
        "emis_mean": "emis_mean_rel_pct",
        "emis_tot": "emis_tot_rel_pct",
    }
    # Isolate the BAU outcomes (one row per draw) and select the variables we
    # want to calculate from, plus the draw_id
    cols_to_keep = list(var_to_normalize.keys())
    cols_to_keep.append("draw_id")

    expected_shapes = (df_optimal.shape == df_bau.shape == df_policy.shape) and (
        df_bau.shape[0] > 0
    )
    if not expected_shapes:
        raise ValueError(
            "Bad shapes for first best / BAU outcomes. See output in log   ."
        )

    df = df_policy.merge(
        df_bau,
        how="inner",
        on="draw_id",
        validate="1:1",
        copy=True,
        suffixes=(None, "_bau"),
    ).merge(
        df_optimal,
        how="inner",
        on="draw_id",
        validate="1:1",
        copy=True,
        suffixes=(None, "_best"),
    )
    vars_to_drop = []
    # Calculate the relative change:
    for current, new_var in var_to_normalize.items():
        bau = current + "_bau"
        best = current + "_best"
        # For both DWL and emiss: 0 <= best <= current <= BAU
        # (Because audit costs are included, it's possible current > BAU for
        # very ineffective policies. For numerical issues, we can end up with
        # very small negative nuumbers as well.)
        df[new_var] = 100 * (df[bau] - df[current]) / (df[bau] - df[best])
        vars_to_drop.append(bau)
        vars_to_drop.append(best)
        assert df[new_var].between(-1, 110, inclusive="both").all()
    var_to_diff = {
        "emis_mean": "emis_reduce_mean",
        "emis_tot": "emis_reduce_tot",
    }
    # Check that the bau_outcomes df has the keys we need. (<= is subset)
    assert set(var_to_diff.keys()) <= set(cols_to_keep)
    # Now also calculate the change in levels:
    for current, new_var in var_to_diff.items():
        bau = current + "_bau"
        df[new_var] = df[bau] - df[current]
    # Clean up the *_bau and *_best variables
    df.drop(columns=vars_to_drop, inplace=True)
    return df


def conf_low(x):
    """Return the 2.5% quantile of x"""
    return x.quantile(0.025)


def conf_high(x):
    """Return the 97.5% quantile of x"""
    return x.quantile(0.975)


def summarize_outcomes(df):
    """
    For every outcome, calculate the mean and 95% CI.
    """
    assert df.notna().all().all()
    agg_fn = ["mean", conf_low, conf_high]
    summ = (
        df.drop(columns="draw_id")
        .groupby(list(AuditInfo._fields))
        .aggregate(agg_fn)
        .reset_index()
    )
    # Here pandas has made the columns into a MultiIndex...
    # summ.rename fails here because it won't handle the column name pairs
    new_columns = []
    for c in summ.columns:
        assert isinstance(c, tuple)
        assert len(c) == 2
        if c[1] == "":
            # keep original
            new_columns.append(c[0])
        else:
            new_columns.append("_".join(c))
    summ.columns = new_columns
    summ.rename(columns={"τT": "tau_T"}, inplace=True)
    assert "shadow_price_mean" in summ.columns
    return summ


def set_memory_limit(mem_mb):
    """
    Limit the available memory to `mem_mb`. Only works on unix systems.
    "available memory" includes memory_map files.
    """
    import resource

    mem_bytes = mem_mb * 1024 ** 2
    _, lim_hard = resource.getrlimit(resource.RLIMIT_AS)
    new_lim = (mem_bytes, lim_hard)
    resource.setrlimit(resource.RLIMIT_AS, new_lim)


def extract_regex_match(str_to_search, regex):
    """Convenience wrapper to pull out one regex group"""
    match = re.search(regex, str_to_search)
    if not match:
        raise ValueError(f"Could not match {regex} in string {str_to_search}")
    return match.group(1)


def parse_period_wildcard(wildcard_str):
    """Parse the 'time_period' Snakemake wildcard.

    If the wildcard is empty (""), return 8760.0.
    """
    if wildcard_str == "":
        return 8760.0
    time_period_hr = float(extract_regex_match(wildcard_str, r"-period_(\d+)_hours"))
    return time_period_hr


def parse_τT_wildcard(wildcard_str):
    # numbers (strict: no commas or scientific notation, decimals require 0)
    τT_number_matches = re.search(r"^\-?(\d+\.)?\d+$", wildcard_str)
    if τT_number_matches:
        return float(wildcard_str)

    τT_regex = f"({'|'.join(TAU_LEVELS.keys())})-({'|'.join(T_LEVELS.keys())})"
    τT_regex_match = re.search(τT_regex, wildcard_str)
    if not τT_regex_match:
        raise ValueError(f"Failed to parse τT from {wildcard_str}")

    τ_str = τT_regex_match.group(1)
    T_str = τT_regex_match.group(2)
    τT = TAU_LEVELS[τ_str] * T_LEVELS[T_str]
    return τT


def parse_audit_info(wildcards):
    """
    Here we parse some values out of a filename like this:
    "{model_name}{prior_only}{bootstrap}{time_period}" /
    "audit_outcome_summary_rule={audit_rule}_frac={audit_amount}_tauT={audit_tauT}.parquet"
    into an AuditInfo tuple
    """
    # Here we translate what it means for detect to be "high" or "low" into numbers.
    const = read_constants()
    detect_opts = {
        "low": const["POLICY_DETECT_THRESHOLD_LOW"],
        "high": const["POLICY_DETECT_THRESHOLD_HIGH"],
    }
    # audit_frac_opts = {"low": 0.0, "med": 0.01, "high": 0.10}
    # audit_cost_opts = {"low": 0.0, "med": 100.0, "high": 600.0}
    # Parse (none|uniform|remote_low|remote_high|target_x|target_e_low|target_e_high)
    # into an audit rule (e.g. "target_e") and detection threshold (e.g. "low")
    match_audit = re.search(
        "^(none|uniform|remote|target_x|target_e)_?(low|high)?$",
        wildcards["audit_rule"],
    )
    if not match_audit:
        raise ValueError("Failed to match audit pattern")
    audit_rule = match_audit.group(1)
    match_detect_threshold = match_audit.group(2)
    if not match_detect_threshold:
        # Note: detect_threshold isn't relevant to some rules, but still need
        # to fill in a value
        match_detect_threshold = "low"
    detect_threshold = detect_opts[match_detect_threshold]

    # parse "audit_amount", which has a wildcard constraint of
    # "(0pct|1pct|10pct|optimal-100usd|optimal-600usd)"
    if wildcards["audit_amount"].endswith("pct"):
        audit_frac = extract_regex_match(wildcards["audit_amount"], r"^(0?\.?\d+)pct$")
        audit_frac = float(audit_frac) / 100.0
        audit_cost = 0.0
    elif wildcards["audit_amount"].startswith("optimal-"):
        audit_frac = 0.0
        audit_cost = float(
            extract_regex_match(wildcards["audit_amount"], "^optimal-(\d+)usd$")
        )
    else:
        raise ValueError(f"Failed to parse {wildcards['audit_amount']}")
    τT = parse_τT_wildcard(wildcards["audit_tauT"])
    audit_info = AuditInfo(
        audit_rule=audit_rule,
        audit_frac=audit_frac,
        τT=τT,
        detect_threshold=detect_threshold,
        audit_cost=audit_cost,
    )

    # Run a few checks to make sure we're not requsting nonsense audit policy
    if (
        audit_frac < 0
        or audit_frac > 1
        or audit_cost < 0
        or detect_threshold < 0
        or τT < 0
        or audit_rule not in {"none", "uniform", "target_x", "target_e", "remote"}
    ):
        logging.error(wildcards)
        logging.error(audit_info)
        raise ValueError("Impossible audit values")
    if (audit_frac == 0 and audit_cost == 0) or (audit_frac > 0 and audit_cost > 0):
        if audit_rule in {"uniform", "target_x", "target_e"}:
            logging.error(wildcards)
            raise ValueError(
                f"Bad combo of audit_frac ({audit_frac}) and audit_cost ({audit_cost}) for rule {audit_rule}"
            )
    if (
        audit_rule in ("none", "uniform", "target_x")
        and detect_threshold != detect_opts["low"]
    ):
        logging.error(wildcards)
        raise ValueError(
            f"Audit rules that don't depend on detect_threshold should set to {detect_opts['low']}"
        )
    if audit_rule in ("none", "remote") and (audit_cost != 0 or audit_frac != 0):
        logging.error(wildcards)
        raise ValueError(
            "Remote or no audits should set audit_cost = 0 and audit_frac = 0"
        )
    if audit_cost > 0 and detect_threshold > 0:
        logging.error(audit_info)
        raise NotImplementedError(
            "audit_cost > 0 and detect_threshold > 0 isn't implemented."
        )
    return audit_info


class Timer(object):
    """
    Context manager to time an expression (once; not for benchmarking)
    https://stackoverflow.com/a/5849861
    https://stackoverflow.com/a/3427051
    """

    def __init__(self, name=None):
        self.name = name

    def __enter__(self):
        self.tstart = datetime.datetime.now().replace(microsecond=0)

    def __exit__(self, type, value, traceback):
        if self.name:
            logging.info("[%s]" % self.name)
        elapsed = datetime.datetime.now().replace(microsecond=0) - self.tstart
        logging.info(f"Elapsed: {elapsed}")


if not "snakemake" in globals():
    logging.warn("Using placeholder snakemake")
    data_generated = Path(pyprojroot.here("data/generated"))
    wildcards = {
        "model_name": "08_twopart_lognormal_heterog_alpha",
        # "model_name": "01_twopart_lognormal",
        "bootstrap": "-bootstrap",
        "time_period": "-period_8760_hours",
        # "time_period": "",
        "prior_only": "",
    }
    stan_fits = (
        data_generated
        / f"stan_fits/{wildcards['model_name']}{wildcards['prior_only']}{wildcards['bootstrap']}{wildcards['time_period']}"
    )
    SnakemakePlaceholder = namedtuple(
        "SnakemakePlaceholder", ["input", "output", "threads", "resources", "wildcards"]
    )
    snakemake = SnakemakePlaceholder(
        input={
            "leak_size_draw": stan_fits / "leak_size_draw.parquet",
            "leak_size_expect": stan_fits / "leak_size_expect.parquet",
            "stan_data_json": stan_fits / "stan_data.json",
        },
        output={"results_summary": stan_fits / "audit_outcome_summary.parquet"},
        threads=4,
        resources={"mem_mb": 7000},
        wildcards=wildcards,
    )


if __name__ == "__main__":

    if snakemake.wildcards["model_name"] not in MODEL_NAMES["all"]:
        raise ValueError(f"Unknown model {snakemake.wildcards['model_name']}")
    try:
        set_memory_limit(snakemake.resources["mem_mb"])
    except:
        logging.warning("Note: failed to set memory limit.")

    logging.basicConfig(filename=snakemake.log[0], level=logging.INFO, encoding="utf-8")
    with Timer(snakemake.output["results_summary"]):
        results = calc_all_outcomes_all_draws(snakemake).pipe(summarize_outcomes)
        results.to_parquet(snakemake.output["results_summary"])
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
source(here::here("code/shared_functions.r"))
source(here::here("code/distribution_model_data_prep.r"))
source(here::here("code/stan_helper_functions.r"))

set.seed(8675309)
# LEAK_SIZE_DEF defines what leak size we're talking about for a "major leak".
# This needs to match the shift_amount in distribution_model_data_prep
LEAK_SIZE_DEF <- 5

if (!exists("snakemake")) {
  message("Using placeholder snakemake")
  TEX_FRAGMENTS <- fs::fs_path(here::here("output/tex_fragments"))
  STAN_FITS <- fs::fs_path(here::here("data/generated/stan_fits"))
  snakemake <- SnakemakePlaceholder(
    input=list(
      # NOTE: the order here (and in the snakemake file) is the order of the
      # columns in the table.
      distribution_fits = c(
        STAN_FITS / "01_twopart_lognormal-bootstrap/model_fit.rds",
        STAN_FITS / "03_twopart_lognormal_meas_err-bootstrap/model_fit.rds",
        # STAN_FITS / "05_twopart_normal_qr/model_fit.rds",
        STAN_FITS / "02_twopart_lognormal_alpha-bootstrap-period_8760_hours/model_fit.rds",
        STAN_FITS / "08_twopart_lognormal_heterog_alpha-bootstrap-period_8760_hours/model_fit.rds"
      ),
      measurements = here::here("data/generated/methane_measures/matched_wells_all.rds"),
      stan_data_json = c(
        STAN_FITS / "01_twopart_lognormal-bootstrap/stan_data.json",
        STAN_FITS / "03_twopart_lognormal_meas_err-bootstrap/stan_data.json",
        # STAN_FITS / "05_twopart_normal_qr/stan_data.json",
        STAN_FITS / "02_twopart_lognormal_alpha-bootstrap/stan_data.json",
        STAN_FITS / "08_twopart_lognormal_heterog_alpha-bootstrap/stan_data.json"
      )
    ),
    output = list(
      model_prob_leak_plot = "graphics/model_prob_leak_plot.pdf",
      model_cost_vs_q_plot = "graphics/model_cost_vs_q_plot.pdf",
      model_cost_vs_q_dwl_plot = "graphics/model_cost_vs_q_dwl_plot.pdf",
      model_coef_obs_leak  = TEX_FRAGMENTS / "model_parameters_obs_leak.tex",
      model_coef_leak_size = TEX_FRAGMENTS / "model_parameters_leak_size.tex",
      model_coef_footer_obs_leak  = TEX_FRAGMENTS / "model_parameters_footer_obs_leak.tex",
      model_coef_footer_leak_size = TEX_FRAGMENTS / "model_parameters_footer_leak_size.tex",
      # model_cost_alpha_by_leak_size_bin_plot = "graphics/model_cost_alpha_by_leak_size_bin_plot.pdf",
      # model_leak_size_by_leak_size_bin_plot  = "graphics/model_leak_size_by_leak_size_bin_plot.pdf"
      model_prob_size_above_threshold_histogram = "graphics/model_prob_size_above_threshold_histogram.pdf",
      model_cost_alpha_histogram = "graphics/model_cost_alpha_histogram.pdf"
    ),
    threads=1,
    resources=list(mem_mb = 13000)
  )
}

memory_limit(snakemake@resources[["mem_mb"]])
# Set warn = 2 after memory_limit because it may not work
options(scipen = 99, mc.cores=snakemake@threads, warn = 2)


load_samples <- function(parquet_file, n_draws) {
  if (n_draws == "all") {
    return(arrow::read_parquet(parquet_file))
  }
  stopifnot(n_draws >= 1)
  # Load some columns from a parquet file of draws.
  n_col <- arrow::ParquetFileReader$create(parquet_file)$GetSchema()$num_fields
  if (n_col < n_draws) {
    stop("Only ", n_col, " columns are available. (Use n_draws='all' to load all)")
  }
  # Column names are just numbers, "1", "2", ...
  idx <- sample.int(n_col, size=n_draws, replace=FALSE) %>% as.character()
  arrow::read_parquet(parquet_file, col_select=!!idx)
}


# Plot the modeled prob_leak
plot_model_prob_leak <- function(prob_leak_file, outfile) {
  stopifnot(fs::path_file(prob_leak_file) == "prob_leak.parquet")

  # We have a lot of draws that went into the uncertainty. Pick a subset.
  n_draws <- 200
  draws_to_plot <- load_samples(prob_leak_file, n_draws=n_draws)
  # Reset names:
  colnames(draws_to_plot) <- as.character(seq_len(n_draws))
  draws_to_plot %<>% tibble::as_tibble() %>%
    tidyr::pivot_longer(dplyr::everything(), names_to="draw_id", values_to="prob_leak") %>%
    dplyr::mutate(prob_leak_pct = winsorize(prob_leak * 100, trim=c(0, 0.02)))
  # Different lines, but same color for every line.
  colors <- rep_len(RColorBrewer::brewer.pal(n=3, name="Dark2")[1], n_draws)
  plt <- ggplot2::ggplot(draws_to_plot, ggplot2::aes(x=prob_leak_pct, color=draw_id)) +
    ggplot2::stat_density(geom="line", alpha=0.1) +
    ggplot2::scale_color_manual(values=colors, guide="none") +
    ggplot2::theme_bw() +
    ggplot2::labs(x="Probability of leaking (%; winsorized)", y="Density")

  save_plot(plt, outfile, reproducible=TRUE)
}

read_dist_fit <- function(distribution_fit, depvar = c("leak size", "obs leak")) {
  model_name <- filename_to_model_name(distribution_fit)
  stopifnot(model_name %in% MODEL_NAMES$all)
  depvar <- match.arg(depvar)
  varnames <- get_shared_varnames(model_name = model_name)
  if (depvar == "leak size") {
    # Order matters here -- we're going to match up stan coefs and varnames by
    # position!
    stan_coefs <- c("b_y_intercept", "b_y", "sigma_y")
    varnames %<>% c("sigma")
  } else if (depvar == "obs leak") {
    stan_coefs <- c("b_obs_intercept", "b_obs")
    if (model_name == "02_twopart_lognormal_alpha") {
      return(NULL)
    }
  } else {
    stop("programming error")
  }
  if (depvar == "obs leak" && model_name %in% MODEL_NAMES$rhs_ehat_models) {
    varnames %<>% c("e_hat") # e for emissions
  }
  stopifnot(length(distribution_fit) == 1, grepl('intercept', stan_coefs[1]))
  draws <- readRDS(distribution_fit)$draws %>%
    posterior::subset_draws(stan_coefs)
  list(draws = draws, varnames = varnames, model_name = model_name)
}

clean_varnames <- function(x) {
  # This process could obviously be a lot cleaner:
  pretty_names <- c(
    Intercept = "Intercept",
    asinhgas_avg_mcfd = "IHS of gas prod (mcfd)",
    asinhoil_avg_bbld = "IHS of oil prod (bbld)",
    basinSanJoaquin = "Basin: San Joaquin",
    basinSanJuan = "Basin: San Juan",
    basinOtherCalifornia = "Basin: Other California",
    prod_oil_frac = "Oil prod share",
    asinhage_yr = "IHS of age (yr)",
    drill_typeH = "Drill: Horizontal",
    drill_typeU = "Drill: Unknown",
    drill_typeV = "Drill: Vertical",
    drill_typeD = "Drill: Directional",
    sigma = "$\\sigma$",
    e_hat = "$\\hat{e_i}$"
  )
  missing_names <- setdiff(unique(x), names(pretty_names))
  if (length(missing_names) > 0) {
    stop("Missing pretty names for these variables: ", paste(missing_names, collapse=", "))
  }
  pretty_names[x]
}

get_shared_varnames <- function(model_name,
    measurements_file = snakemake@input[["measurements"]]
  ) {
  stopifnot(length(measurements_file) == 1)
  # Load the data the same way it's loaded for the model fitting.
  # These are the names R creates with model.matrix, so they're messy for factors
  # Could memoize this function, but it's already fast.
  coef_names <- prep_measurement_data(measurements_file) %>%
    purrr::chuck("aviris_all") %>%
    prep_custom_model_data(model_name = model_name) %>% #
    purrr::chuck("X") %>%
    colnames()
  coef_names
}

summarize_coefs_once <- function(draws, varnames) {
  stopifnot(posterior::is_draws(draws), is.character(varnames))
  # col_order <- order(names(df))
  # out <- tidyr::pivot_longer(df, dplyr::everything(), names_to="term", values_to="est") %>%
  #   dplyr::group_by(term) %>%
  #   dplyr::summarize(
  #     estimate = signif(mean(est), 3),
  #     conf_low = signif(quantile_(est, 0.025), 2),
  #     conf_high = signif(quantile_(est, 0.975), 2),
  #     .groups="drop"
  #   ) %>%
  #   dplyr::mutate(term = clean_varnames(term))
  est = function(x) signif(mean(x), 3)
  conf95_low  = function(x) signif(quantile(x, 0.025, names=FALSE, type=8), 2)
  conf95_high = function(x) signif(quantile(x, 0.975, names=FALSE, type=8), 2)

  out <- posterior::summarize_draws(draws,
    estimate = est, conf_low = conf95_low, conf_high = conf95_high
  ) %>%
  dplyr::mutate(term = clean_varnames(varnames))
  stopifnot(
    # Check names match expectations. (dplyr already checks lengths)
    grepl("intercept", out$variable[1], ignore.case=TRUE),
    grepl("intercept", varnames[1], ignore.case=TRUE),
    grepl("[1]", out$variable[2], fixed=TRUE),
    grepl("mcfd", varnames[2], fixed=TRUE)
  )
  # Re-sort rows to match the original column order
  # stopifnot(nrow(out) == length(col_order))
  # dplyr::arrange(out, col_order)
  out
}

.bayes_R2 <- function(y, ypred) {
  # Borrowed from brms
  # Subtract y from every column of ypred and multiply by -1
  # (as.array to make sweep a little stricter)
  e <- -1 * sweep(ypred, MARGIN=1, STATS=as.array(y), FUN="-")
  # These are rowVars in the original brms code, but we have transposed the
  # results so each column is a  MCMC draw, and each row as a well.
  # We want to take variance of each column.
  var_ypred <- matrixStats::colVars(ypred)
  var_e <- matrixStats::colVars(e)
  stopifnot(length(var_ypred) == length(var_e))
  # I'm pretty sure this works for binary outcomes too.
  return(var_ypred / (var_ypred + var_e))
}

calc_r2 <- function(model_dir, outcome_name) {
  stopifnot(
    length(model_dir) == 1,
    length(outcome_name) == 1
  )
  stan_data_json <- file.path(model_dir, "stan_data.json")
  y_raw <- jsonlite::read_json(stan_data_json, simplifyVector=TRUE)$Y %||% stop("Missing 'Y' in sdata")
  obs_indicator <- (!is.na(y_raw)) & (y_raw > LEAK_SIZE_DEF)
  generated_dir <- fs::path_dir(stan_data_json) %>% fs::fs_path()
  generated_file <- generated_dir / paste0(outcome_name, ".parquet")

  # pred value matrix, one row per well, one col per draw.
  ypred <- arrow::read_parquet(generated_file) %>% as.matrix()
  if (outcome_name == "prob_leak") {
    y <- as.numeric(obs_indicator)
  } else if (outcome_name == "leak_size_expect") {
    # For leak size, we can only compare the actual observed leak sizes.
    # rows of ypred are wells.
    y <- y_raw[obs_indicator]
    ypred <- ypred[obs_indicator, , drop=FALSE]
  } else {
    stop("Unknown outcome name: ", outcome_name)
  }
  stopifnot(length(y) == nrow(ypred), ncol(ypred) > 1)
  r2_by_draw <- .bayes_R2(y, ypred)
  stopifnot(noNAs(r2_by_draw), r2_by_draw >= 0, r2_by_draw <= 1)
  # could do CI if we wanted.
  mean(r2_by_draw)
}

get_N <- function(model_dir) {
  stopifnot(length(model_dir) == 1)
  stan_data_json <- file.path(model_dir, "stan_data.json")
  N <- jsonlite::read_json(stan_data_json, simplifyVector=FALSE)$N
  N
}

calc_outcome_mean <- function(model_dir, depvar) {
  stopifnot(length(model_dir) == 1, length(depvar) == 1)
  stan_data_json <- file.path(model_dir, "stan_data.json")
  y_raw <- jsonlite::read_json(stan_data_json, simplifyVector=TRUE)$Y %||% stop("Missing 'Y' in sdata")
  obs_indicator <- (!is.na(y_raw)) & (y_raw > LEAK_SIZE_DEF)
  if (depvar == "obs leak") {
    y <- as.numeric(obs_indicator)
  } else if (depvar == "leak size") {
    y <- y_raw[obs_indicator]
  } else {
    stop("Unknown outcome name: ", depvar)
  }
  mean(y)
}

write_coefs_table <- function(snakemake, depvar) {
  # fit_info is a list of lists. Outer list has one element per file in
  # distribution_fits. For each of those, inner list has elements `draws`,
  # `varnames`, and `model_name`
  if (depvar == "leak size") {
    outfile <- snakemake@output$model_coef_leak_size %||% stop("missing outfile")
  } else if (depvar == "obs leak") {
    outfile <- snakemake@output$model_coef_obs_leak  %||% stop("missing outfile")
  } else {
    stop("bad depvar")
  }
  fit_files <- snakemake@input$distribution_fits %||% stop("missing fit files")
  fit_info <- purrr::map(fit_files, read_dist_fit, depvar=depvar) %>%
    purrr::compact()
  model_names <- extract_from_each(fit_info, "model_name") %>% as.character()
  varnames_lst <- extract_from_each(fit_info, "varnames")
  summary_lst <- extract_from_each(fit_info, "draws") %>%
    purrr::map2(varnames_lst, summarize_coefs_once)
  tab <- summary_lst %>%
    purrr::map(format_estimate_above_interval, align="@") %>%
    merge_estimates_df() %>%
    make_table_fragment(escaped=FALSE, add_comments = model_names) %>%
    writeLines(outfile)
  write_coef_footer(snakemake, depvar)
  invisible(summary_lst)
}

write_coef_footer <- function(snakemake, depvar) {
  model_dirs <- dirname(snakemake@input$stan_data_json)
  model_names <- file.path(model_dirs, "model_fit.rds") %>% filename_to_model_name()
  if (depvar == "leak size") {
    outfile <- snakemake@output$model_coef_footer_leak_size
    outcome_name <- "leak_size_expect"
  } else if (depvar == "obs leak") {
    outfile <- snakemake@output$model_coef_footer_obs_leak
    outcome_name <- "prob_leak"
    # No b_obs coef for this model, so no footer for this model:
    model_dirs <- model_dirs[model_names != "02_twopart_lognormal_alpha"]
  } else {
    stop("bad depvar")
  }
  .make_table_line <- function(x) {
    paste(paste(x, collapse = " & "), "\\\\")
  }
  r2 <- purrr::map_dbl(
    model_dirs,
    calc_r2,
    outcome_name=outcome_name
  ) %>% signif(2)
  n <- purrr::map_int(model_dirs, get_N)
  depvar_mean <- purrr::map_dbl(
    model_dirs,
    calc_outcome_mean,
    depvar=depvar
  ) %>% signif(3)
  waic <- rep_len("-", length(model_dirs))
  text_lines <- c(
    .make_table_line(c("$N$", n)),
    .make_table_line(c("$R^2$", r2)),
    # .make_table_line(c("WAIC", waic)),
    .make_table_line(c("Dep. var. mean", depvar_mean))
  )
  writeLines(text_lines, outfile)
}

which_quantile <- function(x, prob) {
  # Get the index of `x` for the value that's closest to the `prob` quantile.
  # Return value has length 1, even if there are ties
  # Analagous to which.min
  stopifnot(length(prob) == 1, length(x) >= 1)
  q <- quantile_(x, prob)
  which.min(abs(x - q))
}

plot_model_cost_param <- function(model_dir, outfile_cost, outfile_dwl) {
  # Create two plots here:
  # 1. Cost param with uncertainty
  # 2. Shaded DWL
  model_dir <- fs::fs_path(model_dir)
  model_name <- filename_to_model_name(model_dir / "model_fit.rds")
  time_period_hr <- filename_to_time_period_hr(model_dir)

  stopifnot(
    model_name %in% MODEL_NAMES$cost_coef_models,
    length(outfile_cost) == 1,
    length(outfile_dwl) == 1
  )
  cost_param_A_file <- model_dir / "cost_param_A.parquet"
  cost_param_alpha_file <- model_dir / "cost_param_alpha.parquet"
  leak_size_expect_file <- model_dir / "leak_size_expect.parquet"
  sdata <- jsonlite::read_json(model_dir / "stan_data.json", simplifyVector=TRUE)

  gas_frac_ch4 <- 0.95
  ch4_kg_per_mcf <- 18.8916
  # price in sdata is $ per kg CH4
  price_per_mcf <- median(sdata$price) * ch4_kg_per_mcf / gas_frac_ch4
  fee_per_ton_co2e <- 5

  ton_co2e_per_ton_ch4 <- 29.8
  ton_per_kg <- 1 / 1000
  fee_per_mcf <- (
    fee_per_ton_co2e
    * ton_co2e_per_ton_ch4
    * ton_per_kg
    * ch4_kg_per_mcf
    * gas_frac_ch4
  )
  scm_per_kg <- 2 # ($58/ton CO2e; kinda low!_
  scm_per_mcf <- scm_per_kg * ch4_kg_per_mcf * gas_frac_ch4



  # For 02_twopart_lognormal_alpha, the coef is constant, so it doesn't matter
  # that we take the median vs whatever. For 08_twopart_lognormal_heterog_alpha,
  # the coef varies by well, so aggregation potentially matters
  # We'll find the well with the median predicted leak size in draw 1 (arbitrary),
  # and then plot that well's parameters.
  # Doing it this way, instead of taking the median in each draw, gives a better
  # sense of the uncertainty, plotting the variation at one well, rather than
  # the quantile.
  leak_size_kg_per_hr <- arrow::read_parquet(leak_size_expect_file) %>%
    as.matrix()
  well_idx <- which_quantile(leak_size_kg_per_hr[, 1], 0.5)
  cost_param_A <- as.matrix(arrow::read_parquet(cost_param_A_file))[well_idx, , drop=FALSE] %>%
    as.vector()
  cost_param_alpha <- as.matrix(arrow::read_parquet(cost_param_alpha_file))[well_idx, , drop=FALSE] %>%
    as.vector()
  leak_size_kg_per_hr <- as.matrix(arrow::read_parquet(leak_size_expect_file))[well_idx, , drop=FALSE] %>%
    as.vector()
  stopifnot(length(cost_param_A) == 4000, length(cost_param_alpha) == 4000)

  leak_size_kg <- leak_size_kg_per_hr * time_period_hr
  leak_size_mcf <- leak_size_kg / ch4_kg_per_mcf / gas_frac_ch4
  line_private <- price_per_mcf
  line_policy <- (price_per_mcf + fee_per_mcf)
  line_optimal <- (price_per_mcf + scm_per_mcf)

  # To find the uncertainty values we want to plot, find the indexes of the 2.5%
  # and 97.5% MC at q=0.99
  # (I think the function is monotonic in this way, so it doesn't matter that
  # we're only doing one point.)
  cost_param <- dplyr::tibble(
    A = cost_param_A,
    alpha = cost_param_alpha,
    mc_point = A * (1 - 0.99) ^ alpha / !!leak_size_mcf,
  )
  draw_percentile = c(2.5, 50, 97.5)
  draw_idx_of_interest <- c(
    low = which_quantile(cost_param$mc_point, 0.025),
    med = which_quantile(cost_param$mc_point, 0.5),
    high = which_quantile(cost_param$mc_point, 0.975)
  )

  # We're going to plot the y-axis in dollars per mcf for two reasons:
  # 1. It makes the policy lines stable, since otherwise they would vary with the e
  # 2. It gives the audience numbers on a scale they might be familar with (eg. $/mcf commodity price)
  # (Could do $/CO2e for similar reasons)

  q_seq <- seq(0.95, 0.99999, length.out=400)
  to_plot <- dplyr::tibble(
      draw_percentile = draw_percentile,
      # Plot quantiles across draws
      alpha = cost_param$alpha[draw_idx_of_interest],
      A     = cost_param$A[draw_idx_of_interest],
      leak_size_mcf = leak_size_mcf[draw_idx_of_interest],
      # Make a list with a copy of q_seq for each value of draw_idx_of_interest
      q = purrr::map(draw_idx_of_interest, ~q_seq),
    ) %>%
    tidyr::unnest(cols="q") %>%
    ensure_id_vars(q, draw_percentile) %>%
    dplyr::mutate(
      marg_cost_per_mcf = A * (1 - q) ^ alpha / leak_size_mcf,
      draw_percentile_fct = factor(draw_percentile),
    ) %>%
    dplyr::filter(marg_cost_per_mcf <= !!line_optimal * 1.2)

  text_labels <- dplyr::tibble(
    text = c("Private", "Policy", "Social"),
    y = c(line_private, line_policy, line_optimal) + c(1.4, 1.7, 1.7),
    x = min(to_plot$q),
  )

  to_plot_median <- dplyr::filter(to_plot, draw_percentile == 50)
  stopifnot(nrow(dplyr::distinct(to_plot_median, A, alpha)) == 1)
  intersection_y <- c(line_private, line_policy, line_optimal) * unique(to_plot_median$leak_size_mcf)
  # unique here because there are duplicates for every q
  intersection_x <- 1 - (intersection_y / unique(to_plot_median$A)) ^ (1 / unique(to_plot_median$alpha))

  dwl_region_orig <- to_plot_median %>%
    dplyr::filter(dplyr::between(.data$q, intersection_x[1], intersection_x[3]))
  dwl_region_with_policy <- to_plot_median %>%
    dplyr::filter(dplyr::between(.data$q, intersection_x[2], intersection_x[3]))
  dwl_region_colors <- c("gray", "gray")

  plt_cost <- ggplot2::ggplot(to_plot) +
    ggplot2::geom_line(ggplot2::aes(q, marg_cost_per_mcf, color=draw_percentile_fct)) +
    ggplot2::geom_hline(yintercept=line_private, linetype="dashed", alpha=0.7) +
    ggplot2::geom_hline(yintercept=line_optimal, linetype="dashed", alpha=0.7) +
    ggplot2::theme_bw() +
    ggplot2::scale_color_manual(values=c("gray", "black", "gray"), guide="none") +
    ggplot2::geom_text(ggplot2::aes(x=x, y=y, label=text), data=dplyr::filter(text_labels, text != "Policy"), hjust=0, size=3) +
    ggplot2::labs(
      x="Prob no leak (q)",
      y="",
      subtitle="$ / mcf"
    )
  plt_dwl <- ggplot2::ggplot(to_plot_median, ggplot2::aes(q, marg_cost_per_mcf)) +
    ggplot2::geom_line() +
    ggplot2::geom_line(data=to_plot, alpha=0) + # add to plot to make the axis the same for both graphs.
    ggplot2::geom_ribbon(ggplot2::aes(ymin=marg_cost_per_mcf, ymax=line_optimal), data=dwl_region_orig,        fill=dwl_region_colors[1], alpha=0.6) +
    ggplot2::geom_ribbon(ggplot2::aes(ymin=marg_cost_per_mcf, ymax=line_optimal), data=dwl_region_with_policy, fill=dwl_region_colors[2], alpha=0.8) +
    ggplot2::geom_hline(yintercept=line_private, linetype="dashed", alpha=0.7) +
    ggplot2::geom_hline(yintercept=line_policy, linetype="dashed", alpha=0.7) +
    ggplot2::geom_hline(yintercept=line_optimal, linetype="dashed", alpha=0.7) +
    ggplot2::geom_text(ggplot2::aes(x=x, y=y, label=text), data=text_labels, hjust=0, size=3) +
    ggplot2::theme_bw() +
    ggplot2::labs(
      x="Prob no leak (q)",
      y="",
      subtitle="$ / mcf"
    )

  save_plot(plt_cost, outfile_cost, reproducible=TRUE)
  save_plot(plt_dwl, outfile_dwl, reproducible=TRUE)
}

load_data_by_leak_size_bin <- function(model_dir) {
  # This function is only used by plot_model_by_leak_size_bin, but makes some
  # big temp objects, so put it in a separate function
  model_dir <- fs::fs_path(model_dir)
  model_name <- filename_to_model_name(model_dir / "model_fit.rds")
  time_period_hr <- filename_to_time_period_hr(model_dir)
  stopifnot(model_name %in% MODEL_NAMES$cost_coef_models)
  cost_param_alpha_file <- model_dir / "cost_param_alpha.parquet"
  leak_size_draw_file <- model_dir / "leak_size_draw.parquet"
  leak_size_kg_per_hr <- arrow::read_parquet(leak_size_draw_file)

  leak_size_means_by_well <- leak_size_kg_per_hr %>% as.matrix() %>% rowMeans()
  size_group_ids <- tibble::tibble(
    well_id = seq_along(leak_size_means_by_well),
    size_group = factor(cut(leak_size_means_by_well, breaks=5, labels=FALSE)),
  )
  leak_size_long <- leak_size_kg_per_hr %>%
    dplyr::mutate(well_id = dplyr::row_number()) %>%
    # Melt is a few seconds faster than pivot_longer here
    data.table::as.data.table() %>%
    data.table::melt(id.vars="well_id", variable.name="draw_id", value.name="leak_size")
  data.table::setkey(leak_size_long, well_id, draw_id)
  cost_param_alpha_long <- cost_param_alpha_file %>%
    arrow::read_parquet() %>%
    dplyr::mutate(well_id = dplyr::row_number()) %>%
    data.table::as.data.table() %>%
    data.table::melt(id.vars="well_id", variable.name="draw_id", value.name="alpha")
  data.table::setkey(cost_param_alpha_long, well_id, draw_id)
  to_plot <- merge(leak_size_long, cost_param_alpha_long, by=c("well_id", "draw_id"), all=TRUE) %>%
    merge(size_group_ids, by="well_id", all=TRUE) %>%
    dplyr::as_tibble() %>%
    dplyr::select(-well_id)
  stopifnot(noNAs(to_plot))

  return(to_plot)
}

plot_model_by_leak_size_bin <- function(model_dir, outfile_alpha, outfile_leak) {
  # Code to generate:
  # "graphics/model_cost_alpha_by_leak_size_bin_plot.pdf",
  # "graphics/model_leak_size_by_leak_size_bin_plot.pdf",
  # Currently not run because it's slow and I went with a simpler graph.
  stopifnot(
    length(model_dir) == 1,
    length(outfile_alpha) == 1,
    length(outfile_leak) == 1
  )
  # Divide wells into quintiles of (mean) leak size
  # Then plot:
  # 1. ridgeplots of leak size (showing variation both across wells and draws within group)
  # 2. ridgeplots of alpha
  to_plot <- load_data_by_leak_size_bin(model_dir)

  plt_alpha <- ggplot2::ggplot(to_plot,
    ggplot2::aes(
      x = -winsorize(alpha, c(0.01, 0)),
      y = size_group,
      height = stat(density)
    )) +
    ggridges::geom_density_ridges(stat = "binline", bins = 100, scale = 0.95, draw_baseline = FALSE) +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Abatement elasticity by leak size quintile",
      x = "Marginal cost elasticity (−α)",
      y = "Leak size quintile"
    )
  plt_leak_size <- ggplot2::ggplot(to_plot, ggplot2::aes(x=leak_size, y=size_group, height = stat(density))) +
    ggridges::geom_density_ridges(stat = "binline", bins = 100, scale = 0.95, draw_baseline = FALSE) +
    ggplot2::theme_bw() +
    ggplot2::scale_x_log10() +
    ggplot2::labs(
      title = "Leak size distribution by quintile of leak size mean",
      x = "Leak size (kg/hr)",
      y = "Leak size quintile"
    )

  save_plot(plt_alpha, outfile_alpha, reproducible=TRUE)
  save_plot(plt_leak_size, outfile_leak, reproducible=TRUE)
}


plot_alpha_histogram <- function(model_dir, outfile_alpha) {
  stopifnot(
    length(model_dir) == 1,
    length(outfile_alpha) == 1
  )
  cost_param_alpha_file <- fs::fs_path(model_dir) / "cost_param_alpha.parquet"
  cost_param_alpha_mean_by_well <- cost_param_alpha_file %>%
    arrow::read_parquet() %>%
    as.matrix() %>%
    rowMeans()
  plt_alpha <- tibble::tibble(alpha = cost_param_alpha_mean_by_well) %>%
    ggplot2::ggplot(ggplot2::aes(
      x = -winsorize(alpha, c(0.01, 0)),
      y = ..density..
    )) +
    ggplot2::geom_histogram(bins=100) +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Abatement elasticity",
      subtitle = "Mean across draws for each well",
      x = "Marginal cost elasticity (−α)",
      y = "Density"
    )
  save_plot(plt_alpha, outfile_alpha, reproducible=TRUE)
}

plot_prob_size_above_threshold_histogram <- function(model_dir, outfile) {
  on.exit(options(warn=2))
  options(warn=1)
  stopifnot(
    length(model_dir) == 1,
    length(outfile) == 1
  )
  pq_file <- fs::fs_path(model_dir) / "prob_size_above_threshold.parquet"
  mean_by_well <- pq_file %>%
    arrow::read_parquet() %>%
    as.matrix() %>%
    rowMeans()
  plt <- tibble::tibble(prob_size_above_threshold = mean_by_well) %>%
    ggplot2::ggplot(ggplot2::aes(
      x = prob_size_above_threshold,
      y = ..density..
    )) +
    ggplot2::geom_histogram(bins=100) +
    ggplot2::theme_bw() +
    ggplot2::xlim(0, 1) +
    ggplot2::labs(
      title = "Probability leak size is greater than detection threshold",
      x = "Pr(e > 100 | X)",
      y = "Density"
    )
  save_plot(plt, outfile, reproducible=TRUE)
}


model_to_plot <- "08_twopart_lognormal_heterog_alpha"
model_dir_to_plot <- snakemake@input$distribution_fits[grepl(model_to_plot, snakemake@input$distribution_fits)] %>%
  dirname() %>%
  fs::fs_path()
stopifnot(length(model_dir_to_plot) == 1)

# plot_model_by_leak_size_bin(
#   model_dir_to_plot,
#   snakemake@output$model_cost_alpha_by_leak_size_bin_plot,
#   snakemake@output$model_leak_size_by_leak_size_bin_plot
# )

plot_prob_size_above_threshold_histogram(
  model_dir_to_plot,
  snakemake@output$model_prob_size_above_threshold_histogram
)

plot_alpha_histogram(
  model_dir_to_plot,
  snakemake@output$model_cost_alpha_histogram
)

plot_model_cost_param(
  model_dir_to_plot,
  snakemake@output$model_cost_vs_q_plot,
  snakemake@output$model_cost_vs_q_dwl_plot
)

plot_model_prob_leak(
  model_dir_to_plot / "prob_leak.parquet",
  outfile = snakemake@output$model_prob_leak_plot
)

write_coefs_table(snakemake, "obs leak")
write_coefs_table(snakemake, "leak size")
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
library(ggplot2)
library(here)
options(warn=2)

source(