Nanopore Read Clustering and Polishing Workflow

public public 1yr ago 0 bookmarks

This workflow attempts to find clusters of similar sequences in a set of nanopore reads. It is intended to be used for samples that have been size selected for viral-like particles.

Overview

Clustering 1: windowed minimap2

The first pass uses minimap2 and mcl to find clusters of similar sequences.

Because an all-v-all comparison is system-taxing and we're looking for clusters of complete sequences of the same length, we split the reads into buckets by size. We use overlapping size windows so that we can combine the distance measures and run MCL on all the reads at once.

Clustering 2: lastal

Clusters with enough reads are processed with lastal and mcl to find subclusters. A cutoff of 85% identity is used to build the MCL distance matrix

Polishing

Subclusters with enough reads and a tight size distribution are polished with racon and medaka to produce consensus sequences.

Installation

The workflow is executed via snakemake, which can handle the installation of all required dependencies (with the help of conda).

NOTE: The commands in this section (Installation) are assumed to be run from the repo directory.

Conda

The simplest approach is to install conda if you don't have it, and then to use conda to install snakemake.

Mamba

Mamba is an add on for conda that can isntall programs much faster than conda. We recommend installing that, too, but you don't have to. If you do have mamba install, replace "conda" with "mamba" in the following command.

The snakemake env

Create a conda environment for running snakemake by using the provided configuration file:

$ conda env create -p ./env -f np_read_clustering/conda/snake.yaml

To use snakemake, you'll have to activate the environment in the shell (or script) from which you want to launch the workflow:

$ conda activate ./env

NOTE: The -p ./env option creates the environment in a folder named env in your current directory. You can use any name and location you wish. You can also use -n to name the environment and keep it in you conda installation location. See the conda documentation for how to name and activate environments for more detail.

A Test Run

That's it. Now you are ready to go. Test your setup (and pre-install the rest of the dependencies):

$ snakemake --configfile=config.yaml -j 2 -p --use-conda --conda-frontend mamba

Note: this assumes you are running from the repo directory.

You can also run the larger test file by overriding key config values:

$ snakemake --configfile=config.yaml -j 20 -p --use-conda --conda-frontend mamba \
 --config all_fasta=test/test.fasta work_dir=test/outputs/nprc

Note: we increased the threasd count from 2 to 20, because this is a bigger dataset.

Running

Configuration

All of the configuration parameters are top level, so they can be supplied on the command line, but can be passed in by file as in the test example above.

Required Parameters

The only stricltly necessary input is:

  • all_fasta: a fasta file of all the nanopore reads. This is assumed to be in {work_dir}/all.reads.fasta.

You may also want to specify:

  • work_dir: location to create all files (defaults to 'np_clustering'). This can be outside the repo.

  • name: naming prefix for the final sequences (defaults to 'SEQ')

  • pfam_hmm_path: the PFAM HMM file for gene annotation (pfam annotations are empty otherwise)

Other Parameters

See the example config.yaml for the rest of the parameters and their defaults

See the provided example and the snakeamke documentation for full information on snakemake configuration by file and command line, but the basic ideas are:

Command time config

Cofiguration options can be supplied by the command line with the --config option:

$ snakemake -s path/to/Snakefile --config name=my_name work_dir=my_Dir ...

Configfiles

Config value can be supplied by files. These can be formatted as JSON or YAML. Tell snakemake where to find the file with --configfile=

$ snakemake -s /path/to/np_read_clustering/Snakefile --configfile=my.config.yaml

We suggest you copy the provided example into a new file and modify the files accordingly.

parallelization and performance

Single node

The -j flag tells snakemake how many threads are availabe on your computer, and it will run workflow steps in parallel as much as is possible (usually).

multithreaded steps

Some steps in the workflow are mutithreaded (EG: minimap2). You can configure how many threads these teps get in the configuration (EG: mapping_threads).

examples

Note, we don't obother to use the --conda-frontend flag here assuming the conda environments have already been created during the test run above. Mamba is only needed when creaating environments.

Running with a custom config:

$ snakemake --configfile=my.config.yaml -j 40 -p --use-conda

Or use the provided test config an override key values:

$ snakemake --configfile=config.yaml -j 40 -p --use-conda \
 --config all_fasta=/path/to/reads.fasta work_dir=/path/to/output hmm_path=/dbs/PFAM.hmm

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import os
from hit_tables import agg_hit_table

hit_table = str(snakemake.input)
fmt = str(snakemake.params.format).upper()

if os.path.getsize(hit_table) > 0:
    agg_hit_table(hit_table, format=fmt) \
        .to_csv(str(snakemake.output), sep='\t', index=None)
else:
    # input is empty, touch the output
    with open(str(snakemake.output), 'wt') as out_handle:
        pass
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import pandas, numpy, os
from collections import deque
from itertools import cycle
from scipy import stats
from Bio import SeqIO

