QCSD: A QUIC Client-Side Website-Fingerprinting Defence Framework

public public 1yr ago Version: v1.0.1 0 bookmarks

This repository contains the experiment and evaluation code for the paper "QCSD: A QUIC Client-Side Website-Fingerprinting Defence Framework" (USENIX Security 2022). The Rust code for the QCSD library and test clients can be found at https://github.com/jpcsmith/neqo-qcsd.

Software Requirements

  • Ubuntu 20.04 with bash : All code was tested on a fresh installation of Ubuntu 20.04.

  • git, git-lfs : Used to clone the code repository and install python packages.

  • Python 3.8 with virtual envs : Used to create a Python 3.8 virtual environment to run the evaluation and collection scripts. Install with sudo apt-get install python3.8 python3.8-venv python3-venv .

  • docker >= 20.10 (sudo-less) : Used to isolate simultaneous runs of browsers and collection scripts, as well as to allow multiple wireguard clients. The current non-root user must be able to manage containers ( install , post-install ).

  • tcpdump >= 4.9.3 (sudo-less) : Used to capture traffic traces.

  • rust (rustc, cargo) == 1.51 : Used to compile the QCSD library and test client library written in Rust.

  • Others : Additionally, the following packages are required to build the QCSD library and test client, and can be installed with the ubuntu package manager, apt.

    sudo apt-get install build-essential mercurial gyp ninja-build libz-dev clang tshark texlive-xetex
    

Getting Started

1. Clone the repository

# Clone the repository
git clone https://github.com/jpcsmith/qcsd-experiments.git
# Change to the code directory
cd qcsd-experiments
# Download resources/alexa-1m-2021-07-18.csv.gz
git lfs pull

2. Install required Python packages

# Create and activate a virtual environment
python3.8 -m venv env
source env/bin/activate
# Ensure that pip and wheel are the latest version
python -m pip install --upgrade pip wheel
# Install the requirements using pip
python -m pip install --no-cache-dir -r requirements.txt

3. Setup

The experiments can be run either locally or distributed across multiple machines:

  • The file ansible/distributed contains an example of the configuration required for running the experiments distributed on multiple hosts.

  • The file ansible/local contains the configuration for running the experiments locally, and is used in these instructions.

Perform the following steps:

  1. Set the gateway_ip variable in ansible/local to the non-loopback IP address of the host, for example, the LAN IP address.

  2. Change the exp_path variable to a path on the (local) filesystem. It can be the same path to which the repository was cloned.

  3. Run the following command

    ansible-playbook -i ansible/local ansible/setup.yml
    
    • to setup the docker image for creating the web-page graphs with Chromium,

    • create, start, and test docker images for the Wireguard gateways and clients,

    • and download and build the QCSD library and test clients.

    The QCSD source code is cloned on the remote host into the third-party/ directory of the folder identified by the 'exp_path' variable in the hosts file ( ansible/local or ansible/distributed )

Running Experiments

Ensure that the environment is setup before running the experiments.

# Activate the virtual environment if not already active
source env/bin/activate
# Set the NEQO_BIN, NEQO_BIN_MP, and LD_PATH environment vars
source env_vars

Overview

The results and plots in the paper were produced using snakemake. Like GNU make, snakemake will run all dependent rules necessary to build the final target. The general syntax is

snakemake -j <cores> --configfile=<filename> <rulename>

Where <filename> can be config/test.yaml or config/final.yaml and <rulename> is the name of one of the snakemake rules found in workflow/rules/ or the target filename. The configfile can also be set in workflow/Snakefile to avoid repeatedly specifying it on the command line.

Mapping of Figures to Snakemake Rules

The table below details the figures and tables in the paper and the rule used to produce them. The listed output files can be found in the results/ directory.

Section Figure Rule name Output file
5. Shaping Case Studies: FRONT & Tamaraw Figure 3 shaping_eval__all plots/shaping-eval-front.png , plots/shaping-eval-tamaraw.png
Table 2 overhead_eval__table tables/overhead-eval.tex
6.1. Defending Single Connections Figure 4 ml_eval_conn__all plots/ml-eval-conn-tamaraw.png , plots/ml-eval-conn-front.png
6.2. Defending Full Web-Page Loads Figure 5 ml_eval_mconn__all plots/ml-eval-mconn-tamaraw.png , plots/ml-eval-mconn-front.png
Figure 6 ml_eval_brows__all plots/ml-eval-brows-front.png
E. Overhead in the Multi-connection Setting Table 3 overhead_eval_mconn__table tables/overhead-eval-mconn.tex
F. Server Compliance with Shaping Figure 8 None. Instead see workflow/notebooks/failure-analysis.ipynb plots/failure-rate.png

Licence

The code in this repository and associated data is released under an MIT licence as found in the LICENCE file.

Code Snippets

13
14
15
shell:
    "curl http://s3.amazonaws.com/alexa-static/top-1m.csv.zip"
    " | zcat | gzip --stdout > {output}"
28
29
30
shell:
    "set +o pipefail;"
    " zcat {input} | head -n {params.topn} | gzip --stdout > {output}"
89
90
91
92
shell:
    "workflow/scripts/evaluate_tuned_kfp.py --verbose 0 --n-jobs {threads}"
    " --cv-results-path {log[cv_results]} --feature-importance {output[feature_importance]}"
    " {input} > {output[0]} 2> {log[0]}"
106
107
108
shell:
    "workflow/scripts/evaluate_tuned_varcnn.py --hyperparams {wildcards.hyperparams}"
    " {wildcards.feature_type} {input} > {output} 2> {log}"
SnakeMake From line 106 of rules/common.smk
122
123
124
shell:
    "workflow/scripts/evaluate_tuned_df.py --hyperparams {wildcards.hyperparams}"
    " {input} > {output} 2> {log}"
SnakeMake From line 122 of rules/common.smk
136
137
shell:
    "workflow/scripts/extract_kfp_features.py {input} > {output} 2> {log}"
SnakeMake From line 136 of rules/common.smk
58
59
run:
    combine_varcnn_predictions(input, output)
72
73
script:
    "../scripts/create_dataset.py"
101
102
script:
    "../scripts/create_dataset.py"
121
122
script:
    "../scripts/run_browser_collection.py"
77
78
run:
    combine_varcnn_predictions(input, output)
96
97
script:
    "../scripts/create_dataset.py"
109
110
script:
    "../scripts/create_dataset.py"
129
130
script:
    "../scripts/run_collection.py"
56
57
script:
    "../scripts/create_dataset.py"
 99
100
run:
    combine_varcnn_predictions(input, output)
112
113
script:
    "../scripts/create_dataset.py"
135
136
137
138
139
140
shell:
    "workflow/scripts/run_collectionv2.py --configfile {params.configfile}"
    " --n-monitored {params.n_monitored} --n-instances {params.n_instances}"
    " --n-unmonitored {params.n_unmonitored} --max-failures {params.max_failures}"
    " --timeout {params.timeout} --use-multiple-connections"
    " {input} {output} -- {params.neqo_args} 2> {log}"
11
12
script:
    "../scripts/overhead_to_latex.py"
27
28
29
30
shell:
    "workflow/scripts/calculate_overhead.py {wildcards.defence} {input}"
    " --n-jobs {threads} --tamaraw-config '{params.tamaraw_config}'"
    " > {output} 2> {log}"
49
50
script:
    "../scripts/run_paired_collection.py"
16
17
18
shell:
  "workflow/scripts/calculate_overhead.py --tamaraw-config '{params.tamaraw_config}'"
  " --n-jobs {threads} {wildcards.defence} {input} > {output} 2> {log}"
28
29
script:
    "../scripts/overhead_to_latex.py"
19
20
script:
    "../scripts/run_paired_collection.py"
40
41
script:
    "../scripts/calculate_score.py"
 9
10
script:
    "../scripts/version_scan.py"
27
28
script:
    "../scripts/filter_versions.py"
11
12
13
shell:
    "set +o pipefail;"
    " tail -n +{params.start} {input} | head -n {params.batch_size} > {output}"
28
29
30
shell:
    "workflow/scripts/docker-dep-fetch --ranks --max-attempts 1 --force-quic-on-all"
    " {input} 2> {log} | gzip --stdout > {output}"
43
44
45
46
shell:
    "mkdir -p {output}"
    " && python3 workflow/scripts/url_dependency_graph.py --no-origin-filter"
    " '{output}/' {input} 2> {log}"
11
12
13
shell:
    "set +o pipefail;"
    " tail -n +{params.start} {input} | head -n {params.batch_size} > {output}"
27
28
29
shell:
    "workflow/scripts/docker-dep-fetch --ranks --max-attempts 1 {input} 2> {log}"
    " | gzip --stdout > {output}"
43
44
45
shell:
    "mkdir -p {output}"
    " && python3 workflow/scripts/url_dependency_graph.py '{output}/' {input} 2> {log}"
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import json
import logging
import itertools
import functools
from pathlib import Path
import multiprocessing
import multiprocessing.pool
from typing import Optional

import numpy as np
import pandas as pd
from lab import tracev2
from lab.defences import tamaraw

import common
from common.doceasy import doceasy, Use, Or

_LOGGER = logging.getLogger(__name__)