# load the read lengths from the summary file
read_lens = pandas.read_csv(snakemake.input.read_lens,
                            sep='\t',
                            names=['read_id','sequence_length_template'],
                            index_col='read_id',
                            header=None).sequence_length_template.to_dict()


# process clusters to choose keepers
cluster_data = []
read_clusters = {}
sigma_cutoff = snakemake.params.sigma_cutoff
count_cutoff = snakemake.params.min_cl_size

# loop over clusters in mcl_file
with open(str(snakemake.input.mcl)) as mcl_lines:
    for i,line in enumerate(mcl_lines):

        # get cluster read names
        reads = set(line.strip().split())
        count = len(reads)

        # get cluster read length dist
        cluster_lens = numpy.array([read_lens[r] for r in reads])
        counts, bins = numpy.histogram(cluster_lens, bins=100)
        X = numpy.array([numpy.mean((bins[j], bins[j-1])) for j in range(1,len(bins))])
        mu, sigma = stats.norm.fit(cluster_lens)

        keep = (sigma <= sigma_cutoff and count >= count_cutoff)
        cluster_data.append(dict(num=i, count=count, sigma=sigma, mu=mu,
                                 keep=keep))

        if keep:
            """
            # write read list
            if not os.path.exists(str(snakemake.output.reads)):
                os.makedirs(str(snakemake.output.reads), exist_ok=True)
            with open(f"{output.reads}/cluster.{i}.reads", 'wt') as reads_out:
                reads_out.write("\n".join(reads) + "\n")
            """
            # save cluster id
            for read in reads:
                read_clusters[read] = i

cluster_table = pandas.DataFrame(cluster_data)

## assign groups
# this serves 2 purposes:
#  1) we limit the number of files in each folder (too many files can slow
#     down snakemake)
#  2) we enable running the workflow in chunks (can perform better in some
#     cases)

keepers = cluster_table.query('keep')
num_keepers = keepers.shape[0]

# we want the number of groups, but we can get it from group_size
if 'group_size' in snakemake.config and 'num_groups' not in snakemake.config:
    group_size = snakemake.config['group_size']
    n_groups = int(numpy.ceil(num_keepers/group_size))
else:
    n_groups = snakemake.config.get('num_groups', 100)

# assigne a group to each cluster (round-robin)
groups = cycle(range(n_groups))
cluster_groups = {c:next(groups) for c in keepers['num']}
cluster_table['group'] = [cluster_groups.get(c,None) if k else None
                          for c,k in cluster_table[['num','keep']].values]

# write fasta files
if not os.path.exists(str(snakemake.output.reads)):
    os.makedirs(str(snakemake.output.reads), exist_ok=True)

# limit number of open files with
n_open = 250
open_handle_ids = deque([])
handles = {}
def open_cluster_fasta(i):
    """
    checks for open handle for this scluster and returns it if found

    otherwise closes oldest handle and replaes with new handle for this cluster
    """
    # return open handle if it exists
    try:
        return handles[i]
    except KeyError:
        pass

    # close handle(s) if we have too many
    while len(handles) > n_open - 1:
        # drop oldest
        drop_id = open_handle_ids.popleft()

        # close and delete
        handles[drop_id].close()
        del handles[drop_id]

    group = cluster_groups[i]
    fasta_file = f"{snakemake.output.reads}/group.{group}/cluster.{i}.fasta"
    fd = os.path.dirname(fasta_file)
    if not os.path.exists(fd):
        os.makedirs(fd)
    handle = open(fasta_file, 'at')
    handles[i] = handle
    open_handle_ids.append(i)
    return handle

# loop over all reads and write out
skipped_read_count = 0
for read in SeqIO.parse(snakemake.input.fasta, 'fasta'):
    try:
        cluster = read_clusters[read.id]
    except KeyError:
        # skip if no cluster
        skipped_read_count += 1
        continue

    open_cluster_fasta(cluster).write(read.format('fasta'))

# add row for unclustered reads
for k,v in dict(i=-1, count=skipped_read_count, keep=False).items():
    cluster_table.loc[-1,k] = v

# save cluster table
cluster_table.to_csv(str(snakemake.output.stats), sep='\t',
                                      index=False)
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
import numpy
import os
import pandas
import re

from collections import Counter, defaultdict
from functools import partial
from itertools import zip_longest

from scipy import stats
import matplotlib
matplotlib.use('pdf')
from matplotlib import pyplot as plt, cm, colors
from matplotlib.patches import Polygon
from matplotlib.backends.backend_pdf import PdfPages
from snakemake.rules import Namedlist
from Bio import SeqIO

from hit_tables import parse_blast_m8, BLAST_PLUS
from edl import blastm8

BLACK = (0, 0, 0, 1)

def grouper_trim(iterable, n):
    "Collect data into fixed-length chunks or blocks and trim last chunk (and all null values)"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return ([i for i in group if i is not None] for group in zip_longest(*args, fillvalue=None))

def main(input, output, params):
    """
    input output should be namedlists (from snakemake)
    params should be a dict (so we can fall back to defaults)
    """

    # prodigal amino acid output 
    faa_file = input.faa
    # cluster fasta reads
    fasta_file = input.fasta
    # PFAM results (non-overlapping)
    dom_tbl_U = input.domtbl
    # mcl results
    mcl_file = input.mcl
    # lastal table raw
    lastal_file = input.lastal
    # lastal table aggregated
    agg_file = input.agg

    ## load clusters (just a list of reads in each cluster, sorted by size)
    subclusters = load_clusters(mcl_file, params.get('min_sub_size', 10))

    ## load the fasta, keeping dict of lengths
    cluster_reads = {r.id:r for r in SeqIO.parse(fasta_file, 'fasta')}
    # use just the first word in the read id as a short name
    read_lens = {r.id.split('-')[0]:len(r) for r in cluster_reads.values()}

    ## plot all cluster hitsts, applying sigma cutoff
    subcluster_ids = plot_cluster_hists(subclusters, read_lens, agg_file, output.stats, output.hist_pdf, params)

    ## make the synteny plots for each good subcluster
    plot_subcluster_synteny(subcluster_ids, subclusters, read_lens, lastal_file, faa_file, dom_tbl_U, output.gene_pdf, params)

    ## write out sub cluster fasta
    os.makedirs(str(output.fasta_dir), exist_ok=True)
    for subcluster_id in subcluster_ids:
        with open(str(output.fasta_dir) + f"/subcluster.{subcluster_id}.fasta", 'wt') as fasta_out:
            for read_id in subclusters[subcluster_id]:
                fasta_out.write(cluster_reads[read_id].format('fasta'))