def main(
    input_dir: Path,
    defence: str,
    tamaraw_config: Optional[str],
    n_jobs: int
):
    """Calculate the overhead using multiple processes."""
    common.init_logging()
    _LOGGER.info("Using parameters: %s", locals())

    if defence == "tamaraw":
        if tamaraw_config is None:
            raise ValueError("Tamaraw configuration required.")
        tamaraw_config = json.loads(tamaraw_config)

    assert input_dir.is_dir(), f"invalid path {input_dir}"
    directories = sorted(
        [x.parent for x in Path(input_dir).glob("**/defended/")]
    )
    _LOGGER.info("Found %d samples", len(directories))

    func = functools.partial(
        _calculate_overhead, defence=defence, tamaraw_config=tamaraw_config)
    if n_jobs > 1:
        chunksize = max(len(directories) // (n_jobs * 2), 1)
        with multiprocessing.pool.Pool(n_jobs) as pool:
            scores = list(
                pool.imap_unordered(func, directories, chunksize=chunksize)
            )
    else:
        # Run in the main process
        scores = [func(x) for x in directories]
    _LOGGER.info("Overhead calculation complete")

    results = pd.DataFrame.from_records(itertools.chain.from_iterable(scores))
    print(results.to_csv(header=True, index=False), end="")


def _parse_duration(path):
    """Return the time taken to download all of the application's HTTP
    resources in ms, irrespective of any additional padding or traffic,
    as logged by the run.
    """
    tag = "[FlowShaper] Application complete after "  # xxx ms
    found = None
    with (path / "stdout.txt").open(mode="r") as stdout:
        found = [line for line in stdout if line.startswith(tag)][-1]
    assert found, f"Run never completed! {path}"

    # Parse the next word as an integer
    return int(found[len(tag):].split()[0])


def _calculate_overhead(dir_, *, defence: str, tamaraw_config):
    try:
        control = np.sort(tracev2.from_csv(dir_ / "undefended" / "trace.csv"))
        defended = np.sort(tracev2.from_csv(dir_ / "defended" / "trace.csv"))
        schedule = np.sort(tracev2.from_csv(dir_ / "defended" / "schedule.csv"))
    except Exception as err:
        raise ValueError(f"Error loading files in {dir_}") from err

    undefended_size = np.sum(np.abs(control["size"]))
    defended_size = np.sum(np.abs(defended["size"]))
    simulated_size = np.sum(np.abs(schedule["size"]))
    simulated_size_alt = None
    # Add the undefended size as padding only defences only list the padding
    # in the schedule.
    if defence == "front":
        simulated_size += undefended_size

    if defence == "tamaraw":
        tamaraw_trace = tamaraw.simulate(
            control,
            packet_size=tamaraw_config["packet_size"],
            rate_in=tamaraw_config["rate_in"] / 1000,
            rate_out=tamaraw_config["rate_out"] / 1000,
            pad_multiple=tamaraw_config["packet_multiple"],
        )
        simulated_size_alt = np.sum(np.abs(tamaraw_trace["size"]))

    assert control["time"][0] == 0, "trace must start at 0s"
    # The trace should also already be sorted and in seconds
    undefended_ms = int(control["time"][-1] * 1000)
    defended_ms = _parse_duration(dir_ / "defended")
    if defence == "front":
        simulated_ms = undefended_ms
    elif defence == "tamaraw":
        unpadded_tamaraw = tamaraw.simulate(
            control,
            packet_size=tamaraw_config["packet_size"],
            rate_in=tamaraw_config["rate_in"] / 1000,
            rate_out=tamaraw_config["rate_out"] / 1000,
            pad_multiple=1,
        )
        simulated_ms = int(unpadded_tamaraw["time"][-1] * 1000)
    else:
        raise ValueError(f"Unsupported defence: {defence}")

    return [
        {
            "sample": str(dir_),
            "overhead": "bandwidth",
            "setting": "collected",
            "value": (defended_size - undefended_size) / undefended_size
        },
        {
            "sample": str(dir_),
            "overhead": "bandwidth",
            "setting": "simulated",
            "value": (simulated_size - undefended_size) / undefended_size
        },
        {
            "sample": str(dir_),
            "overhead": "bandwidth",
            "setting": "simulated-alt",
            "value": ((simulated_size_alt - undefended_size) / undefended_size
                      if simulated_size_alt is not None else None)
        },
        {
            "sample": str(dir_),
            "overhead": "latency",
            "setting": "collected",
            "value": (defended_ms - undefended_ms) / undefended_ms
        },
        {
            "sample": str(dir_),
            "overhead": "latency",
            "setting": "simulated",
            "value": (simulated_ms - undefended_ms) / undefended_ms
        },
    ]


if __name__ == "__main__":
    main(**doceasy(__doc__, {
        "DEFENCE": Or("tamaraw", "front"),
        "INPUT_DIR": Use(Path),
        "--n-jobs": Use(int),
        "--tamaraw-config": Or(None, str),
    }))
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import logging
import itertools
import functools
from pathlib import Path
import multiprocessing
import multiprocessing.pool
from typing import Sequence, Dict, Optional

import numpy as np
import pandas as pd
import tslearn.metrics
from scipy.stats import pearsonr
from scipy.spatial.distance import euclidean

import common
from common import timeseries

_LOGGER = logging.getLogger(__name__)


def main(
    input_,
    output,
    *,
    defence: str,
    ts_offset: Dict[str, int],
    resample_rates: Sequence["str"],
    lcss_eps: int,
    filter_below: Sequence[int] = (0,),
    jobs: Optional[int] = None,
):
    """Score how close a defended time series is to the theoretical."""
    common.init_logging()
    _LOGGER.info("Using parameters: %s", locals())

    directories = sorted([x.parent for x in Path(input_).glob("**/defended/")])
    _LOGGER.info("Found %d samples", len(directories))

    jobs = jobs or (multiprocessing.cpu_count() or 4)
    func = functools.partial(
        _calculate_score,
        defence=defence,
        ts_offset=ts_offset,
        resample_rates=resample_rates,
        filter_below=filter_below,
        lcss_eps=lcss_eps,
    )

    if jobs > 1:
        chunksize = max(len(directories) // (jobs * 2), 1)
        with multiprocessing.pool.Pool(jobs) as pool:
            scores = list(pool.imap_unordered(func, directories, chunksize=chunksize))
    else:
        # Run in the main process
        scores = list(map(func, directories))
    _LOGGER.info("Score calculation complete")

    pd.DataFrame.from_records(itertools.chain.from_iterable(scores)).to_csv(
        output, header=True, index=False
    )


def _calculate_score(
    dir_,
    *,
    defence: str,
    ts_offset,
    resample_rates,
    filter_below,
    lcss_eps,
):
    """Score how close a padding-only defended time series is to
    padding schedule.
    """
    assert defence in ("front", "tamaraw")
    pad_only = defence == "front"

    schedule_ts = timeseries.from_csv(dir_ / "defended" / "schedule.csv")
    defended_ts = timeseries.from_csv(dir_ / "defended" / "trace.csv")
    control_ts = timeseries.from_csv(dir_ / "undefended" / "trace.csv")

    offsets = range(ts_offset["min"], ts_offset["max"], ts_offset["inc"])
    simulated_ts = simulate_with_lag(
        control_ts, schedule_ts, defended_ts, offsets, pad_only=pad_only
    )

    results = []
    for rate, direction, min_pkt_size in itertools.product(
        resample_rates, ("in", "out"), filter_below
    ):
        series = pd.DataFrame(
            {
                "a": timeseries.resample(
                    _filter(defended_ts[direction], min_pkt_size), rate
                ),
                "b": timeseries.resample(
                    _filter(simulated_ts[direction], min_pkt_size), rate
                ),
                "c": timeseries.resample(
                    _filter(control_ts[direction], min_pkt_size), rate
                ),
            }
        ).fillna(0)

        _LOGGER.debug("Series length at rate %s: %d", rate, len(series))
        _LOGGER.debug("Series summary at rate %s: %s", rate, series.describe())

        if len(series["a"]) > 30_000:
            _LOGGER.warning(
                "Refusing to allocate >= 8 GB for sample: %s, %s, %s",
                dir_,
                rate,
                direction,
            )
            continue

        results.append(
            {
                "sample": str(dir_),  # Sample name
                "rate": rate,
                "dir": direction,
                "min_pkt_size": min_pkt_size,
                "pearsonr": pearsonr(series["a"], series["b"])[0],
                "lcss": tslearn.metrics.lcss(series["a"], series["b"], eps=lcss_eps),
            }
        )
    return results


def _filter(column, below):
    assert column.min() >= 0
    return column[column >= below]


def simulate_with_lag(
    control,
    schedule,
    defended,
    offsets,
    pad_only: bool,
    rate="5ms",
):
    """Find the best offset such that the simulated trace formed by
    combining the control and schedule where the schedule is lagged
    the offset, has the lowest euclidean distance to the defended trace.

    When pad_only is False, the schedule is taken as the already
    simulated trace
    """
    defended_in = timeseries.resample(defended["in"], rate)
    simulated_out = (
        control["out"].append(schedule["out"]) if pad_only else schedule["out"]
    )

    best_offset = 0
    best_distance = np.inf
    best_simulated = None

    for offset in offsets:
        shifted_schedule = pd.Series(
            schedule["in"].values,
            index=(schedule["in"].index + pd.Timedelta(f"{offset}ms")),
        )
        simulated_in = (
            control["in"].append(shifted_schedule) if pad_only else shifted_schedule
        )

        # Put together in a dataframe to ensure they have the same length and
        # indices
        frame = pd.DataFrame(
            {
                "defended": defended_in,
                "simulated": timeseries.resample(simulated_in, rate),
            }
        ).fillna(0)

        distance = euclidean(frame["defended"], frame["simulated"])
        if distance < best_distance:
            best_offset = offset
            best_distance = distance
            best_simulated = simulated_in

    assert best_simulated is not None

    _LOGGER.debug("Using an incoming offset of %d ms", best_offset)
    return pd.DataFrame(
        {
            # Take the sum of any given time instance to handle rare duplicates
            "in": best_simulated.groupby("time").sum(),
            "out": simulated_out.groupby("time").sum(),
        }
    ).fillna(0)


if __name__ == "__main__":
    main(
        str(snakemake.input[0]),
        str(snakemake.output[0]),
        **snakemake.params,
        jobs=snakemake.threads,
    )
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
from pathlib import Path
from itertools import product, islice
from typing import Optional, Dict
import h5py
import numpy as np
import lab.tracev2 as trace
from lab.defences import front
import common

LABELS_DTYPE = np.dtype([("class", "i4"), ("region", "i4"), ("sample", "i4")])


def main(
    input_: str,
    output: str,
    *,
    n_monitored: int,
    n_instances: int,
    n_unmonitored: int,
    n_regions: int,
    simulate: str = "",
    simulate_kws: Dict = None,
):
    """Create an HDF5 dataset at output using the files found in the
    directory input_.  The CSV traces should be found under

        <input_>/<page_id>/<region>_<sample>/

    """
    common.init_logging()
    rng = None
    if simulate_kws and "seed" in simulate_kws:
        rng = np.random.default_rng(simulate_kws["seed"])
        del simulate_kws["seed"]
    if simulate_kws and "use_empty_resources" in simulate_kws:
        del simulate_kws["use_empty_resources"]

    sample_dirs = [p for p in Path(input_).iterdir() if p.is_dir()]
    # Sort such that the dirs with the most samples are first ties broken by
    # lower sample ids
    sample_dirs.sort(key=lambda p: (sum(-1 for _ in p.glob("*")), int(p.stem)))

    if len(sample_dirs) < n_monitored + n_unmonitored:
        raise ValueError(f"Insufficient samples: {len(sample_dirs)}")
    sample_dirs = sample_dirs[:(n_monitored + n_unmonitored)]

    n_rows = n_monitored * n_instances + n_unmonitored
    labels = np.recarray(n_rows, dtype=LABELS_DTYPE)
    sizes = np.full(n_rows, None, dtype=object)
    timestamps = np.full(n_rows, None, dtype=object)

    index = 0
    for label, sample_path in enumerate(sample_dirs[:n_monitored]):
        # Tterator that takes one sample from each region before taking
        # the second sample from a region. This is akin to how we
        # generated the samples.
        sample_iter = product(range(n_instances), range(n_regions))

        # Take only the first n_instances from the infinite iterator
        for (id_, region_id) in islice(sample_iter, n_instances):
            index = (label * n_instances) + (id_ * n_regions) + region_id
            labels[index] = (label, region_id, id_)

            sample = _maybe_simulate(
                sample_path / f"{region_id}_{id_}", simulate, simulate_kws, rng
            )
            sizes[index] = sample["size"]
            timestamps[index] = sample["time"]

    region_index = 0
    for sample_path in sample_dirs[n_monitored:(n_monitored + n_unmonitored)]:
        paths = [(region_id, sample_path / f"{region_id}_0")
                 for region_id in range(n_regions)]
        paths = [(region_id, p) for (region_id, p) in paths if p.exists()]
        assert paths, f"no samples in directory? {sample_path}"

        # Select a path from the list, cycling among the used region if there
        # are more than one regions, otherwise returning an inde in the paths
        # list
        region_id, path = paths[(region_index % n_regions) % len(paths)]
        region_index += 1

        index = index + 1
        labels[index] = (-1, region_id, 0)

        sample = _maybe_simulate(path, simulate, simulate_kws, rng)
        sizes[index] = sample["size"]
        timestamps[index] = sample["time"]
    assert index == (n_rows - 1), "not enough samples"

    order = labels.argsort(order=["class", "region"])
    with h5py.File(output, mode="w") as h5out:
        h5out.create_dataset("/labels", dtype=LABELS_DTYPE, data=labels[order])
        h5out.create_dataset(
            "/sizes", dtype=h5py.vlen_dtype(np.dtype("i4")), data=sizes[order]
        )
        h5out.create_dataset(
            "/timestamps", dtype=h5py.vlen_dtype(np.dtype(float)),
            data=timestamps[order]
        )


def _maybe_simulate(path: Path, simulate: str, simulate_kws, rng):
    try:
        if not simulate:
            return trace.from_csv(path / "trace.csv")

        if simulate == "tamaraw":
            return trace.from_csv(path / "schedule.csv")
        if simulate == "front":
            simulate_kws = {
                key: value for (key, value) in simulate_kws.items()
                if key in [
                    "max_client_packets", "max_server_packets", "packet_size",
                    "peak_minimum", "peak_maximum", "random_state"
                ]
            }
            assert rng is not None
            assert "seed" not in simulate_kws
            assert all(kw in simulate_kws for kw in [
                "max_client_packets", "max_server_packets", "packet_size",
                "peak_minimum", "peak_maximum",
            ])

            padding = front.generate_padding(**simulate_kws, random_state=rng)
            # We should be pointing to an undefended trace.csv file
            baseline = trace.from_csv(path / "trace.csv")
            return front.simulate(baseline, padding)
        raise ValueError(f"Unrecognised simulation: {simulate}")
    except Exception:
        logging.error("Failed on sample: %s", path)
        raise


if __name__ == "__main__":
    snakemake = globals().get("snakemake", None)
    main(
        input_=str(snakemake.input[0]),
        output=str(snakemake.output[0]),
        n_regions=snakemake.config["wireguard"]["n_regions"],
        **snakemake.params
    )
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import time
import logging
import dataclasses
from pathlib import Path
from typing import Optional, ClassVar, Sequence, Union

import h5py
import numpy as np
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import StratifiedKFold, train_test_split
from lab.feature_extraction.trace import ensure_non_ragged
from lab.classifiers import dfnet
from lab.metrics import rprecision_score, recall_score
import tensorflow

from common import doceasy
from common.doceasy import Use, Or, And

PAPER_EPOCHS: int = 30
MAX_RAND_SEED: int = 10_000

# Hyperparameters from the paper
DEFAULT_N_PACKETS: int = 5000
DEFAULT_LEARNING_RATE: float = 0.002
DEFAULT_EPOCHS: int = 30


def main(outfile, **kwargs):
    """Create and run the experiment and write the CSV results to outfile."""
    logging.basicConfig(
        format='[%(asctime)s] [%(levelname)s] %(name)s - %(message)s',
        level=logging.INFO
    )
    # Use autoclustering for speedup
    tensorflow.config.optimizer.set_jit("autoclustering")

    (probabilities, y_true, classes) = Experiment(**kwargs).run()

    y_pred = classes[np.argmax(probabilities, axis=1)]
    score = rf1_score(y_true, y_pred)
    logging.info("r_20 f1-score = %.4g", score)

    outfile.writerow(["y_true"] + list(classes))
    outfile.writerows(
        np.hstack((np.reshape(y_true, (-1, 1)), probabilities)))


@dataclasses.dataclass
class Experiment:
    """An experiment to evalute hyperparameter tuned DF classifier.
    """
    # The path to the dataset of sizes and times
    dataset_path: Path

    # The fraction of samples to use for testing the final model
    test_size: float = 0.1

    # Number of folds used in the stratified k-fold cross validation
    n_folds: int = 3

    # Level of debugging output from sklearn and tensorflow
    verbose: int = 0

    # Random seed for the experiment
    seed: int = 114155

    # Hyperparams to use if not "tune"
    hyperparams: Union[str, dict] = "tune"

    # Hyperparameters to search
    n_packet_parameters: Sequence[int] = (5_000, 7_500, 10_000)
    tuned_parameters: dict = dataclasses.field(default_factory=lambda: {
        "epochs": [30],
        "learning_rate": [0.002],
    })

    # Other seeds which are chosen for different operations
    seeds_: Optional[dict] = None

    logger: ClassVar = logging.getLogger("Experiment")

    def run(self):
        """Run hyperparameter tuning for the DeepFingerprinting classifier
        and return the prediction probabilities for the best chosen
        classifier.
        """
        # Generate and set random seeds
        rng = np.random.default_rng(self.seed)
        self.seeds_ = {
            "train_test": rng.integers(MAX_RAND_SEED),
            "kfold_shuffle": rng.integers(MAX_RAND_SEED),
            "tensorflow": rng.integers(MAX_RAND_SEED),
        }
        tensorflow.random.set_seed(self.seeds_["tensorflow"])

        self.logger.info("Running %s", self)
        start = time.perf_counter()

        # Load the dataset
        X, y = self.load_dataset()
        n_classes = len(np.unique(y))
        self.logger.info("Dataset shape=%s, n_classes=%d", X.shape, n_classes)

        # Generate our training and final testing set
        x_train, x_test, y_train, y_test = train_test_split(
            X, y, test_size=self.test_size, stratify=y, shuffle=True,
            random_state=self.seeds_["train_test"]
        )

        if self.hyperparams == "tune":
            self.logger.info("Performing hyperparameter tuning ...")
            # Tune other hyperparameters and fit the final estimator
            classifier = self.tune_hyperparameters(
                x_train, y_train, n_classes=n_classes
            )
        else:
            assert isinstance(self.hyperparams, dict)
            n_packets = self.hyperparams.get("n_packets", DEFAULT_N_PACKETS)
            learning_rate = self.hyperparams.get(
                "learning_rate", DEFAULT_LEARNING_RATE
            )
            epochs = self.hyperparams.get("epochs", DEFAULT_EPOCHS)
            self.logger.info(
                "Using n_packets=%s, learning_rate=%.3g, and epochs=%d",
                n_packets, learning_rate, epochs
            )

            x_train = first_n_packets(x_train, n_packets=n_packets)
            x_test = first_n_packets(x_test, n_packets=n_packets)

            classifier = dfnet.DeepFingerprintingClassifier(
                n_classes=n_classes, verbose=min(self.verbose, 1),
                n_features=n_packets, epochs=epochs, learning_rate=learning_rate
            )
            classifier.fit(x_train, y_train)

        # Predict the classes for the test set
        probabilities = classifier.predict_proba(x_test)
        self.logger.info(
            "Experiment complete in %.2fs.", (time.perf_counter() - start))

        return (probabilities, y_test, classifier.classes_)

    def tune_hyperparameters(self, x_train, y_train, *, n_classes):
        """Perform hyperparameter tuning on the learning rate."""
        assert self.seeds_ is not None, "seeds must be set"
        pipeline = Pipeline([
            ("first_n_packets", FunctionTransformer(first_n_packets)),
            ("dfnet", dfnet.DeepFingerprintingClassifier(
                n_classes=n_classes, verbose=min(self.verbose, 1),
            ))
        ])
        param_grid = [
            {
                "first_n_packets__kw_args": [{"n_packets": n_packets}],
                "dfnet__n_features": [n_packets],
                **{
                    f"dfnet__{key}": values
                    for key, values in self.tuned_parameters.items()
                }
            }
            for n_packets in self.n_packet_parameters
        ]
        cross_validation = StratifiedKFold(
            self.n_folds, shuffle=True,
            random_state=self.seeds_["kfold_shuffle"]
        )

        grid_search = GridSearchCV(
            pipeline, param_grid, cv=cross_validation, error_score="raise",
            scoring=make_scorer(rf1_score), verbose=self.verbose, refit=True,
        )
        grid_search.fit(x_train, y_train)

        self.logger.info("hyperparameter results = %s", grid_search.cv_results_)
        self.logger.info("hyperparameter best = %s", grid_search.best_params_)

        return grid_search.best_estimator_

    def load_dataset(self):
        """Load the features and classes from the dataset. Slices the
        packet features to the maximum evaluated.
        """
        max_packets = max(self.n_packet_parameters)
        with h5py.File(self.dataset_path, mode="r") as h5in:
            features = ensure_non_ragged(h5in["sizes"], dimension=max_packets)
            classes = np.asarray(h5in["labels"]["class"][:])
        return features, classes


def first_n_packets(features, *, n_packets: int):
    """Return the first n_packets packets along with the meta features."""
    return features[:, :n_packets]


def rf1_score(y_true, y_pred, *, negative_class=-1, ratio=20):
    """Compute the F1-score using the r-precisions with the specified ratio
    and recall.
    """
    precision = rprecision_score(
        y_true, y_pred, negative_class=negative_class, ratio=ratio,
        # If we're dividing by zero it means there were no true positives and
        # thus recall will be zero and the F1 score below will be zero.
        zero_division=1
    )
    recall = recall_score(y_true, y_pred, negative_class=negative_class)
    return 2 * precision * recall / (precision + recall)


if __name__ == "__main__":
    main(**doceasy.doceasy(__doc__, {
        "OUTFILE": doceasy.CsvFile(mode="w", default="-"),
        "DATASET_PATH": Use(Path),
        "--verbose": Use(int),
        "--hyperparams": Or("tune", And(doceasy.Mapping(), {
            doceasy.Optional("n_packets"): Use(int),
            doceasy.Optional("epochs"): Use(int),
            doceasy.Optional("learning_rate"): Use(float),
        }))
    }))
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import time
import logging
import dataclasses
from pathlib import Path
from typing import Optional, ClassVar

import numpy as np
import pandas as pd
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold, train_test_split
from lab.metrics import rprecision_score, recall_score
from lab.classifiers import kfingerprinting

from common import doceasy
from common.doceasy import Use, Or

IDEAL_N_JOBS_KFP: int = 4
MAX_RAND_SEED: int = 10_000


def main(outfile, **kwargs):
    """Create and run the experiment and write the CSV results to outfile."""
    logging.basicConfig(
        format='[%(asctime)s] [%(levelname)s] %(name)s - %(message)s',
        level=logging.INFO
    )

    (probabilities, y_true, classes) = Experiment(**kwargs).run()

    y_pred = classes[np.argmax(probabilities, axis=1)]
    score = rf1_score(y_true, y_pred)
    logging.info("r_20 f1-score = %.4g", score)

    outfile.writerow(["y_true"] + list(classes))
    outfile.writerows(
        np.hstack((np.reshape(y_true, (-1, 1)), probabilities)))


@dataclasses.dataclass
class Experiment:
    """An experiment to evalute hyperparameter tuned k-FP classifier.
    """
    # The path to the dataset of features
    features_path: Path

    # The output path for the cross-validation results
    cv_results_path: Optional[Path]

    # The output path for feature importances
    feature_importance: Optional[Path]

    # The fraction of samples to use for testing the final model
    test_size: float = 0.2

    # Number of folds used in the stratified k-fold cross validation
    n_folds: int = 3

    # Level of debugging output from sklearn and tensorflow
    verbose: int = 0

    # Total number of jobs to use in both the gridsearch and RF classifier
    n_jobs: int = 1

    # Random seed for the experiment
    seed: int = 3410

    # Hyperparameters to search
    tuned_parameters: dict = dataclasses.field(default_factory=lambda: {
        "n_neighbours": [2, 3, 6],
        "forest__n_estimators": [100, 150, 200, 250],
        "forest__max_features": ["sqrt", "log2", 20, 30],
        "forest__oob_score": [True, False],
        "forest__max_samples": [None, 0.5, 0.75, 0.9],
    })

    # Other seeds which are chosen for different operations
    seeds_: Optional[dict] = None

    # Dervied number of jobs
    n_jobs_: Optional[dict] = None

    logger: ClassVar = logging.getLogger("Experiment")

    def run(self):
        """Run hyperparameter tuning for the VarCNN classifier and return
        the prediction probabilities for the best chosen classifier.
        """
        # Generate and set random seeds
        rng = np.random.default_rng(self.seed)
        self.seeds_ = {
            "train_test": rng.integers(MAX_RAND_SEED),
            "kfold_shuffle": rng.integers(MAX_RAND_SEED),
            "kfp": rng.integers(MAX_RAND_SEED),
        }

        self.logger.info("Running %s", self)
        start = time.perf_counter()

        # Load the dataset
        X, y = self.load_dataset()
        self.logger.info(
            "Dataset shape=%s, n_classes=%d", X.shape, len(np.unique(y))
        )

        # Generate our training and final testing set
        x_train, x_test, y_train, y_test = train_test_split(
            X, y, test_size=self.test_size, stratify=y, shuffle=True,
            random_state=self.seeds_["train_test"]
        )

        # Tune other hyperparameters and fit the final estimator
        (results, classifier) = self.tune_hyperparameters(x_train, y_train)
        if self.cv_results_path is not None:
            pd.DataFrame(results).to_csv(
                self.cv_results_path, header=True, index=False
            )

        if self.feature_importance is not None:
            pd.DataFrame({
                "feature": kfingerprinting.ALL_DEFAULT_FEATURES,
                "weight": classifier.forest_.feature_importances_
            }).to_csv(self.feature_importance, header=True, index=False)

        # Predict the classes for the test set
        probabilities = classifier.predict_proba(x_test)
        self.logger.info(
            "Experiment complete in %.2fs.", (time.perf_counter() - start))

        return (probabilities, y_test, classifier.classes_)

    def tune_hyperparameters(self, x_train, y_train):
        """Perform hyperparameter tuning."""
        assert self.seeds_ is not None, "seeds must be set"

        # Determine the number of jobs to use for the grid search and for kFP
        self.n_jobs_ = {"tune": 1, "kfp": self.n_jobs}
        if self.n_jobs > IDEAL_N_JOBS_KFP:
            self.n_jobs_["tune"] = max(1, self.n_jobs // IDEAL_N_JOBS_KFP)
            self.n_jobs_["kfp"] = IDEAL_N_JOBS_KFP
        self.logger.info("Using jobs: %s", self.n_jobs_)

        estimator = kfingerprinting.KFingerprintingClassifier(
            unknown_label=-1, random_state=self.seeds_["kfp"],
            n_jobs=self.n_jobs_["kfp"]
        )
        cross_validation = StratifiedKFold(
            self.n_folds, shuffle=True,
            random_state=self.seeds_["kfold_shuffle"]
        )

        grid_search = GridSearchCV(
            estimator, self.tuned_parameters, cv=cross_validation,
            error_score="raise", scoring=make_scorer(rf1_score),
            verbose=self.verbose, refit=True, n_jobs=self.n_jobs_["tune"],
        )
        grid_search.fit(x_train, y_train)

        self.logger.info("hyperparameter best = %s", grid_search.best_params_)

        return (grid_search.cv_results_, grid_search.best_estimator_)

    def load_dataset(self):
        """Load the features and classes from the dataset."""
        frame = pd.read_csv(self.features_path)
        assert frame.columns.get_loc("y_true") == 0, "y_true not first column?"

        classes = frame.iloc[:, 0].to_numpy()
        features = frame.iloc[:, 1:].to_numpy()

        return features, classes


def rf1_score(y_true, y_pred, *, negative_class=-1, ratio=20):
    """Compute the F1-score using the r-precisions with the specified ratio
    and recall.
    """
    precision = rprecision_score(
        y_true, y_pred, negative_class=negative_class, ratio=ratio,
        # If we're dividing by zero it means there were no true positives and
        # thus recall will be zero and the F1 score below will be zero.
        zero_division=1
    )
    recall = recall_score(y_true, y_pred, negative_class=negative_class)
    return 2 * precision * recall / (precision + recall)


if __name__ == "__main__":
    main(**doceasy.doceasy(__doc__, {
        "OUTFILE": doceasy.CsvFile(mode="w", default="-"),
        "FEATURES_PATH": Use(Path),
        "--cv-results-path": Or(None, Use(Path)),
        "--feature-importance": Or(None, Use(Path)),
        "--n-jobs": Use(int),
        "--verbose": Use(int),
    }))
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import time
import logging
import dataclasses
from pathlib import Path
from typing import Optional, ClassVar, Sequence, Union

import h5py
import numpy as np
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import StratifiedKFold, train_test_split
from lab.feature_extraction.trace import (
    ensure_non_ragged, extract_metadata, Metadata, extract_interarrival_times
)
from lab.classifiers import varcnn
from lab.metrics import rprecision_score, recall_score
import tensorflow

from common import doceasy
from common.doceasy import Use, Or, And

MAX_RAND_SEED: int = 10_000
N_META_FEATURES: int = 12
MAX_EPOCHS: int = 150
FEATURE_EXTRACTION_BATCH_SIZE: int = 5000

# Hyperparameters from the paper
DEFAULT_N_PACKETS: int = 5000
DEFAULT_LEARNING_RATE: float = 0.001
DEFAULT_LR_DECAY: float = np.sqrt(0.1)


def main(outfile, **kwargs):
    """Create and run the experiment and write the CSV results to outfile."""
    logging.basicConfig(
        format='[%(asctime)s] [%(levelname)s] %(name)s - %(message)s',
        level=logging.INFO
    )
    # Use autoclustering for speedup
    tensorflow.config.optimizer.set_jit("autoclustering")

    (probabilities, y_true, classes) = Experiment(**kwargs).run()

    y_pred = classes[np.argmax(probabilities, axis=1)]
    score = rf1_score(y_true, y_pred)
    logging.info("r_20 f1-score = %.4g", score)

    outfile.writerow(["y_true"] + list(classes))
    outfile.writerows(
        np.hstack((np.reshape(y_true, (-1, 1)), probabilities)))


@dataclasses.dataclass
class Experiment:
    """An experiment to evalute hyperparameter tuned Var-CNN classifier.
    """
    # The path to the dataset of sizes and times
    dataset_path: Path

    # The features to use, either "time" or "sizes"
    feature_type: str

    # The fraction of samples to use for testing the final model
    test_size: float = 0.2

    validation_split: float = 0.1

    # Number of folds used in the stratified k-fold cross validation
    n_folds: int = 3

    # Level of debugging output from sklearn and tensorflow
    verbose: int = 0

    # Random seed for the experiment
    seed: int = 7121

    # Hyperparams to use if not "tune"
    hyperparams: Union[str, dict] = "tune"

    # Hyperparameters to search
    n_packet_parameters: Sequence[int] = (5_000, 7_500, 10_000)
    learning_rate_parameters: Sequence[float] = (0.001,)
    lr_decay_parameters: Sequence[float] = (np.sqrt(0.1),)
    # lr_decay_parameters: Sequence[float] = (0.45, np.sqrt(0.1), 0.15)

    # Other seeds which are chosen for different operations
    seeds_: Optional[dict] = None

    logger: ClassVar = logging.getLogger("Experiment")

    def run(self):
        """Run hyperparameter tuning for the VarCNN classifier and return
        the prediction probabilities for the best chosen classifier.
        """
        # Generate and set random seeds
        rng = np.random.default_rng(self.seed)
        self.seeds_ = {
            "train_test": rng.integers(MAX_RAND_SEED),
            "kfold_shuffle": rng.integers(MAX_RAND_SEED),
            "tensorflow": rng.integers(MAX_RAND_SEED),
            "train_val": rng.integers(MAX_RAND_SEED),
        }
        tensorflow.random.set_seed(self.seeds_["tensorflow"])

        self.logger.info("Running %s", self)
        start = time.perf_counter()

        # Load the dataset
        X, y = self.load_dataset()
        n_classes = len(np.unique(y))
        self.logger.info("Dataset shape=%s, n_classes=%d", X.shape, n_classes)

        # Generate our training and final testing set
        x_train, x_test, y_train, y_test = train_test_split(
            X, y, test_size=self.test_size, stratify=y, shuffle=True,
            random_state=self.seeds_["train_test"]
        )

        if self.hyperparams == "tune":
            self.logger.info("Performing hyperparameter tuning ...")
            # Tune other hyperparameters and fit the final estimator
            classifier = self.tune_hyperparameters(x_train, y_train)
        else:
            assert isinstance(self.hyperparams, dict)
            n_packets = self.hyperparams.get("n_packets", DEFAULT_N_PACKETS)
            learning_rate = self.hyperparams.get(
                "learning_rate", DEFAULT_LEARNING_RATE
            )
            lr_decay = self.hyperparams.get("lr_decay", DEFAULT_LR_DECAY)
            self.logger.info(
                "Using n_packets=%s, learning_rate=%.3g, and lr_decay=%.3g",
                n_packets, learning_rate, lr_decay
            )

            x_train = first_n_packets(x_train, n_packets=n_packets)
            x_test = first_n_packets(x_test, n_packets=n_packets)

            classifier = varcnn.VarCNNClassifier(
                n_meta_features=N_META_FEATURES, n_packet_features=n_packets,
                callbacks=varcnn.default_callbacks(lr_decay=lr_decay),
                epochs=MAX_EPOCHS, tag=f"varcnn-{self.feature_type}",
                learning_rate=learning_rate, verbose=min(self.verbose, 1),
            )
            classifier.fit(
                x_train, y_train, validation_split=self.validation_split
            )

        # Predict the classes for the test set
        probabilities = classifier.predict_proba(x_test)
        self.logger.info(
            "Experiment complete in %.2fs.", (time.perf_counter() - start))

        return (probabilities, y_test, classifier.classes_)

    def tune_hyperparameters(self, x_train, y_train):
        """Perform hyperparameter tuning."""
        assert self.seeds_ is not None, "seeds must be set"
        pipeline = Pipeline([
            ("first_n_packets", FunctionTransformer(first_n_packets)),
            ("varcnn", varcnn.VarCNNClassifier(
                n_meta_features=N_META_FEATURES,
                epochs=MAX_EPOCHS, tag=f"varcnn-{self.feature_type}",
                validation_split=self.validation_split,
                verbose=min(self.verbose, 1),
            ))
        ])
        param_grid = [
            {
                "first_n_packets__kw_args": [{"n_packets": n_packets}],
                "varcnn__n_packet_features": [n_packets],
                "varcnn__learning_rate": self.learning_rate_parameters,
                "varcnn__callbacks": [
                    varcnn.default_callbacks(lr_decay=lr_decay)
                    for lr_decay in self.lr_decay_parameters
                ]
            }
            for n_packets in self.n_packet_parameters
        ]

        cross_validation = StratifiedKFold(
            self.n_folds, shuffle=True,
            random_state=self.seeds_["kfold_shuffle"]
        )

        grid_search = GridSearchCV(
            pipeline, param_grid, cv=cross_validation, error_score="raise",
            scoring=make_scorer(rf1_score), verbose=self.verbose, refit=True,
        )
        grid_search.fit(x_train, y_train)

        self.logger.info("hyperparameter results = %s", grid_search.cv_results_)
        self.logger.info("hyperparameter best = %s", grid_search.best_params_)
        self.logger.info(
            "hyperparameter best lr_decay %f",
            grid_search.best_params_["varcnn__callbacks"][0].factor
        )

        return grid_search.best_estimator_

    def load_dataset(self):
        """Load the features and classes from the dataset. Slices the
        packet features to the maximum evaluated.
        """
        self.logger.info("Loading dataset ...")
        with h5py.File(self.dataset_path, mode="r") as h5in:
            times = np.asarray(
                [x.astype("float32") for x in h5in["timestamps"]], dtype=object
            )
            sizes = np.asarray(h5in["sizes"])
            classes = np.asarray(h5in["labels"]["class"][:])
        self.logger.info("Extracting features ...")
        return self.extract_features(sizes=sizes, timestamps=times), classes

    def extract_features(
        self, *, sizes: np.ndarray, timestamps: np.ndarray
    ) -> np.ndarray:
        """Extract the features acrroding to the specified feature_type.

        Slices the packet features to the maximum evalauted amount.
        """
        assert self.feature_type in ("sizes", "time")

        meta_features = extract_metadata(
            sizes=sizes,
            timestamps=timestamps,
            metadata=(Metadata.COUNT_METADATA | Metadata.TIME_METADATA
                      | Metadata.SIZE_METADATA),
            batch_size=FEATURE_EXTRACTION_BATCH_SIZE
        )
        assert meta_features.shape[1] == N_META_FEATURES, "wrong # of metadata?"

        max_packets = max(self.n_packet_parameters)
        features = (
            ensure_non_ragged(sizes, dimension=max_packets)
            if self.feature_type == "sizes"
            else extract_interarrival_times(
                timestamps, dimension=max_packets
            ).astype("float32")
        )

        return np.hstack((features, meta_features))


def first_n_packets(features, *, n_packets: int):
    """Return the first n_packets packets along with the meta features."""
    n_features = features.shape[1]
    idx = np.r_[:n_packets, (n_features - N_META_FEATURES):n_features]
    return features[:, idx]


def rf1_score(y_true, y_pred, *, negative_class=-1, ratio=20):
    """Compute the F1-score using the r-precisions with the specified ratio
    and recall.
    """
    precision = rprecision_score(
        y_true, y_pred, negative_class=negative_class, ratio=ratio,
        # If we're dividing by zero it means there were no true positives and
        # thus recall will be zero and the F1 score below will be zero.
        zero_division=1
    )
    recall = recall_score(y_true, y_pred, negative_class=negative_class)
    return 2 * precision * recall / (precision + recall)


if __name__ == "__main__":
    main(**doceasy.doceasy(__doc__, {
        "OUTFILE": doceasy.CsvFile(mode="w", default="-"),
        "FEATURE_TYPE": Or("time", "sizes"),
        "DATASET_PATH": Use(Path),
        "--verbose": Use(int),
        "--hyperparams": Or("tune", And(doceasy.Mapping(), {
            doceasy.Optional("n_packets"): Use(int),
            doceasy.Optional("learning_rate"): Use(float),
            doceasy.Optional("lr_decay"): Use(float)
        }))
    }))
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import logging
from pathlib import Path

import h5py
import numpy as np
import pandas as pd
from lab.classifiers import kfingerprinting

from common import doceasy


def extract_features(infile: Path):
    """Perform the extraction."""
    logging.basicConfig(
        format='[%(asctime)s] [%(levelname)s] %(name)s - %(message)s',
        level=logging.INFO)

    with h5py.File(infile, mode="r") as h5in:
        logging.info("Loading dataset...")
        sizes = np.asarray(h5in["sizes"], dtype=object)
        times = np.asarray(h5in["timestamps"], dtype=object)
        labels = np.asarray(h5in["labels"]["class"])

    for i in range(len(times)):
        idx = np.argsort(times[i])
        sizes[i] = sizes[i][idx]
        times[i] = times[i][idx]
        assert times[i][0] == 0, "first should be zero"

    # Extract time and size related features
    features = kfingerprinting.extract_features_sequence(
        sizes=sizes, timestamps=times, n_jobs=None
    )
    frame = pd.DataFrame(features, columns=kfingerprinting.ALL_DEFAULT_FEATURES)
    frame.insert(0, "y_true", labels)
    print(frame.to_csv(index=False, header=True), end="")


if __name__ == "__main__":
    extract_features(**doceasy.doceasy(__doc__, {
        "INFILE": doceasy.Use(Path),
    }))
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import logging
from typing import Sequence, Final, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from publicsuffixlist import PublicSuffixList

import common
from common import doceasy, alt_svc

_DEFAULT_THRESHOLD: Final = 50
_LOGGER: Final = logging.getLogger("profile_domains.filter")


def _maybe_plot_counts(
    series, filename: Optional[str], top: int = 20, x_scale: str = None
):
    if filename is None:
        return

    fig, axes = plt.subplots()
    order = series.value_counts().index[:top]
    sns.countplot(y=series, order=order, ax=axes)
    if x_scale is not None:
        axes.set_xscale(x_scale)
    fig.savefig(filename, bbox_inches="tight")


def remove_similar_roots(
    domains: Sequence[str], ranks: Sequence[int], plot_filename=None
) -> Sequence[int]:
    """Remove domains which only differ in the public suffix and return
    the remaining indices, preferring domains with lower ranks.

    We note that mail.ch and mail.co.uk may be entirely different
    domain. However, given the fact that many domains have different
    localisations of their domain with differing public suffixes, we
    opt to simply remove all such occurences.
    """
    frame = (pd.Series(domains)
             .reset_index(drop=True).rename("netloc")
             .str.replace(r":\d+", "", regex=True)
             .to_frame()
             .assign(ranks=ranks)
             .sort_values(by="ranks"))

    psl = PublicSuffixList()
    frame.loc[:, "public_sfx"] = frame['netloc'].apply(psl.publicsuffix)
    frame.loc[:, "private_part"] = frame.apply(
        lambda x: x["netloc"][:-(len(x["public_sfx"]) + 1)], axis=1)

    _maybe_plot_counts(frame["private_part"], plot_filename)

    frame = frame.groupby('private_part', group_keys=False).head(n=1)
    return np.sort(frame.index)


def reduce_representation(
    domains: Sequence["str"],
    ranks: Sequence[int],
    sld_domains: Sequence[str],
    whitelist: str = r"(com|co|org|ac)\..*",
    threshold: int = 50,
    plot_filename: Optional[str] = None,
) -> Sequence[int]:
    """Reduce the representation of the specified sld_domains
    second-level domains to within the threshold.

    Any domains which are still not within the threshold, but are not
    excluded by the whitelist will result in an error.

    Return the index of the reduced sample, lower ranks take precedence
    when selecting the sample.
    """
    frame = (pd.Series(domains).reset_index(drop=True).rename("netloc")
             .str.replace(r":\d+", "", regex=True)
             .to_frame()
             .assign(ranks=ranks)
             .sort_values(by="ranks"))

    # Split the domain parts and rejoin everything from the second level domain
    # to the top-level domain
    frame["2LD"] = frame["netloc"].str.split(".").apply(
        lambda x: ".".join(x[-2:]))

    whitelist_mask = frame["2LD"].str.match(whitelist)
    exceptions_mask = whitelist_mask | ~frame["2LD"].isin(sld_domains)

    _maybe_plot_counts(frame["2LD"], plot_filename, x_scale="log")

    # Check that there are no others over the threshold when only considering
    # the whitelist
    non_filtered = (frame[~whitelist_mask & ~frame["2LD"].isin(sld_domains)]
                    .loc[:, "2LD"]
                    .value_counts())
    if (non_filtered > threshold).any():
        unaccounted = non_filtered[non_filtered > threshold].to_dict()
        raise ValueError(f"The provided sld_domains ({sld_domains}) and "
                         f"whitelist ({whitelist}) did not account for all "
                         f"excessive domains: {unaccounted}.")

    # Now perform the downsampling
    samples_idx = (frame[~exceptions_mask]
                   .groupby("2LD", group_keys=False)
                   .head(threshold)
                   .index)

    frame.drop(columns="2LD", inplace=True)
    return np.sort(np.concatenate((frame[exceptions_mask].index, samples_idx)))


# pylint: disable=too-many-arguments
def main(
    infile: str,
    outfile,
    versions: Sequence[str],
    sld_domains: Sequence[str],
    public_root_plot=None,
    sld_plot=None,
):
    """Perform filtering of ranked domains with alt_svc entries.
    """
    common.init_logging()

    data = pd.read_csv(infile, usecols=["rank", "domain", "alt_svc"]).dropna()
    _LOGGER.info("Loaded %d domains with alt-svc records", len(data))

    data["alt_svc"] = data["alt_svc"].apply(alt_svc.parse, errors="log")
    data = data.explode(column="alt_svc", ignore_index=True)
    data[["protocol", "authority"]] = pd.DataFrame(
        data["alt_svc"].tolist(), index=data.index)

    # Drop entries with a non-443 authority
    data = data[data["authority"] == ":443"]

    # Drop entries that do not support a desired version
    data = data[data["protocol"].isin(versions)]
    # Select one entry per domain
    data = data.groupby("rank").head(n=1).set_index("rank")[["domain"]]
    _LOGGER.info("%d domains support the specified versions",
                 data["domain"].nunique())

    filtered_idx = remove_similar_roots(
        data["domain"], data.index, plot_filename=public_root_plot)
    data = data.iloc[filtered_idx]
    _LOGGER.info("Filtered similar private domains to %d domains", len(data))

    data = data.iloc[reduce_representation(
        data["domain"], data.index, sld_domains=sld_domains,
        threshold=_DEFAULT_THRESHOLD, plot_filename=sld_plot
    )]
    _LOGGER.info("Reduced representation of %s in the dataset. Now %d domains",
                 sld_domains, len(data))

    data.sort_index().to_csv(outfile, header=False, index=True)


if __name__ == "__main__":
    try:
        KW_ARGS = {
            "infile": snakemake.input[0],                    # type: ignore
            "outfile": open(snakemake.output[0], mode="w"),  # type: ignore
            "versions": snakemake.params["versions"],        # type: ignore
            "sld_domains": snakemake.params["sld_domains"],  # type: ignore
            "public_root_plot": snakemake.output[1],         # type: ignore
            "sld_plot": snakemake.output[2],                 # type: ignore
        }
    except NameError:
        KW_ARGS = doceasy.doceasy(__doc__, {
            "<infile>": str,
            "<outfile>": doceasy.File(mode="w", default="-"),
            "--versions": [str],
            "--sld-domains": [str],
        })
    main(**KW_ARGS)
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import textwrap
import pandas as pd


def main(front: str, tamaraw: str, output: str):
    """Create a latex table of the overhead plots."""
    with open(output, mode="w") as table:
        table.write(textwrap.dedent(r"""
            \begin{tabular}{@{}lrr@{}}
            \toprule
            & FRONT & Tamaraw \\
            \midrule
            """))

        data = pd.concat([pd.read_csv(front), pd.read_csv(tamaraw)],
                         keys=["front", "tamaraw"], names=["defence"])
        data = (data.groupby(["defence", "overhead", "setting"])["value"]
                    .describe())
        data = data.rename(index={
            'simulated': 'Simulated', "collected": "Defended",
            "bandwidth": "Bandwidth", "latency": "Latency",
        })

        for overhead in ["Bandwidth", "Latency"]:
            table.write(r"\textbf{%s} \\" % overhead)
            table.write("\n")

            settings = ["Defended", "Simulated"]
            if overhead == "Bandwidth":
                settings.append("simulated-alt")

            for setting in settings:
                table.write(r"\quad %s" % setting)
                for defence in ["front", "tamaraw"]:
                    median = data["50%"][defence][overhead][setting]
                    iqr1 = data["25%"][defence][overhead][setting]
                    iqr3 = data["75%"][defence][overhead][setting]
                    table.write(
                        f" &${median:.2f}$ (${iqr1:.2f}\\text{{--}}{iqr3:.2f}$)"
                    )
                table.write(r" \\" + "\n")
        table.write("\\bottomrule\n\\end{tabular}\n")
        with pd.option_context('display.max_colwidth', None):
            table.write(textwrap.indent(
                data.drop(
                    columns=["count", "mean", "std", "min", "max"]).to_csv(),
                "% "))


if __name__ == "__main__":
    snakemake = globals().get("snakemake", None)
    main(front=snakemake.input["front"], tamaraw=snakemake.input["tamaraw"],
         output=snakemake.output[0])
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import hashlib
import logging
import functools
import subprocess
from subprocess import DEVNULL
from pathlib import Path
from typing import Dict, List

import lab.tracev2 as trace

import common
from common import neqo
from common.collect import Collector

_LOGGER = logging.getLogger(__name__)


def collect_with_args(
    input_file: Path,
    output_dir: Path,
    region_id: int,
    client_id: int,
    neqo_args: List[str],
    config: Dict,
    timeout: float,
    skip_neqo: bool,
) -> bool:
    """Collect a trace using NEQO and return True iff it was successful."""
    #: Copy the args, since it is shared among all method instances
    neqo_args = neqo_args + [
        "--header", "user-agent", config["user_agent"],
        "--url-dependencies-from", str(input_file),
    ]
    if (
        "tamaraw" in neqo_args
        or "front" in neqo_args
        or "schedule" in neqo_args
    ):
        neqo_args += ["--defence-event-log", str(output_dir / "schedule.csv")]

    if "front" in neqo_args:
        # Take a 4 byte integer from the output directory
        dir_bytes = str(output_dir).encode('utf-8')
        seed = int(hashlib.sha256(dir_bytes).hexdigest(), 16) & 0xffffffff
        neqo_args += ["--defence-seed", str(seed)]

    client_port = config["wireguard"]["client_ports"][region_id][client_id]
    interface = config["wireguard"]["interface"]

    url = _get_main_url(input_file)

    try:
        pcap = neqo.run_alongside(
            neqo_args,
            lambda: subprocess.Popen(
                "sleep .5 && workflow/scripts/docker-dep-fetch-vpn"
                f" {region_id} {client_id} --max-attempts 1 --single-url {url}",
                shell=True, stderr=DEVNULL, stdout=DEVNULL),
            neqo_exe=[
                "workflow/scripts/neqo-client-vpn", str(region_id),
                str(client_id)
            ],
            stdout=str(output_dir / "stdout.txt"),
            stderr=str(output_dir / "stderr.txt"),
            env={"RUST_LOG": config["neqo_log_level"]},
            tcpdump_kw={
                "capture_filter": f"udp port {client_port}", "iface": interface,
            },
            timeout=timeout,
            skip_neqo=skip_neqo,
        )
    except subprocess.TimeoutExpired as err:
        _LOGGER.debug("Neqo timed out: %s", err)
    except subprocess.CalledProcessError as err:
        _LOGGER.debug("Neqo/browser failed with error: %s", err)
    else:
        assert pcap is not None
        # (output_dir / "trace.pcapng").write_bytes(result.pcap)
        traffic = trace.from_pcap(pcap, client_port=client_port)
        if len(traffic) == 0:
            _LOGGER.debug("Failed with empty trace.")
            return False
        trace.to_csv((output_dir / "trace.csv"), traffic)
        return True
    return False


def _get_main_url(input_file):
    graph = json.loads(Path(input_file).read_text())
    return next(n for n in graph["nodes"] if n["id"] == 0)["url"]


def main(
    input_,
    output,
    config: Dict,
    *,
    neqo_args: List[str],
    n_instances: int,
    n_monitored: int,
    n_unmonitored: int = 0,
    max_failures: int = 3,
    timeout: float = 120,
    skip_neqo: bool = False,
):
    """Collect all the samples for the speicified arguments."""
    common.init_logging(name_thread=True, verbose=True)

    neqo_args = [str(x) for x in neqo_args]
    n_regions = config["wireguard"]["n_regions"]
    n_clients_per_region = min(config["wireguard"]["n_clients_per_region"], 4)

    Collector(
        functools.partial(
            collect_with_args, neqo_args=neqo_args, config=config,
            timeout=timeout, skip_neqo=skip_neqo),
        n_regions=n_regions,
        n_clients_per_region=n_clients_per_region,
        n_instances=n_instances,
        n_monitored=n_monitored,
        n_unmonitored=n_unmonitored,
        max_failures=max_failures,
        input_dir=input_,
        output_dir=output,
    ).run()


if __name__ == "__main__":
    snakemake = globals().get("snakemake", None)
    main(input_=str(snakemake.input[0]), output=str(snakemake.output[0]),
         config=snakemake.config, **snakemake.params)
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import hashlib
import logging
import functools
import subprocess
from pathlib import Path
from typing import Dict, List

import lab.tracev2 as trace

import common
from common import neqo
from common.collect import Collector

_LOGGER = logging.getLogger(__name__)


def collect_with_args(
    input_file: Path,
    output_dir: Path,
    region_id: int,
    client_id: int,
    neqo_args: List[str],
    config: Dict,
    timeout: float,
) -> bool:
    """Collect a trace using NEQO and return True iff it was successful."""
    #: Copy the args, since it is shared among all method instances
    neqo_args = neqo_args + [
        "--header", "user-agent", config["user_agent"],
        "--url-dependencies-from", str(input_file),
    ]
    if (
        "tamaraw" in neqo_args
        or "front" in neqo_args
        or "schedule" in neqo_args
    ):
        neqo_args += ["--defence-event-log", str(output_dir / "schedule.csv")]

    if "front" in neqo_args:
        # Take a 4 byte integer from the output directory
        dir_bytes = str(output_dir).encode('utf-8')
        seed = int(hashlib.sha256(dir_bytes).hexdigest(), 16) & 0xffffffff
        neqo_args += ["--defence-seed", str(seed)]

    client_port = config["wireguard"]["client_ports"][region_id][client_id]
    interface = config["wireguard"]["interface"]

    try:
        (result, pcap) = neqo.run(
            neqo_args,
            neqo_exe=[
                "workflow/scripts/neqo-client-vpn", str(region_id),
                str(client_id)
            ],
            check=True,
            stdout=str(output_dir / "stdout.txt"),
            stderr=str(output_dir / "stderr.txt"),
            pcap=neqo.PIPE,
            env={"RUST_LOG": config["neqo_log_level"]},
            tcpdump_kw={
                "capture_filter": f"udp port {client_port}", "iface": interface,
            },
            timeout=timeout,
        )
    except subprocess.TimeoutExpired as err:
        _LOGGER.debug("Neqo timed out: %s", err)
    except subprocess.CalledProcessError as err:
        _LOGGER.debug("Neqo failed with error: %s", err)
    else:
        assert result.returncode == 0
        assert pcap is not None
        # (output_dir / "trace.pcapng").write_bytes(result.pcap)
        trace.to_csv((output_dir / "trace.csv"),
                     trace.from_pcap(pcap, client_port=client_port))
        _LOGGER.debug("Neqo succeeded.")
        return True
    return False


def main(
    input_,
    output,
    config: Dict,
    *,
    neqo_args: List[str],
    n_instances: int,
    n_monitored: int,
    n_unmonitored: int = 0,
    max_failures: int = 3,
    timeout: float = 120,
    use_multiple_connections: bool = False,
):
    """Collect all the samples for the speicified arguments."""
    common.init_logging(name_thread=True, verbose=True)

    neqo_args = [str(x) for x in neqo_args]
    n_regions = config["wireguard"]["n_regions"]
    n_clients_per_region = config["wireguard"]["n_clients_per_region"]

    _LOGGER.info("Env variable NEQO_BIN=%s", os.environ["NEQO_BIN"])
    _LOGGER.info("Env variable NEQO_BIN_MP=%s", os.environ["NEQO_BIN_MP"])
    if use_multiple_connections:
        os.environ["NEQO_BIN"] = os.environ["NEQO_BIN_MP"]
        _LOGGER.info("Env variable updated NEQO_BIN=%s", os.environ["NEQO_BIN"])

    Collector(
        functools.partial(
            collect_with_args, neqo_args=neqo_args, config=config,
            timeout=timeout),
        n_regions=n_regions,
        n_clients_per_region=n_clients_per_region,
        n_instances=n_instances,
        n_monitored=n_monitored,
        n_unmonitored=n_unmonitored,
        max_failures=max_failures,
        input_dir=input_,
        output_dir=output,
    ).run()


if __name__ == "__main__":
    snakemake = globals().get("snakemake", None)
    main(input_=str(snakemake.input[0]), output=str(snakemake.output[0]),
         config=snakemake.config, **snakemake.params)
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import asyncio
import hashlib
import logging
import functools
import subprocess
from pathlib import Path
from typing import Dict, List

import yaml
import lab.tracev2 as trace

import common
from common import neqo
from common.collectv2 import Collector
from common.doceasy import doceasy, Use, Schema

_LOGGER = logging.getLogger(__name__)


def collect_with_args(
    input_file: Path,
    output_dir: Path,
    region_id: int,
    client_id: int,
    neqo_args: List[str],
    config: Dict,
    timeout: float,
) -> bool:
    """Collect a trace using NEQO and return True iff it was successful."""
    #: Copy the args, since it is shared among all method instances
    neqo_args = neqo_args + [
        "--header", "user-agent", config["user_agent"],
        "--url-dependencies-from", str(input_file),
    ]
    if (
        "tamaraw" in neqo_args
        or "front" in neqo_args
        or "schedule" in neqo_args
    ):
        neqo_args += ["--defence-event-log", str(output_dir / "schedule.csv")]

    if "front" in neqo_args:
        # Take a 4 byte integer from the output directory
        dir_bytes = str(output_dir).encode('utf-8')
        seed = int(hashlib.sha256(dir_bytes).hexdigest(), 16) & 0xffffffff
        neqo_args += ["--defence-seed", str(seed)]

    client_port = config["wireguard"]["client_ports"][region_id][client_id]
    interface = config["wireguard"]["interface"]

    try:
        (result, pcap) = neqo.run(
            neqo_args,
            neqo_exe=[
                "workflow/scripts/neqo-client-vpn", str(region_id),
                str(client_id)
            ],
            check=True,
            stdout=str(output_dir / "stdout.txt"),
            stderr=str(output_dir / "stderr.txt"),
            pcap=neqo.PIPE,
            env={"RUST_LOG": config["neqo_log_level"]},
            tcpdump_kw={
                "capture_filter": f"udp port {client_port}", "iface": interface,
            },
            timeout=timeout,
        )
    except subprocess.TimeoutExpired as err:
        _LOGGER.debug("Neqo timed out: %s", err)
    except subprocess.CalledProcessError as err:
        _LOGGER.debug("Neqo failed with error: %s", err)
    else:
        assert result.returncode == 0
        assert pcap is not None
        # (output_dir / "trace.pcapng").write_bytes(result.pcap)
        trace.to_csv((output_dir / "trace.csv"),
                     trace.from_pcap(pcap, client_port=client_port))
        _LOGGER.debug("Neqo succeeded.")
        return True

    _check_for_user_error(output_dir / "stderr.txt")
    return False


class MisconfigurationError(RuntimeError):
    """Raised when the experiment is misconfigured."""


def _check_for_user_error(stderr: Path):
    """Check for errors that need to be fixed before the collection
    can continue.
    """
    err_txt = stderr.read_text()

    if "No such file or directory" in err_txt:
        raise MisconfigurationError(
            f"Collection run failed due to files being missing. See {stderr!s}"
            "for more details."
        )

    if (
        "error: Found argument" in err_txt
        or "error: Invalid value for" in err_txt
        or "USAGE:" in err_txt
    ):
        raise MisconfigurationError(
            "Collection run failed due to invalid arguments or parameters."
            f"See {stderr!s} for more details."
        )


async def main(
    input_dir: Path,
    output_dir: Path,
    configfile: Path,
    *,
    neqo_args: List[str],
    n_instances: int,
    n_monitored: int,
    n_unmonitored: int,
    max_failures: int,
    timeout: float,
    use_multiple_connections: bool = False,
):
    """Collect all the samples for the speicified arguments."""
    common.init_logging(name_thread=True, verbose=True)

    config = yaml.safe_load(configfile.read_text())
    n_regions = config["wireguard"]["n_regions"]
    n_clients_per_region = config["wireguard"]["n_clients_per_region"]

    _LOGGER.info("Env variable NEQO_BIN=%s", os.environ["NEQO_BIN"])
    _LOGGER.info("Env variable NEQO_BIN_MP=%s", os.environ["NEQO_BIN_MP"])
    if use_multiple_connections:
        os.environ["NEQO_BIN"] = os.environ["NEQO_BIN_MP"]
        _LOGGER.info("Env variable updated NEQO_BIN=%s", os.environ["NEQO_BIN"])

    await Collector(
        functools.partial(
            collect_with_args, neqo_args=neqo_args, config=config,
            timeout=timeout),
        n_regions=n_regions,
        n_clients_per_region=n_clients_per_region,
        n_instances=n_instances,
        n_monitored=n_monitored,
        n_unmonitored=n_unmonitored,
        max_failures=max_failures,
        input_dir=input_dir,
        output_dir=output_dir,
    ).run()


if __name__ == "__main__":
    asyncio.run(main(**doceasy(__doc__, Schema({
        "INPUT_DIR": Use(Path),
        "OUTPUT_DIR": Use(Path),
        "--configfile": Use(Path),
        "--use-multiple-connections": bool,
        "--n-monitored": Use(int),
        "--n-instances": Use(int),
        "--n-unmonitored": Use(int),
        "--max-failures": Use(int),
        "--timeout": Use(float),
        "NEQO_ARGS": [str],
    }, ignore_extra_keys=True))))
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import hashlib
import logging
import functools
import subprocess
from pathlib import Path
from typing import Dict, List

import lab.tracev2 as trace

import common
from common import neqo
from common.collect import Collector

_LOGGER = logging.getLogger(__name__)


def collect_with_args(
    input_file: Path,
    output_dir: Path,
    region_id: int,
    client_id: int,
    neqo_args: List[str],
    config: Dict,
    timeout: float,
) -> bool:
    """Collect a trace using NEQO and return True iff it was successful."""
    #: Copy the args, since it is shared among all method instances
    common_args = [
        "--url-dependencies-from", str(input_file),
        "--header", "user-agent", config["user_agent"],
    ]
    defended_args = neqo_args + common_args + [
        "--defence-event-log", str(output_dir / "defended" / "schedule.csv"),
    ]

    client_port = config["wireguard"]["client_ports"][region_id][client_id]
    interface = config["wireguard"]["interface"]

    try:
        for setting, args, directory in [
            ("defended", defended_args, (output_dir / "defended")),
            ("undefended", common_args, (output_dir / "undefended")),
        ]:
            directory.mkdir(exist_ok=True)

            # front needs a seed, make one based on the directory and save it
            if setting == "defended" and "front" in args:
                # Take a 4 byte integer from the output directory
                seed = int(
                    hashlib.sha256(str(directory).encode('utf-8')).hexdigest(),
                    16
                ) & 0xffffffff
                args += ["--defence-seed", seed]

            args = [str(arg) for arg in args]

            _LOGGER.debug("Collecting setting %r", setting)
            (result, pcap) = neqo.run(
                args,
                neqo_exe=[
                    "workflow/scripts/neqo-client-vpn", str(region_id),
                    str(client_id)
                ],
                check=True,
                stdout=str(directory / "stdout.txt"),
                stderr=str(directory / "stderr.txt"),
                pcap=neqo.PIPE,
                env={"RUST_LOG": config["neqo_log_level"]},
                tcpdump_kw={
                    "capture_filter": f"udp port {client_port}",
                    "iface": interface,
                },
                timeout=timeout,
            )
            assert result.returncode == 0
            assert pcap is not None
            trace.to_csv((directory / "trace.csv"),
                         trace.from_pcap(pcap, client_port=client_port))
    except subprocess.TimeoutExpired as err:
        _LOGGER.debug("Neqo timed out on setting %r: %s", setting, err)
        return False
    except subprocess.CalledProcessError as err:
        _LOGGER.debug("Neqo failed in setting %r with error: %s", setting, err)
        return False
    return True


def main(
    input_,
    output,
    config: Dict,
    *,
    neqo_args: List[str],
    n_instances: int = 0,
    n_monitored: int = 0,
    n_unmonitored: int = 0,
    max_failures: int = 3,
    timeout: float = 120,
    use_multiple_connections: bool = False,
):
    """Collect all the samples for the speicified arguments."""
    common.init_logging(name_thread=True, verbose=True)

    if (
        "tamaraw" not in neqo_args
        and "front" not in neqo_args
        and "schedule" not in neqo_args
    ):
        raise ValueError("The arguments must correspond to a defence.")

    n_regions = config["wireguard"]["n_regions"]
    n_clients_per_region = config["wireguard"]["n_clients_per_region"]

    _LOGGER.info("Env variable NEQO_BIN=%s", os.environ["NEQO_BIN"])
    _LOGGER.info("Env variable NEQO_BIN_MP=%s", os.environ["NEQO_BIN_MP"])
    if use_multiple_connections:
        os.environ["NEQO_BIN"] = os.environ["NEQO_BIN_MP"]
        _LOGGER.info("Env variable updated NEQO_BIN=%s", os.environ["NEQO_BIN"])

    Collector(
        functools.partial(
            collect_with_args, neqo_args=neqo_args, config=config,
            timeout=timeout),
        n_regions=n_regions,
        n_clients_per_region=n_clients_per_region,
        n_instances=n_instances,
        n_monitored=n_monitored,
        n_unmonitored=n_unmonitored,
        max_failures=max_failures,
        input_dir=input_,
        output_dir=output,
    ).run()


if __name__ == "__main__":
    snakemake = globals().get("snakemake", None)
    main(input_=str(snakemake.input[0]), output=str(snakemake.output[0]),
         config=snakemake.config, **snakemake.params)
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import json
import logging
import fileinput
import urllib.parse
from collections import Counter
from typing import Set, Optional, Iterator, Dict, List, Any
from pathlib import Path
import networkx as nx

import common
from common import doceasy

#: The minimum number of URLs in each dependency graph
N_MININMUM_URLS = 1
#: Allow error codes for too many requests (429) and server timeout (522) since
#: these are transient.
ALLOWED_HTTP_ERRORS = [429, 522, ]

_LOGGER = logging.getLogger("url-dep-graph")


def origin(url: str) -> str:
    """Return the origin of the URL."""
    parts = urllib.parse.urlsplit(url)
    return f"{parts[0]}://{parts[1]}"


class InvalidGraphError(RuntimeError):
    """Raised when construction would result in an empty graph."""


def _rget(mapping, keys: List[Any], default=None):
    """Recrusvie get."""
    assert isinstance(keys, list)
    for key in keys[:-1]:
        mapping = mapping.get(key, {})
    return mapping.get(keys[-1], default)


class _DependencyGraph:
    def __init__(self, browser_log, origin_: Optional[str] = None):
        self.logs = browser_log
        self.origin = origin_
        self.graph = nx.DiGraph()
        #: A mapping of URLs to node_ids to account for redirections
        self._url_node_ids: Dict[str, List[str]] = dict()
        self._ignored_requests: Set[str] = set()

        self._construct()

    def _construct(self):
        msgs = self.logs["http_trace"]
        msgs.sort(key=lambda msg: msg["timestamp"])

        for msg in msgs:
            msg = msg["message"]["message"]

            if msg["method"] == "Network.requestWillBeSent":
                self._handle_request(msg)
            elif msg["method"] == "Network.responseReceived":
                self._handle_response(msg)
            elif msg["method"] == "Network.dataReceived":
                self._handle_data(msg)
            else:
                continue

        if self_loops := list(nx.nodes_with_selfloops(self.graph)):
            raise InvalidGraphError(f"Graph contains self loops: {self_loops}")
        if loops := list(nx.simple_cycles(self.graph)):
            raise InvalidGraphError(f"Graph contains loops: {loops}")

        if self.origin is not None:
            to_drop = [node for node, node_origin in self.graph.nodes("origin")
                       if node_origin != self.origin]
            self.graph.remove_nodes_from(to_drop)

            if len(self.graph) == 0:
                raise InvalidGraphError(
                    f"Origin filtering would result in an empty graph:"
                    f" {self.origin}")

        self.graph = nx.relabel_nodes(self.graph, mapping={
            node: i for (i, node) in enumerate(self.graph)
        })

    def to_json(self) -> str:
        """Convert the graph to json."""
        return json.dumps(nx.node_link_data(self.graph), indent=2)

    def roots(self) -> List[str]:
        """Return the roots of the graph."""
        return [node for node, degree in self.graph.in_degree() if degree == 0]

    def _add_node(self, node_id, url, type_):
        if url not in self._url_node_ids:
            self._url_node_ids[url] = [node_id, ]
        elif node_id not in self._url_node_ids[url]:
            self._url_node_ids[url].append(node_id)
        # If this is a redirection it will change the details of the node
        self.graph.add_node(
            node_id, url=url, done=False, type=type_, origin=origin(url),
            content_length=None, data_length=0,
        )

    def _handle_data(self, msg):
        assert msg["method"] == "Network.dataReceived"
        node_id = msg["params"]["requestId"]
        if (
            node_id in self._ignored_requests
            or node_id not in self.graph.nodes
        ):
            return

        self.graph.nodes[node_id]["data_length"] += _rget(
            msg, ["params", "dataLength"], 0)

    def _handle_response(self, msg):
        assert msg["method"] == "Network.responseReceived"
        node_id = msg["params"]["requestId"]
        if (
            node_id in self._ignored_requests
            or node_id not in self.graph.nodes
        ):
            return
        self.graph.nodes[node_id]["done"] = True

        size = _rget(
            msg, ["params", "response", "headers", "content-length"], None
        )
        if size is not None:
            self.graph.nodes[node_id]["content_length"] = int(size)

    def _find_node_by_url(self, url: str) -> Optional[str]:
        """Find the most recent node associated with a get request."""
        # TODO: Check these nodes for which is completed?
        if not (node_ids := self._url_node_ids.get(url, [])):
            return None
        if nid := next(
            (nid for nid in node_ids if self.graph.nodes[nid]["done"]), None
        ):
            return nid
        return node_ids[-1]

    def _add_origin_dependency(self, dep: str, node_id):
        """Add a dependency for node_id to the root with the same
        origin as dep.
        """
        root_node = next((root for root in self.roots()
                          if self.graph.nodes[root]["origin"] == origin(dep)),
                         None)
        assert root_node is not None
        self.graph.add_edge(root_node, node_id)

    def _add_dependency(self, dep, node_id) -> bool:
        if not dep.startswith("http"):
            return True
        if dep_node := self._find_node_by_url(dep):
            self.graph.add_edge(dep_node, node_id)
            return True
        return False

    def _handle_request(self, msg):
        assert msg["method"] == "Network.requestWillBeSent"

        request = msg["params"]["request"]
        node_id = msg["params"]["requestId"]

        if (request["method"] != "GET"
                or not request["url"].startswith("https://")):
            self._ignored_requests.add(node_id)
            return

        self._add_node(node_id, request["url"], msg["params"]["type"])

        if msg["params"]["documentURL"] != request["url"]:
            if not self._add_dependency(msg["params"]["documentURL"], node_id):
                _LOGGER.debug(
                    "Unable to find documentURL dependency of %r: %r",
                    request["url"], msg["params"]["documentURL"])

        if initiator_url := msg["params"]["initiator"].get("url", None):
            if not self._add_dependency(initiator_url, node_id):
                _LOGGER.debug(
                    "Unable to find initiator dependency of %r: %r",
                    request["url"], initiator_url)

        if stack := msg["params"]["initiator"].get("stack", None):
            for stack_frame in stack["callFrames"]:
                if not self._add_dependency(stack_frame["url"], node_id):
                    _LOGGER.debug(
                        "Unable to find documentURL dependency of %r: %r",
                        request["url"], stack_frame["url"])


def extract_graphs(
    fetch_output_generator, use_origin: bool = True
) -> Iterator[nx.DiGraph]:
    """Filter and generate non-empty graphs from the input
    generated of fetch results.
    """
    seen_urls = set()
    dropped_urls: Dict[str, Counter] = {
        "duplicate": Counter(),
        "disconnected": Counter(),
        "insufficient": Counter(),
        "empty": Counter()
    }

    for result in fetch_output_generator:
        url = result["final_url"] or result["url"]

        if result["status"] != "success":
            _LOGGER.debug(
                "Dropping %r with a status of %r.", url, result["status"])
            continue
        if not url.startswith("https"):
            _LOGGER.debug("Dropping %r as it is not HTTPS.", url)
            continue
        if url in seen_urls:
            _LOGGER.debug("Dropping %r as it was already encountered.", url)
            dropped_urls["duplicate"][url] += 1
            continue

        try:
            graph = _DependencyGraph(
                result, origin(url) if use_origin else None)
        except InvalidGraphError as err:
            _LOGGER.debug("Dropping %r: %s", url, err)
            dropped_urls["empty"][url] += 1
            continue

        if len(graph.roots()) > 1:
            _LOGGER.debug("Dropping %r as it is disconnected.", url)
            dropped_urls["disconnected"][url] += 1
        elif len(graph.graph.nodes) < N_MININMUM_URLS:
            _LOGGER.debug("Dropping %r as it has only %d/%d required URLs.",
                          url, len(graph.graph.nodes), N_MININMUM_URLS)
            dropped_urls["insufficient"][url] += 1
        else:
            yield graph
            seen_urls.add(url)

    for type_, counters in dropped_urls.items():
        _LOGGER.debug("Dropped %s urls: %s", type_, dict(counters))
        _LOGGER.info("Dropped %s urls: %d", type_, sum(counters.values()))


def main(infile: List[str], prefix: str, verbose: bool, no_origin_filter: bool):
    """Filter browser URL request logs and extract dependency graphs."""
    common.init_logging(verbosity=int(verbose) + 1)
    _LOGGER.info("Running with arguments: %s.", locals())

    file_id = -1
    with fileinput.input(files=infile, openhook=fileinput.hook_compressed) \
            as json_lines:
        results = (json.loads(line) for line in json_lines)

        for file_id, graph in enumerate(
            extract_graphs(results, use_origin=not no_origin_filter)
        ):
            path = Path(f"{prefix}{file_id:04d}.json")
            if not path.is_file():
                path.write_text(graph.to_json())
            else:
                _LOGGER.info("Refusing to overwrite: %s", path)
    _LOGGER.info("Script complete. Extracted %d dependency graphs.", file_id+1)


if __name__ == "__main__":
    main(**doceasy.doceasy(__doc__, {
        "PREFIX": doceasy.Or(str, doceasy.Use(lambda _: "x")),
        "INFILE": [str],
        "--no-origin-filter": bool,
        "--verbose": bool,
    }))
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import time
import asyncio
import socket
import logging
import pathlib
from typing import Sequence, NamedTuple, Optional
from datetime import timedelta

import aiohttp
from aiohttp.resolver import AsyncResolver
import numpy as np
import pandas as pd

import common
from common import doceasy
from common.doceasy import And, Use, AtLeast

_LOGGER = logging.getLogger(pathlib.Path(__file__).name)

NAMESERVERS = [
    # Google
    '8.8.8.8', '8.8.4.4',
    # Cloudflare
    '1.1.1.1', '1.0.0.1',
    # Quad9
    '9.9.9.9', '149.112.112.112',
    # Verisign
    '64.6.64.6', '64.6.65.6',
    # Level-3
    '209.244.0.3', '209.244.0.4',
    # Freenom
    '80.80.80.80', '80.80.81.81',
    # Open DNS
    '208.67.222.222', '208.67.220.220',
    # Yandex
    '77.88.8.8', '77.88.8.7',
    # Comodo
    '8.26.56.26', '8.20.247.20',
]
USER_AGENT = ("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
              "(KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36")


class DomainProfile(NamedTuple):
    """Profiling result of a domain.

    status:
        A positive status indicates an HTTP status code.  A negative
        status indicates one of the custom defined error codes.
    """
    domain: str
    fetch_duration: float
    status: Optional[int] = None
    url: Optional[str] = None
    real_url: Optional[str] = None
    alt_svc: Optional[str] = None
    server: Optional[str] = None
    error: Optional[str] = None
    error_str: Optional[str] = None


async def _profile_domain(
    domain: str, connector, timeout: aiohttp.ClientTimeout
) -> DomainProfile:
    headers = {'user-agent': USER_AGENT}
    async with aiohttp.ClientSession(
        timeout=timeout, connector=connector, connector_owner=False,
        headers=headers,
    ) as client:
        try:
            async with client.get(f'https://{domain}', read_until_eof=False,
                                  allow_redirects=True) as resp:
                return DomainProfile(
                    domain, fetch_duration=np.nan, status=resp.status,
                    url=str(resp.url), real_url=str(resp.real_url),
                    alt_svc=resp.headers.get('alt-svc', None),
                    server=resp.headers.get('server', None))
        except asyncio.TimeoutError:
            return DomainProfile(domain, np.nan, error='timeout')
        except aiohttp.ClientSSLError as err:
            return DomainProfile(domain, np.nan, error='ssl-error',
                                 error_str=str(err))
        except (aiohttp.ClientError, ValueError) as err:
            return DomainProfile(domain, np.nan, error='other-error',
                                 error_str=repr(err))
        except OSError as err:
            return DomainProfile(domain, np.nan, error=f'oserror({err.errno})',
                                 error_str=err.strerror)


async def profile_domain(domain: str, connector, sem: asyncio.Semaphore,
                         timeout: aiohttp.ClientTimeout) -> DomainProfile:
    """Request an https version of the domain and record the resulting
    status code, url, and alternative services.
    """
    async with sem:
        start_time = time.perf_counter()
        result = await _profile_domain(domain, connector, timeout)
        return result._replace(
            fetch_duration=(time.perf_counter() - start_time))


async def run_profiling(domains: Sequence[str], n_outstanding: int,
                        total_timeout: float):
    """Profile the specified domains, with at most n_outstanding
    requests at a time.
    """
    semaphore = asyncio.Semaphore(n_outstanding)
    timeout = aiohttp.ClientTimeout(total=total_timeout)

    # Create a resolver with custom nameservers, and close it manually as it
    # does not support async with
    resolver = AsyncResolver(nameservers=NAMESERVERS, rotate=True, tries=2)
    try:
        # Specifying the family is necessary to avoid network unreachable errs
        async with aiohttp.TCPConnector(
            resolver=resolver, limit=n_outstanding, use_dns_cache=True,
            ttl_dns_cache=(60 * 5), family=socket.AF_INET, force_close=True,
            enable_cleanup_closed=True
        ) as connector:
            return await asyncio.gather(
                *[profile_domain(d, connector, semaphore, timeout)
                  for d in domains])
    finally:
        await resolver.close()


def main(
    domain_file, outfile, max_outstanding: int = 150,
    timeout: float = 30
):
    """Program entry point."""
    data = pd.read_csv(domain_file, squeeze=True, names=["rank", "domain"])

    _LOGGER.info("Profiling %d domains with timeout=%.2f and "
                 "max-outstanding=%d.", len(data), timeout, max_outstanding)

    start_time = time.perf_counter()
    result = asyncio.run(
        run_profiling(data["domain"], max_outstanding, timeout))
    duration = time.perf_counter() - start_time

    data.merge(pd.DataFrame(result), left_on="domain", right_on="domain",
               how="left").to_csv(
                   outfile, header=True, index=False, errors="backslashreplace"
               )

    _LOGGER.info("Profiling complete in %s (%.2fs)",
                 timedelta(seconds=duration), duration)


if __name__ == '__main__':
    try:
        KW_ARGS = {"domain_file": snakemake.input[0],
                   "outfile": snakemake.output[0]}
        common.init_logging(filename=next(iter(snakemake.log), None))
    except NameError:
        KW_ARGS = doceasy.doceasy(__doc__, {
            'DOMAIN_FILE': str,
            'OUTFILE': str,
            '--max-outstanding': And(Use(int), AtLeast(1)),
            '--timeout': And(Use(float), AtLeast(0.0)),
        })
        common.init_logging()

    main(**KW_ARGS)
ShowHide 40 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/jpcsmith/qcsd-experiments
Name: qcsd-experiments
Version: v1.0.1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...