def plot_cluster_hists(subclusters,
                       read_lens,
                       agg_file,
                       stats_file,
                       pdf_file,
                       params
                      ):
    """
    For each subcluster plot:
     * histogram aod all read-read mfracs
     * histogram of all read lengths with overlaid normal dist
    """

    # open PDF file
    pdf = PdfPages(pdf_file)

    mx_len = max(read_lens.values())
    mn_len = min(read_lens.values())
    window = [mn_len, mx_len]

    sigma_cutoff = params.get('sigma_cutoff', -1)

    # first pass to chose subclusters to keep and plot
    cluster_stats = {}
    for i, subcluster in enumerate(subclusters):
        keep = True
        if len(subcluster) < params.get('min_sub_size', 10):
            break

        # calculate best normal fit to length dist
        cluster_lens = numpy.array([read_lens[r.split('-')[0]] for r in subcluster])
        counts, bins = numpy.histogram(cluster_lens, bins=100, range=window)
        #from scipy import stats
        mu, sigma = stats.norm.fit(cluster_lens)

        if sigma_cutoff > 0 and sigma > sigma_cutoff:
            keep = False

        # calculate the stats
        X = numpy.array([numpy.mean((bins[i], bins[i-1])) for i in range(1,len(bins))])
        tot_in, tot_out, n_in, n_out = numpy.zeros(4)
        for x, count in zip(X, counts):
            if x < mu - sigma or x > mu + sigma:
                tot_out += count
                n_out += 1
            else:
                tot_in += count
                n_in += 1
        mean_in = tot_in / n_in
        mean_out = tot_out / n_out if n_out > 0 else 0
        ratio = mean_in / mean_out
        n_ratio = n_in / (n_out + n_in)

        cluster_stats[i] = dict(zip(
            ['mu', 'sigma', 'ratio', 'n_ratio', 'N', 'keep', 'counts', 'bins', 'X'],
            [mu, sigma, ratio, n_ratio, len(subcluster), keep, counts, bins, X]
        ))

    # build cluster stats table
    write_cols = ['mu', 'sigma', 'ratio', 'n_ratio', 'N', 'keep']
    cl_st_table = pandas.DataFrame([[i,] + [d[k] for k in write_cols] 
                                    for i,d in cluster_stats.items()],
                                   columns=['index'] + write_cols)
    # write stats to file
    cl_st_table.to_csv(stats_file, sep='\t', index=None)

    # pull out list of good subclusters
    subcluster_ids = list(cl_st_table.query('keep').index)

    # load agg hits
    agg_table = pandas.read_csv(agg_file, sep='\t')

    # max 8 per page
    mx_rows = 8
    for page_sc_ids in grouper_trim(cluster_stats.keys(), mx_rows):
        N = len(page_sc_ids)
        fig, axes = plt.subplots(N, 4, figsize=[11 * N / mx_rows, 8.5], sharey="col", sharex="col", squeeze=False)
        fig.subplots_adjust(hspace=.7, wspace=.6)

        ax_rows = iter(axes)
        for i, subcluster_id in enumerate(page_sc_ids):

            axs = next(ax_rows)

            # remove axes from top and right
            for ax in axs:
                for side in ['top', 'right']:
                    ax.spines[side].set_visible(False)

            ax_sc_mf, ax_sc_id, ax_h_mf, ax_h_ln = axs

            # get the subset of the agg table for this subcluster
            subcluster = set(subclusters[subcluster_id])
            sub_slice = (agg_table['query'].apply(lambda q: q in subcluster)
                         & agg_table.hit.apply(lambda h: h in subcluster))
            agg_hits_cluster = agg_table[sub_slice] \
                .eval('mean_len = (hlen + qlen) / 2') \
                .eval('frac = mlen / mean_len')
            mfrac_dict = agg_hits_cluster.set_index(['query','hit']).mfrac.to_dict()

            # scatter plot mfrac and mean length
            ax_sc_mf.scatter(agg_hits_cluster.mfrac.values,
                             agg_hits_cluster.mean_len.values,
                             marker='.',
                             alpha=.5
                            )
            ax_sc_mf.set_ylabel ('mean_len')

            # scatter plot of pctid and matched fraction
            ax_sc_id.scatter(agg_hits_cluster.pctid.values,
                             agg_hits_cluster.frac.values,
                             marker='.',
                             alpha=.5
                            )
            ax_sc_id.set_ylabel ('frac aln')


            # plot hist of pairwise mfracs
            h = ax_h_mf.hist(get_mfracs(subcluster, mfrac_dict=mfrac_dict), bins=100, range=[50,100])

            # plot hist of read lens
            sc_stats = cluster_stats[subcluster_id]
            counts = sc_stats['counts']
            X = sc_stats['X']

            # recreate histogram from counts and X
            ax_h_ln.bar(X, counts, color='blue')

            # overlay norm dist
            best_fit_line = stats.norm.pdf(X, sc_stats['mu'], sc_stats['sigma'])
            best_fit_line = best_fit_line * counts.sum() / best_fit_line.sum()
            p = ax_h_ln.plot(X, best_fit_line, color='red', alpha=.5)


            ax_h_mf.set_ylabel(f"s.cl: {subcluster_id}")
            ax_h_ln.set_ylabel(f"{len(subcluster)} {int(sc_stats['sigma'])}")

            if i == N - 1:
                xl = ax_sc_mf.set_xlabel("score")
                xl = ax_h_ln.set_xlabel("length")
                xl = ax_sc_id.set_xlabel ('match %ID')
                xl = ax_h_mf.set_xlabel ('score')

        # close plot and go to next pdf page
        pdf.savefig(bbox_inches='tight')
        plt.close()

    pdf.close()

    # save stats to file, but drop extra data first
    write_cols = ['mu', 'sigma', 'ratio', 'n_ratio', 'N']
    pandas.DataFrame([[i,] + [d[k] for k in write_cols] 
                      for i,d in cluster_stats.items()],
                     columns=['index'] + write_cols).to_csv(stats_file, sep='\t', index=None)

    return subcluster_ids

def get_N_colors(N, cmap_name='Dark2'):
    """ given N and a colormap, get N evenly spaced colors"""
    color_map=plt.get_cmap(cmap_name)
    return [color_map(c) for c in numpy.linspace(0, 1, N)]

def get_scaled_color(value, minv=0, maxv=1, alpha=.75, reverse=False, cmap_name='cool'):
    colormap = plt.get_cmap(cmap_name)
    if reverse:
        maxv, minv = minv, maxv
    rangev = maxv - minv
    color = colormap((value - minv) / rangev)
    return color[:3] + (alpha,)

def get_mfracs(reads, mfrac_dict):
    return [mfrac_dict.get((r1, r2), 0)
            for r1 in reads
            for r2 in reads
            if r2 > r1
           ]

def plot_subcluster_synteny(subcluster_ids,
                            subclusters,
                            read_lens,
                            lastal_file,
                            faa_file,
                            dom_tbl_U,
                            pdf_file,
                            params
                           ):
    """
    For each subcluster:
     * identify the N genes that appear in the most reads
     * identify the M reads that have the most of the top genes
     * plot
    """

    ## load the gene annotations
    # first get positions from faa headers
    read_genes = {}
    for gene in SeqIO.parse(faa_file, 'fasta'):
        gene_id, start, end, strand, _ = [b.strip() for b in gene.description.split("#")]
        read, name, gene_no = re.search(r'^((\w+)-[^_]+)_(\d+)', gene_id).groups()

        read_genes.setdefault(name, []).append(dict(
            gene_id=gene_id,
            start=int(start),
            end=int(end),
            strand=int(strand),
            num=int(gene_no),
            pfam=None,
        ))

    # convert to dict of DataFrames from dict of lists of dicts
    read_genes_tables = {read:pandas.DataFrame(genes).set_index('gene_id')
                         for read, genes in read_genes.items()}

    # and add PFAM annotations
    for read, hits in blastm8.generate_hits(dom_tbl_U, format='hmmsearchdom'):
        read_id = read.split("-")[0]
        read_genes_table = read_genes_tables[read_id]

        for hit in hits:
            gene_id = hit.read

            # only assign PFAm if it's the first hit for the gene
            if pandas.isna(read_genes_table.loc[gene_id, 'pfam']):
                pfam = hit.hit
                read_genes_table.loc[gene_id, 'pfam'] = pfam

    # load all the read to read hits
    read_hits = parse_blast_m8(lastal_file, format=BLAST_PLUS)

    # now open the PDF file
    pdf = PdfPages(pdf_file)


    # for each good subcluster
    for subcluster_id in subcluster_ids:
        subcluster = set(subclusters[subcluster_id])
        subcluster_names = {r.split('-')[0]:r for r in subcluster}

        fig = plot_subcluster_genes(subcluster_id, subcluster_names, read_genes_tables, read_hits, read_lens, params)

        # close plot and go to next pdf page
        pdf.savefig(bbox_inches='tight')
        plt.close()

    pdf.close()

def plot_subcluster_genes(subcluster_id, subcluster_names, read_genes_tables, read_hits, read_lens, params):
    """
    make a plot of gene positions:
     ax1 has a scatter plot of mean position by pfam
     ax2 has aligned genomes with top pfams colored
    """

    # get the positions of the named PFAMs
    pf_positions = defaultdict(list)
    for read, gene_table in read_genes_tables.items():
        if read in subcluster_names:
            # do we want to flip the read dir? (too many strand < 1)
            reverse = gene_table.eval('glen = strand * (end - start)').glen.sum() < 1
            for start, end, pfam in gene_table[['start','end','pfam']].values:
                if pandas.isna(pfam):
                    continue
                if reverse:
                    start, end = [read_lens[read] - p for p in (start, end)]
                # add mean post to list for this pfam
                pf_positions[pfam].append((end + start) / 2)

    # chose which genes to color
    N = params.get('max_colored_genes', 8)
    sorted_genes = sorted(pf_positions.keys(), key=lambda k: len(pf_positions[k]), reverse=True)
    top_N_pfams = sorted_genes[:N]
    gene_color_dict = dict(zip(top_N_pfams, get_N_colors(N, cmap_name=params.get('gene_cmap', 'Dark2'))))


    # chose which reads to draw
    M = params.get('max_synteny_reads', 20)
    def count_top_pfams_in_read(read):
        if read in read_genes_tables:
            return sum(1 for p in read_genes_tables[read].pfam.values
                       if p in top_N_pfams)
        return 0
    top_M_reads = sorted(subcluster_names, 
                         key=count_top_pfams_in_read,
                         reverse=True,
                        )[:M]
    m = len(top_M_reads)

    # calculate the sizes necessary to draw genes using the matplotlib arrow function
    align_height = (7 * (m-1) / (M-1)) #use up to 7 in
    figsize = [8.5, 4 + align_height]

    fig, axes = plt.subplots(2,1, figsize=figsize, gridspec_kw={'height_ratios':[4,align_height]}, sharex='col')
    fig.subplots_adjust(hspace=.1,)

    ## draw gene positions
    ax = axes[0]

    ax.set_title(f'PFAM annotations in subcluster {subcluster_id}')

    n = params.get('max_plotted_genes', 18)
    sorted_pf = sorted([p for p in sorted_genes[:n] if len(pf_positions[p]) > 1], 
                       key=lambda p: numpy.mean(list(pf_positions[p])))
    for i, p in enumerate(sorted_pf):
        x,y = zip(*((gp,i) for gp in pf_positions[p]))
        ax.scatter(x,y, 
                   c=len(y) * [gene_color_dict.get(p, BLACK)], 
                   ec=None, alpha=.5)
    yt = ax.set_yticks(range(len(sorted_pf)))
    ytl = ax.set_yticklabels(sorted_pf)
    for label in ytl:
        label.set_color(gene_color_dict.get(label.get_text(), BLACK))

    ## draw alignments
    ax = axes[-1]
    min_x = 0
    max_x = max(read_lens[r] for r in subcluster_names)
    range_x = max_x - min_x
    range_y = M
    thickness = .5
    head_length = range_x * (thickness / range_y) * (figsize[1] / figsize[0])

    cmap = params.get('read_cmap','cool')

    min_pctid = read_hits.pctid.min()
    pctid_range = 100 - min_pctid

    get_conn_color = partial(get_scaled_color, minv=min_pctid, maxv=100, alpha=.75, cmap_name=cmap)

    y = 0
    pad = .1
    prev_read = None
    for name in top_M_reads:
        read = subcluster_names[name]
        read_length = read_lens[name]
        if name in read_genes_tables:

            gene_table = read_genes_tables[name]

            # do we want to flip the read dir? (too many strand < 1)
            reverse = gene_table.eval('glen = strand * (end - start)').glen.sum() < 1

            # draw genes
            for start, end, strand, pfam in gene_table[['start','end','strand','pfam']].values:
                if reverse:
                    strand = -1 * strand
                    start = read_length - start
                    end = read_length - end

                strand = int(strand)
                hl = min(head_length, end-start)
                al = max((end - start) - hl, .0001) * strand
                ast = start if al > 0 else end
                color = gene_color_dict.get(pfam, 'k')
                plt.arrow(ast, y, al, 0, fc=color, ec=color, 
                          lw=0,
                          width=thickness, head_width=thickness, 
                          head_length=hl, 
                          head_starts_at_zero=(int(strand) > 0))
        else:
            reverse=False

        # connect matched segments for read pairs
        if prev_read is not None:
            # get hits between reads
            pair_hits = read_hits.query(f'(hit == "{read}" and query == "{prev_read}") or '
                                        f'(query == "{read}" and hit == "{prev_read}")') \
                                 .query('hit != query') \
                                 .sort_values('score', ascending=True)
            # loop over hits
            cols = ['query', 'hit', 'qstart', 'qend', 'hstart', 'hend', 'pctid']
            for query, hit, qstart, qend, hstart, hend, pctid in pair_hits[cols].values:
                # if hit was recorded the other way, flip hit/query
                if query == prev_read:
                    qstart, qend, hstart, hend = hstart, hend, qstart, qend

                # if either read is reversed, flip x coordinates 
                if reverse:
                    qstart = read_length - qstart
                    qend = read_length - qend
                if prev_rev:
                    hstart = prev_len - hstart
                    hend = prev_len - hend

                # draw connecting paralellogram
                color = get_conn_color(pctid, alpha=.9)
                xy = numpy.array([(hstart, y-1+pad),
                                  (qstart, y-pad),
                                  (qend, y-pad),
                                  (hend, y-1+pad)])
                ax.add_patch(Polygon(xy, fc=(.6,.6,.6,.2), ec=color))   

        # save read info for next one
        prev_read = read
        prev_rev = reverse
        prev_len = read_length

        # increment y value
        y += 1

    x = plt.xlim(min_x - 50, max_x + 50)
    y = plt.ylim(-.5, y - .5)

    plt.yticks(list(range(m)), top_M_reads)
    plt.xlabel('read position')

    cax = plt.axes([0.95, 0.15, 0.025, 0.4 * (align_height / 7)])
    plt.colorbar(mappable=cm.ScalarMappable(norm=colors.Normalize(min_pctid, 100), cmap=cmap), cax=cax)    
    cl = cax.set_ylabel('alignment %ID')

    return fig

def load_clusters(mcl_file, size_cutoff=10):
    with open(mcl_file) as mcl_lines:
        return [c for c in [line.strip().split() for line in mcl_lines] if len(c) >= size_cutoff]

# scriptify
if __name__ == "__main__":
    try:
        # assume this is called from snakemake
        input = snakemake.input
        output = snakemake.output
        params = dict(snakemake.params.items())
    except NameError:
        # TODO: fallback to argparse if we call from the command line (for testing)
        import argparse
        raise Exception("Currently only works from snakemake, sorry")

    main(input, output, params)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os, re, pandas
from Bio import SeqIO

def main(input, output, params):
    stats = {}
    files = {}
    groups = {}

    logger.debug("Collecting polished sequences")

    # get fasta files and gene stats
    for sc_comp_file in input:
        sc_dir = os.path.dirname(sc_comp_file)
        group, cluster, subcluster = \
            re.search(r'group.(\d+).+cluster\.(\d+).+subcluster\.(\d+)', sc_dir).groups()
        cl_sc_id = f"{cluster}_{subcluster}"
        files[cl_sc_id] = dict(
            fasta=sc_dir + "/medaka.fasta",
            faa  =sc_dir + "/medaka.faa"
        )
        gene_stats = \
            pandas.read_csv(sc_dir + "/medaka.v.drafts.gene.lengths",
                            sep='\t',
                            index_col=0) \
                  .loc['medaka']
        stats.setdefault(cluster, {})[subcluster] = \
            {f"gene_{k}": v for k,v in gene_stats.items()}
        groups[cluster] = group

    logger.debug("Found {len(files)} poilished seqs from {len(stats)} clusters")

    # get read len stats for subclusters
    for cluster in stats:
        sc_tsv = (f"{params.work_dir}/refine_lastal/group.{groups[cluster]}"
                  f"/cluster.{cluster}/subclusters/cluster_stats.tsv")

        sc_stats = pandas.read_csv(sc_tsv, sep='\t', index_col=0)
        for sc, row in sc_stats.iterrows():
            sc = str(sc)
            if sc not in stats[cluster]:
                continue
            stats[cluster][sc].update(dict(
                read_len_mean=row['mu'],
                read_len_dev=row['sigma'],
                read_count=row['N']))

    # convert stats to table
    df = pandas.DataFrame({f"{cl}_{sc}": sc_stats
                        for cl, cl_stats in stats.items()
                        for sc, sc_stats in cl_stats.items()},) \
            .T

    # polsihed fasta
    with open(str(output.fasta), 'wt') as fasta_out:
        for cl_sc_id in files:
            for read in SeqIO.parse(files[cl_sc_id]['fasta'], 'fasta'):
                np_read = read.id
                read.id = f"{params.name}_{cl_sc_id}"
                df.loc[cl_sc_id, 'length'] = len(read)
                N = df.loc[cl_sc_id, 'read_count']
                read.description = f"n_reads={N};rep_read={np_read}"
                fasta_out.write(read.format('fasta'))

    # write out stats table
    df.to_csv(str(output.stats), sep='\t')

    with open(str(output.faa), 'wt') as faa_out:
        for cl_sc_id in files:
            N = 0
            for gene in SeqIO.parse(files[cl_sc_id]['faa'], 'fasta'):
                N += 1
                read.id = f"{params.name}_{cl_sc_id}_{N}"
                faa_out.write(read.format('fasta'))

if __name__ == "__main__":
    main(snakemake.input, snakemake.output, snakemake.params)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas, numpy

def main(input, output, params):
    with open(str(output.report), 'wt') as output_handle:
        cluster_stats = pandas.read_csv(str(input.mcl_stats), sep='\t',
                                        index_col=0)
        n_reads = cluster_stats['count'].sum()
        n_clusters = cluster_stats.shape[0]
        n_gt_size_cutoff = cluster_stats.query(f'count >= {params.min_cl_size}').shape[0]
        n_kept = cluster_stats.query('keep').shape[0]

        output_handle.write(f"Cluster Search Results:\n"
                            f"  minimap2 clusters:\n"
                            f"    reads: {n_reads}\n"
                            f"    clusters: {n_clusters}\n"
                            f"    gt_{params.min_cl_size}: {n_gt_size_cutoff}\n"
                            f"    kept_clusters: {n_kept}\n\n")

        # count raw subclusters in mcl files
        n_scs, n_sc_gt_cutoff = 0, 0
        for sc_mcl_file in input.sc_mcls:
            with open(sc_mcl_file) as mcl_lines:
                for line in mcl_lines:
                    n_scs += 1
                    if len(line.strip().split()) > params.min_cl_size:
                        n_sc_gt_cutoff += 1

        # get stats from polished subclusters
        pol_stats = pandas.read_csv(str(input.pol_stats), sep='\t', index_col=0)
        pol_lens = pol_stats[pol_stats['length'].notna()]['length'].values
        n_sc_kept = len(pol_lens)

        output_handle.write(f"  lastal subclusters:\n"
                            f"    subclusters: {n_scs}\n"
                            f"    gt_{params.min_cl_size}: {n_sc_gt_cutoff}\n"
                            f"    kept_subclusters: {n_sc_kept}\n\n")

        output_handle.write(f"  polished seqs:\n"
                            f"    count: {len(pol_lens)}\n"
                            f"    mean: {pol_lens.mean()}\n"
                            f"    max: {pol_lens.max()}\n"
                            f"    min: {pol_lens.min()}\n"
                            f"    median: {numpy.median(pol_lens)}\n"
                            f"    stddev: {pol_lens.std()}\n\n")


if __name__ == "__main__":
    main(snakemake.input, snakemake.output, snakemake.params)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
from hit_tables import agg_hit_table

hit_table = str(snakemake.input)
fmt = str(snakemake.params.format).upper()

# identify sequences that hit longer sequences at > 95%
subseqs = set()
if os.path.getsize(hit_table) > 0:
    aggs = agg_hit_table(hit_table, format=fmt)
    count = 0
    values = aggs.query('hit != query')[['query', 'hit', 'matches', 'qlen', 'hlen']].values
    for query, hit, matches, qlen, hlen in values:
        if hlen < qlen:
            query, hit, qlen, hlen = hit, query, hlen, qlen
        mfracq = matches / qlen
        if mfracq > .95:
            if query.split('_')[1] != hit.split('_')[1]:
                subseqs.add(query)
                count += 1

# write fragment IDs to list
with open(str(snakemake.output), 'wt') as out_handle:
    for seq_id in subseqs:
        out_handle.write(f"{seq_id}\n")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import pandas
from hit_tables import parse_blast_m8, PAF
from Bio import SeqIO

# pick a best read
hits = parse_blast_m8(str(snakemake.input.paf),format=PAF)
hit_matches = hits.groupby(['hit','query']).agg({'matches':sum})
mean_matches = {r:hit_matches.query(f'hit != query and (hit == "{r}" or query == "{r}")').matches.mean() 
                for r in set(i[0] for i in hit_matches.index).union(i[1] for i in hit_matches.index)}
best_matches = sorted(mean_matches.keys(), key=lambda r: mean_matches[r], reverse=True)
ref_read = best_matches[0]

# write out to 2 files
with open(str(snakemake.output.ref), 'wt') as ref_out:
    with open(str(snakemake.output.others), 'wt') as others_out:
        for read in SeqIO.parse(str(snakemake.input.fasta), 'fasta'):
            if read.id == ref_read:
                ref_out.write(read.format('fasta'))
            else:
                others_out.write(read.format('fasta'))
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import pandas, numpy, os
from scipy import stats
import matplotlib
matplotlib.use('pdf')
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sigma_cutoff = snakemake.params.sigma_cutoff
count_cutoff = snakemake.params.min_cl_size

# load the read lengths from the summary file
read_lens = pandas.read_csv(snakemake.input.read_lens,
                            sep='\t', 
                            names=['read_id','sequence_length_template'], 
                            index_col='read_id', 
                            header=None).sequence_length_template.to_dict()

# load the all.v.all mfrac values (but just map pairs where q > h)
mfrac_dict = {tuple(sorted(i)):m
              for i,m in pandas.read_csv(snakemake.input.abc, sep='\t',
                                         names=['q', 'h', 'mfrac'],
                                         header=None, index_col=['q','h']) \
                               .mfrac.items()}

def get_mfracs(reads, mfrac_dict=mfrac_dict):
    return [mfrac_dict.get((r1, r2), 0)
            for r1 in reads
            for r2 in reads
            if r2 > r1
           ]

# load the clusters
with open(snakemake.input.mcl) as mcl_lines:
    all_clusters = [line.strip().split() for line in mcl_lines]

# plots
pdf=PdfPages(snakemake.output.pdf)

ROWS = 20
COLS = 5
N0=0

while True:
    clusters = all_clusters[N0:N0+ROWS*COLS]
    if len(clusters) == 0:
        break
    rows = int(numpy.ceil(len(clusters)/COLS))

    fig, axes = plt.subplots(rows, COLS*2, figsize=[COLS*4,rows], sharex=False,
                            squeeze=False)
    fig.subplots_adjust(hspace=.7, wspace=.6)

    cluster_iter = enumerate(clusters, start=N0)
    axes_list = axes.flatten()
    for i, cluster in enumerate(clusters):
        j = i*2
        ax1, ax2 = axes_list[j:j+2]

        # plot hist of pairwise mfracs
        h = ax1.hist(get_mfracs(cluster, mfrac_dict=mfrac_dict), bins=100,
                     range=[0,100])

        # plot hist of read lens
        cluster_lens = numpy.array([read_lens[r] for r in cluster])
        counts, bins, h_line = ax2.hist(cluster_lens, bins=100, histtype='step')
        X = numpy.array([numpy.mean((bins[j], bins[j-1])) for j in range(1,len(bins))])
        mu, sigma = stats.norm.fit(cluster_lens)

        # overlay norm dist
        best_fit_line = stats.norm.pdf(X, mu, sigma)
        best_fit_line = best_fit_line * counts.sum() / best_fit_line.sum()
        p = ax2.plot(X, best_fit_line, color='red', alpha=.5)

        keep = (sigma <= sigma_cutoff and len(cluster) >= count_cutoff)
        if keep:
            ax1.set_ylabel('keep')

        ax2.set_ylabel(f"c{i} n={len(cluster)}")
        # only put xlabels on bottom plots
        if i >= len(clusters) - COLS:
            xl = ax1.set_xlabel("score")
            xl = ax2.set_xlabel("length")

        # remove axes from top and right
        for ax in [ax1, ax2]:
            for side in ['top', 'right']:
                ax.spines[side].set_visible(False)

    # hide unused axes (on last page)
    for i in range(j+2, len(axes_list)):
        fig.delaxes(axes_list[i])

    pdf.savefig(bbox_inches='tight')
    plt.close()

    N0 += ROWS * COLS
    if len(cluster) < count_cutoff:
        break
pdf.close()
213
shell: 'rm {input}'
SnakeMake From line 213 of main/Snakefile
ShowHide 6 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/jmeppley/np_read_clustering
Name: np_read_clustering
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...