Workflow for detecting PMCs in confocal images

public public 1yr ago Version: v0.1.0 0 bookmarks

This is a repository to perform automatic segmentation and detection of primary mesenchyme cells (PMCs) in 3D confocal images. The workflow uses snakemake and assumes the existence of a trained

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import pandas as pd

if __name__ == '__main__':
    try:
        snakemake
    except NameError:
        snakemake = None
    if snakemake is not None:
        embryo_counts = pd.concat([pd.read_csv(x) for x in snakemake.input['counts']])
        embryo_counts.to_csv(snakemake.output['final'])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
import json
import os
import sys
from collections import namedtuple
import logging
import itertools
from turtle import dot

import numpy as np
import xarray as xr
import skimage
from skimage import exposure, io, morphology, measure
from skimage.filters import threshold_otsu

import bigfish.detection as bf_detection
import bigfish.stack as bf_stack
import bigfish.plot as bf_plot

BoundingBox = namedtuple("BoundingBox", ["ymin", "ymax", "xmin", "xmax"])
try:
    sys.path.append(os.path.dirname(__file__))
except NameError:
    pass
import utils


def normalize_zstack(z_stack, bits):
    """Normalize z-slices in a 3D image using contrast stretching

    Parameters
    ----------
    z_stack : numpy.ndarray
        3 dimensional confocal FISH image.
    bits : int
        Bit depth of image.

    Returns
    -------
    numpy.ndarray
        Z-corrected image with each slice minimum and maximum matched
    """
    out = np.array(
        [
            exposure.rescale_intensity(
                x, in_range=(0, 2 ** bits - 1), out_range=(z_stack.min(), z_stack.max())
            )
            for x in z_stack
        ]
    )
    return skimage.img_as_uint(exposure.rescale_intensity(out))


def read_bit_img(img_file, bits=12):
    """Read an image and return as a 16-bit image."""
    img = exposure.rescale_intensity(
        io.imread(img_file),
        in_range=(0, 2 ** (bits) - 1)
        #         out_range=(0, )
    )
    return skimage.img_as_uint(img)


def select_signal(image, p_in_focus=0.75, margin_width=10):
    """
    Generate bounding box of FISH image to select on areas where signal is present.

    Parameters
    ----------
    image : np.ndarray
        3D FISH image
    p_in_focus : float, optional
        Percent of in-focus slices to retain for 2D projection, by default 0.75.
    margin_width : int, optional
        Number of pixels to pad selection by. Default is 10.

    Returns
    -------
    namedtuple
        minimum and maximum coordinate values of the bounding box in the xy plane
    """
    image = image.astype(np.uint16)
    focus = bf_stack.compute_focus(image)
    selected = bf_stack.in_focus_selection(image, focus, p_in_focus)
    projected_2d = bf_stack.maximum_projection(selected)
    foreground = np.where(projected_2d > threshold_otsu(projected_2d))
    limits = BoundingBox(
        ymin=max(foreground[0].min() - margin_width, 0),
        ymax=min(foreground[0].max() + margin_width, image.shape[1]),
        xmin=max(foreground[1].min() - margin_width, 0),
        xmax=min(foreground[1].max() + margin_width, image.shape[2]),
    )
    return limits


def crop_to_selection(img, bbox):
    """
    Crop image to selection defined by bounding box.

    Crops a 3D image to specified x and y coordinates.
    Parameters
    ----------
    img : np.ndarray
        3Dimensional image to crop
    bbox : namedtuple
        Tuple defining minimum and maximum coordinates for x and y planes.

    Returns
    -------
    np.ndarray
        3D image cropped to the specified selection.
    """
    return img[:, bbox.ymin : bbox.ymax, bbox.xmin : bbox.xmax]


def count_spots_in_labels(spots, labels):
    """
    Count the number of RNA molecules in specified labels.

    Parameters
    ----------
    spots : np.ndarray
        Coordinates in original image where RNA molecules were detected.
    labels : np.ndarray
        Integer array of same shape as `img` denoting regions to interest to quantify.
        Each separate region should be uniquely labeled.

    Returns
    -------
    dict
        dictionary containing the number of molecules contained in each labeled region.
    """
    assert spots.shape[1] == len(labels.shape)
    n_labels = np.unique(labels) - 1  # subtract one for backgroudn
    counts = {i: 0 for i in range(1, n_labels + 1)}
    for each in spots:
        if len(each) == 3:
            cell_label = labels[each[0], each[1], each[2]]
        else:
            cell_label = labels[each[0], each[1]]
        if cell_label != 0:
            counts[cell_label] += 1
    return counts


def preprocess_image(
    img, smooth_method="gaussian", sigma=7, whitehat=True, selem=None, stretch=99.99
):
    scaled = exposure.rescale_intensity(
        img, in_range=tuple(np.percentile(img, [0, stretch]))
    )
    if smooth_method == "log":
        smooth_func = bf_stack.log_filter
        to_smooth = bf_stack.cast_img_float64(scaled)
    elif smooth_method == "gaussian":
        smooth_func = bf_stack.remove_background_gaussian
        to_smooth = bf_stack.cast_img_uint16(scaled)
    else:
        raise ValueError(f"Unsupported background filter: {smooth_method}")
    if whitehat:
        f = lambda x, s: morphology.white_tophat(smooth_func(x, s), selem)
    else:
        f = lambda x, s: smooth_func(x, s)
    smoothed = np.stack([f(img_slice, sigma) for img_slice in to_smooth])
    return bf_stack.cast_img_float64(np.stack(smoothed))


def count_spots(
    smoothed_signal,
    cell_labels,
    voxel_size_nm,
    dot_radius_nm,
    smooth_method="gaussian",
    decompose_alpha=0.5,
    decompose_beta=1,
    decompose_gamma=5,
    verbose=False,
):

    if verbose:
        spot_radius_px = bf_detection.get_object_radius_pixel(
            voxel_size_nm=voxel_size_nm, object_radius_nm=dot_radius_nm, ndim=3
        )
        logging.info("spot radius (z axis): %0.3f pixels", spot_radius_px[0])
        logging.info("spot radius (yx plan): %0.3f pixels", spot_radius_px[-1])
    spots, threshold = bf_detection.detect_spots(
        smoothed_signal,
        return_threshold=True,
        voxel_size=voxel_size_nm,
        spot_radius=dot_radius_nm,
    )
    if verbose:
        logging.info("%d spots detected...", spots.shape[0])
        logging.info("plotting threshold optimization for spot detection...")
        bf_plot.plot_elbow(
            smoothed_signal,
            voxel_size=voxel_size_nm,
            spot_radius=dot_radius_nm,
        )
    decompose_cast = {
        "gaussian": bf_stack.cast_img_uint16,
        "log": bf_stack.cast_img_float64,
    }
    try:
        (
            spots_post_decomposition,
            dense_regions,
            reference_spot,
        ) = bf_detection.decompose_dense(
            decompose_cast[smooth_method](smoothed_signal),
            spots,
            voxel_size=voxel_size_nm,
            spot_radius=dot_radius_nm,
            alpha=decompose_alpha,  # alpha impacts the number of spots per candidate region
            beta=decompose_beta,  # beta impacts the number of candidate regions to decompose
            gamma=decompose_gamma,  # gamma the filtering step to denoise the image
        )
        logging.info(
            "detected spots before decomposition: %d\n"
            "detected spots after decomposition: %d",
            spots.shape[0],
            spots_post_decomposition.shape[0],
        )
        if verbose:
            print(
                f"detected spots before decomposition: {spots.shape[0]}\n"
                f"detected spots after decomposition: {spots_post_decomposition.shape[0]}\n"
                f"shape of reference spot for decomposition: {reference_spot.shape}"
            )
            bf_plot.plot_reference_spot(reference_spot, rescale=True)
    except RuntimeError:
        logging.warning("decomposition failed, using originally identified spots")
        spots_post_decomposition = spots
    n_labels = len(np.unique(cell_labels)) - 1
    counts = {i: 0 for i in range(1, n_labels + 1)}
    expression_3d = np.zeros_like(smoothed_signal)
    # get slices to account for cropping
    for each in spots_post_decomposition:
        spot_coord = tuple(each)
        cell_label = cell_labels[spot_coord]
        if cell_label != 0:
            counts[cell_label] += 1
    for region in measure.regionprops(cell_labels):
        expression_3d[region.slice][region.image] = counts[region.label]
    return counts, expression_3d


def average_intensity(smoothed_signal, cell_labels):
    n_labels = len(np.unique(cell_labels)) - 1
    intensities = {i: 0 for i in range(1, n_labels + 1)}
    z_normed_smooth = (smoothed_signal - smoothed_signal.mean()) / smoothed_signal.std()
    expression_3d = np.zeros_like(z_normed_smooth)
    for region in measure.regionprops(cell_labels, z_normed_smooth):
        intensities[region.label] = region.mean_intensity
        expression_3d[region.slice][region.image] = region.mean_intensity
    return intensities, expression_3d


def quantify_expression(
    fish_img,
    cell_labels,
    measures=["spots", "intensity"],
    voxel_size_nm=None,
    dot_radius_nm=None,
    whitehat=True,
    whitehat_selem=None,
    smooth_method="gaussian",
    smooth_sigma=1,
    decompose_alpha=0.5,
    decompose_beta=1,
    decompose_gamma=5,
    bits=12,
    crop_image=True,
    verbose=False,
):
    """
    Count the number of molecules in an smFISH image

    Parameters
    ----------
    fish_img : np.ndarray
        Image in which to perform molecule counting
    cell_labels : np.ndarray
        Integer array of same shape as `img` denoting regions to interest to quantify.
        Each separate region should be uniquely labeled.
    measures : list
        Measures to use to quantify expression. Possible values are "spots",
        "intensity", and ["spots", "intensity"]. Default is to measure both
    voxel_size_nm : tuple(float, int), None
        Physical dimensions of each voxel in ZYX order. Required if running spot
        counting.
    dot_radius_nm : tuple(float, int), None
        Physical size of expected dots. Required if running spot
        counting.
    whitehat : bool, optional
        Whether to perform white tophat filtering prior to image de-noising, by default True
    whitehat_selem : [int, np.ndarray], optional
        Structuring element to use for white tophat filtering.
    smooth_method : str, optional
        Method to use for image de-noising. Possible values are "log" and "gaussian" for
        Laplacian of Gaussians and Gaussian background subtraction, respectively. By default "log".
    smooth_sigma : [int, np.ndarray], optional
        Sigma value to use for smoothing function, by default 1
    decompose_alpha : float, optional
        Intensity percentile used to compute the reference spot, between 0 and 1.
        By default 0.7. For more information, see:
        https://big-fish.readthedocs.io/en/stable/detection/dense.html
    decompose_beta : int, optional
        Multiplicative factor for the intensity threshold of a dense region,
        by default 1. For more information, see:
        https://big-fish.readthedocs.io/en/stable/detection/dense.html
    decompose_gamma : int, optional
        Multiplicative factor use to compute a gaussian scale, by default 5.
        For more information, see:
        https://big-fish.readthedocs.io/en/stable/detection/dense.html
    bits : int, optional
        Bit depth of original image. Used for scaling image while maintaining
        ob
    crop_image : bool, optional
        Whether to crop signal. Default is True.
    verbose : bool, optional
        Whether to verbosely print results and progress.

    Returns
    -------
    (np.ndarray, dict)
        np.ndarray: positions of all identified mRNA molecules.
        dict: dictionary containing the number of molecules contained in each labeled region.
    """
    if (voxel_size_nm is None or dot_radius_nm is None) and "spots" in measures:
        raise ValueError(
            "Require `voxel_size_nm` and `dot_radius_nm` when performing spot counting."
        )
    if crop_image:
        if verbose:
            logging.info("Cropping image to signal")
        limits = select_signal(fish_img)
        if verbose:
            logging.info(
                "Cropped image to %d x %d",
                {limits.ymax - limits.ymin},
                {limits.xmax - limits.xmin},
            )
    else:
        # create BoudndingBox that selects whole image
        limits = BoundingBox(
            ymin=0, ymax=fish_img.shape[1], xmin=0, xmax=fish_img.shape[2]
        )
    if verbose:
        logging.info("Preprocessing image.")
    cropped_img = skimage.img_as_float64(
        exposure.rescale_intensity(
            crop_to_selection(fish_img, limits),
            in_range=(0, 2 ** bits - 1),
            out_range=(0, 1),
        )
    )
    smoothed = preprocess_image(
        cropped_img, smooth_method, smooth_sigma, whitehat, whitehat_selem, 99.99
    )

    cropped_labels = crop_to_selection(cell_labels, limits)
    quant = dict()
    if "spots" in measures:
        counts, counts_3d = count_spots(
            smoothed,
            cropped_labels,
            voxel_size_nm=voxel_size_nm,
            dot_radius_nm=dot_radius_nm,
            smooth_method=smooth_method,
            decompose_alpha=decompose_alpha,
            decompose_beta=decompose_beta,
            decompose_gamma=decompose_gamma,
            verbose=verbose,
        )
        quant["spots"] = counts
        if crop_image:  # match original shape if cropped
            counts_3d = utils.pad_to_shape(counts_3d, fish_img.shape)
    if "intensity" in measures:
        intensities, intense_3d = average_intensity(smoothed, cropped_labels)
        quant["intensity"] = intensities
        if crop_image:  # match original shape if cropped
            intense_3d = utils.pad_to_shape(intense_3d, fish_img.shape)
    if len(measures) == 2:
        expression_3d = np.stack([counts_3d, intense_3d])
    return (
        quant,
        expression_3d,
    )


def get_quant_measure(method):
    """Get method of quantification"""
    if method == "spots":
        return ["spots"]
    elif method == "intensity":
        return ["intensity"]
    elif method == "both":
        return ["spots", "intensity"]
    else:
        raise ValueError(f"Unrecognized method {method}.")


if __name__ == "__main__":
    import h5py
    import pandas as pd

    try:
        snakemake
    except NameError:
        snakemake = None
    if snakemake is not None:
        logging.basicConfig(filename=snakemake.log[0], level=logging.INFO)
        raw_img, dimensions = utils.read_image_file(
            snakemake.input["image"], as_nm=True
        )
        if dimensions is None and snakemake.input["image"].endswith(".h5"):
            with open(snakemake.params["dimensions"], "r") as handle:
                data = json.load(handle)
                dimensions = np.array([data[c] for c in "zyx"]) * (10 ** 3)
        labels = np.array(h5py.File(snakemake.input["labels"], "r")["image"])
        logging.info("%d labels detected.", len(np.unique(labels) - 1))
        start = int(snakemake.params["z_start"])
        stop = int(snakemake.params["z_end"])
        if stop < 0:
            stop = raw_img.shape[1]
        gene_params = snakemake.params["gene_params"]

        def has_probe_info(name, gene_params):
            if name not in gene_params.keys():
                logging.warning(
                    "No entry for %s found in gene parameters. Not quantifying signal",
                    name,
                )
                return False
            return True

        channels = {
            x: i
            for i, x in enumerate(snakemake.params["channels"].split(";"))
            if has_probe_info(x, gene_params)
        }
        if len(channels) < 0:
            raise ValueError(
                f"No quantification parameters provided for channels: {snakemake.params['channels'].replace(';', ', ')}"
            )
        genes = list(channels.keys())
        fish_exprs = {}
        summarized_images = [None] * len(channels)
        embryo = snakemake.wildcards["embryo"]
        measures = get_quant_measure(snakemake.params["quant_method"])
        for i, gene in enumerate(genes):
            logging.info(f"Quantifying {gene} signal...")
            fish_data = raw_img[channels[gene], start:stop, :, :]
            quant, image = quantify_expression(
                fish_data,
                labels,
                measures=measures,
                voxel_size_nm=dimensions.tolist(),
                dot_radius_nm=gene_params[gene]["radius"],
                whitehat=True,
                smooth_method="gaussian",
                smooth_sigma=7,
                verbose=True,
                bits=12,
                crop_image=snakemake.params["crop_image"],
            )
            for each in measures:
                fish_exprs[f"{gene}_{each}"] = quant[each]
            summarized_images[i] = image
        # write summarized expression images to netcdf using Xarray to keep
        # track of dims
        out_image = np.array(summarized_images)
        xr.DataArray(
            data=out_image,
            coords={"gene": genes, "measure": measures},
            dims=["gene", "measure", "Z", "Y", "X"],
        ).to_netcdf(snakemake.output["image"])
        exprs_df = pd.DataFrame.from_dict(fish_exprs)
        exprs_df.index.name = "label"
        physical_properties = (
            pd.DataFrame(
                measure.regionprops_table(
                    labels,
                    properties={
                        "label",
                        "centroid",
                        "area",
                        "equivalent_diameter",
                    },
                )
            )
            .rename(
                columns={
                    "centroid-0": "Z",
                    "centroid-1": "Y",
                    "centroid-2": "X",
                    "equivalent_diameter": "diameter",
                }
            )
            .set_index("label")
        )
        out = exprs_df.join(physical_properties)
        out["embryo"] = embryo
        out[
            [
                "embryo",
                "area",
                "diameter",
                "Z",
                "Y",
                "X",
            ]
            + [
                f"{gene}_{measure}"
                for (gene, measure) in itertools.product(genes, measures)
            ]
        ].to_csv(snakemake.output["csv"])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
import h5py
from matplotlib.pyplot import fill
import numpy as np
from scipy import signal, spatial
from scipy import ndimage as ndi
from scipy.spatial.qhull import QhullError
from skimage import filters, measure, morphology, segmentation

import logging


def smooth(x, window_len=11, window="hanning"):
    """
    smooth the data using a window with requested size.

    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.

    input:
        x: the input signal
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
            flat window will produce a moving average smoothing.

    output:
        the smoothed signal

    example:

    t=linspace(-2,2,0.1)
    x=sin(t)+randn(len(t))*0.1
    y=smooth(x)

    see also:

    numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
    scipy.signal.lfilter

    Source
    ------
    https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    """

    if x.ndim != 1:
        raise ValueError("smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        raise ValueError("Input vector needs to be bigger than window size.")

    if window_len < 3:
        return x

    if not window in ["flat", "hanning", "hamming", "bartlett", "blackman"]:
        raise ValueError(
            "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"
        )

    s = np.r_[x[window_len - 1 : 0 : -1], x, x[-2 : -window_len - 1 : -1]]
    if window == "flat":  # moving average
        w = np.ones(window_len, "d")
    else:
        w = eval("np." + window + "(window_len)")

    y = np.convolve(w / w.sum(), s, mode="valid")
    return y[(window_len // 2) : -(window_len // 2 - 1)]


def assign_to_label(src, region, slc, new_label):
    """Assign image region to label

    Parameters
    ----------
    src : numpy.ndarray
        Label-containing image.
    region : skimage.measure.RegionProperties
        Region to label
    slc : slice
        Slice defining which part of the region to label. If None, selects
        entire region
    new_label : int
        New label to set
    """
    if slc is None:
        slc = tuple([slice(None)] * len(region.image.shape))
    src[region.slice][slc][region.image[slc]] = new_label


def get_z_regions(region):
    """Break 3D region into separate 2D regions.

    Parameters
    ----------
    region : skimage.measure.RegionProperties
        3D regino to break down

    Returns
    -------
    list[skimage.measure.RegionProperties]
        Separte 2D regions comprising `region`
    """
    return [measure.regionprops(x.astype(int))[0] for x in region.image]


def split_labels_by_area(labels, region):
    """Split a label based on local minimas of measured areas.

    Splits a long label along the z-axis at points of locally minimal area.
    Necessary to separate stacked PMCs.

    Parameters
    ----------
    labels : np.ndarray
        3D image containing labelled regions
    region : skimage.measure.RegionProperties
        Long region to split

    Returns
    -------
    int
        The new total number of objects found in `labels`.
    """
    z_regions = get_z_regions(region)
    areas = np.array([x.area for x in z_regions])
    splits = signal.argrelextrema(smooth(areas, 2, "hamming"), np.less)[0]
    n_labels = labels.max()
    for i, split in enumerate(splits):
        new_label = n_labels + 1
        if split != splits[-1]:
            z_slice = slice(split, splits[i + 1])
        else:
            z_slice = slice(split, None)
        assign_to_label(labels, region, (z_slice, slice(None), slice(None)), new_label)
        n_labels = new_label


def filter_small_ends(labels, region, min_pixels=5):
    """Remove small z-ends of labelled regions

    Parameters
    ----------
    labels : np.ndarray
        3D image containing labelled regions.
    region : skimage.measure.RegionProperties
        Region to filter.
    min_pixels : int, optional
        Minimum number of pixels for a label in each z-slice, by default 5.
    """
    z_regions = get_z_regions(region)
    i = 0
    # forward remove
    while i < len(z_regions) and z_regions[i].area <= min_pixels:
        assign_to_label(labels, region, (i, slice(None), slice(None)), 0)
        i += 1
    # backward remove
    i = len(z_regions) - 1
    while i >= 0 and z_regions[i].area <= min_pixels:
        assign_to_label(labels, region, (i, slice(None), slice(None)), 0)
        i -= 1


def backpropogate_split_labels(z_stack, labels):
    """Propogates split labels in a Z stack to lower slices.

    Parameters
    ----------
    z_stack : numpy.ndarray
        3D image stack from z=0 ... z=Z
    labels : numpy.ndarray
        Labels for current z-slice, should be z=Z slice.

    Returns
    -------
    numpy.ndarray
        Newly labelled `z_stack`
    """
    new_stack = np.zeros_like(z_stack, dtype=int)
    new_stack[-1, :, :] = labels
    for i in range(1, z_stack.shape[0])[::-1]:
        new_stack[i - 1, :, :] = segmentation.watershed(
            np.zeros_like(z_stack[i, :, :]),
            markers=new_stack[i, :, :],
            mask=z_stack[i - 1, :, :],
        )
    return new_stack


# def is_disconnected(region):
#     disconnected = False
#     for z in range()
#     segmented = measure.label(region.image):


def split_z_disconeccted_labels(region, n_labels):
    """Split labels along the z-plane.

    Splits label along the z-plane if labels in individual z-slices have
    multiple connected components.

    Parameters
    ----------
    region : skimage.measure.RegionProperty
        Region to check and possibly split.
    n_labels : int
        Total number of current labels.

    Returns
    -------
    tuple : (numpy.ndarray, int)
        numpy.ndarray : Previous label split along z-axis if the label should be
            split. Otherwise, an array of zeros.
        int : Total number of current labels
    """
    new_labs = np.zeros_like(region.image, dtype=int)
    expand_labels = False
    n_new_labels = 0
    n_split = 0
    for i, z_slice in enumerate(region.image):
        segmented = measure.label(z_slice)
        new_labs[i, :, :] = segmented
        z_regions = measure.regionprops(new_labs[i, :, :])
        if len(z_regions) > 1 and not expand_labels:
            n_split += 1
            centroids = [r.centroid for r in z_regions]
            if n_split > 1 or max(spatial.distance.pdist(centroids)) > 5:
                expand_labels = True
                n_new_labels = len(np.unique(segmented)) - 1  # -1 for zeros
                if i != 0:
                    new_labs[: (i + 1), :, :] = backpropogate_split_labels(
                        new_labs[: (i + 1), :, :], new_labs[i, :, :]
                    )

        elif expand_labels and i > 0:
            # fill in z labels with z-1 labels
            z_filled = segmentation.watershed(
                np.zeros_like(new_labs[i, :, :]),
                markers=new_labs[i - 1, :, :],
                mask=new_labs[i, :, :],
            )
            # check if all regions in z labels are accounted for, otherwise
            # create new label
            for r in z_regions:
                if np.isin(0, z_filled[r.slice][r.image]):
                    n_new_labels += 1
                    z_filled[r.slice][r.image] = n_new_labels
            new_labs[i, :, :] = z_filled
    # re-assign label values to avoid conflict with previous labels + maintain
    # consistency
    for r in measure.regionprops(new_labs):
        if r.label == 1:
            assign_to_label(new_labs, r, None, region.label)
        else:
            assign_to_label(new_labs, r, None, n_labels + 1)

            n_labels += 1
    return new_labs, n_labels


def filter_by_area_length_ratio(labels, region, min_ratio=15):
    """Filter cylindrical-like labels from a 3D label image.

    Filters thin, long regions by comparing the (Total Area) / (Z length) ratio.

    Parameters
    ----------
    labels : numpy.ndarray
        Original 3D label image containing labels to filter.
    region : skimage.RegionProperty
        Region to possibly filter
    min_ratio : int, optional
        Minumum area / z length ratio to keep labels By default 15, and any
        label with a smaller ratio will be removed
    """
    if region.area / region.image.shape[0] < min_ratio:
        logging.info(f"Filtering region {region.label} for being too cylindrical")
        assign_to_label(labels, region, None, 0)


def renumber_labels(labels):
    """Renumber labels to that N labels will be labelled from 1 ... N"""
    for i, region in enumerate(measure.regionprops(labels)):
        assign_to_label(labels, region, None, i + 1)


def fill_labels(labels, region):
    """Fill labels using binary closing.

    Parameters
    ----------
    labels : np.ndarray
        Array of object labels
    region : RegionProperties
        Region in image containing label of interest
    """
    filled = morphology.binary_closing(region.image)
    labels[region.slice][filled] = region.label


def check_and_split_separated_labels(labels):
    """Checks for and splits labels containing empty z-slices"""
    for region in measure.regionprops(labels):
        pixels_per_slice = region.image.sum(axis=1).sum(axis=1)
        if 0 in pixels_per_slice:
            logging.info(
                "Z split label %s. Separating into distinct labels.", region.label
            )
            n_slices = len(pixels_per_slice)
            n_labels = labels.max()
            zeros = list(np.where(pixels_per_slice == 0)[0])
            non_zeros = np.where(pixels_per_slice != 0)[0]
            start_stops = []
            start = non_zeros[0]
            stop = n_slices
            while start < n_slices:
                if len(zeros) > 0:
                    current_zero = zeros[0]
                    stop = (
                        non_zeros[non_zeros < current_zero][-1] + 1
                    )  # add 1 bc of non-inclusive indexing
                    start_stops.append((start, stop))
                    non_zeros = np.array([x for x in non_zeros if x > stop])
                    start = non_zeros[0]
                    zeros.pop(0)
                else:
                    start_stops.append((start, n_slices))
                    start = n_slices

            for i, (start, stop) in enumerate(start_stops):
                if i == 0:
                    to_assign = region.label
                else:
                    to_assign = n_labels + 1
                    n_labels += 1
                slc = (slice(start, stop), slice(None), slice(None))
                logging.info("Assigning %s between %d and %d", to_assign, start, stop)
                assign_to_label(labels, region, slc, to_assign)
                # labels[region.slice][start:stop, :, :][
                #     region.image[start:stop, :, :]
                # ] = to_assign


def generate_labels(
    stain,
    pmc_probs,
    p_low=0.5,
    p_high=0.8,
    selem=None,
    max_stacks=7,
    min_stacks=3,
    max_area=600,
):
    """Generate PMC labels within a confocal image.

    Parameters
    ----------
    stain : numpy.ndarray
        3D image containing PMC stain.
    pmc_probs : numpy.ndarray
        3D image containing probabilities of each pixel in `stain` containing a
        PMC.
    p_low : float, optional
        Lower probabilitiy bound for considering a pixel a PMC, by default 0.5.
        Used in `filters.apply_hysteresis_threshold()`
    p_high : float, optional
        Higher probabilitiy bound for considering a pixel a PMC, by default 0.5.
        Used in `filters.apply_hysteresis_threshold()`
    selem : np.ndarray, optional
        Structuring element used for morhpological opening / clsoing. If None,
        uses `skimage` defaults.
    max_stacks : int, optional
        The maximum number of z-slices a label can occupy before assuming stacked
        PMCs, by default 7.
    min_stacks : int, optional
        The minimum number of slices a label should occupy before removing,
        by default 3


    Returns
    -------
    np.ndarray
        3D image containing PMC segmentations.
    """
    if selem is None:
        selem = morphology.disk(2)
    pmc_seg = filters.apply_hysteresis_threshold(pmc_probs, p_low, p_high)
    # seeds = measure.label(morphology.binary_opening(pmc_seg, selem=selem))
    seeds = measure.label(
        np.array([morphology.binary_opening(x, footprint=selem) for x in pmc_seg])
    )
    try:
        gradients = filters.sobel(stain, axis=0)
    except TypeError:
        gradients = np.array([filters.sobel(x) for x in stain])
    labels = segmentation.watershed(
        np.abs(gradients),
        seeds,
        mask=np.stack(
            morphology.closing(
                ndi.binary_fill_holes(x, selem),
                None,
            )
            for x in pmc_probs > 0.5
        ),
    )
    # close up any holes in labels
    for region in measure.regionprops(labels):
        fill_labels(labels, region)

    # check for z-separated labels created by binary filling
    check_and_split_separated_labels(labels)
    # further segment large regions of PMCs using stricter criteria
    for region in measure.regionprops(labels):
        if region.area > max_area:
            strict_pmc_prediction(region, pmc_probs, labels, 0.8)

    # find abnormally long tracks, check for local minima in area size that
    # would indicate stacked pmcs. Additionally clean up small tails in labels.
    for region in measure.regionprops(labels):
        n_stacks = np.unique(region.coords[:, 0]).size
        if n_stacks > max_stacks:
            logging.info(
                "Region %s exceed maximum z-length, splitting by area.", {region.label}
            )
            split_labels_by_area(labels, region)
        elif n_stacks < min_stacks:
            logging.info(
                "Region %s only spans %d z-slices: removing.",
                region.label,
                region.image.shape[0],
            )
            assign_to_label(labels, region, None, 0)
        filter_by_area_length_ratio(labels, region, min_ratio=15)
        filter_small_ends(labels, region, min_pixels=5)

    # renumber labels from 1...N
    renumber_labels(labels)

    return labels


def strict_pmc_prediction(region, pmc_probs, labels, threshold):
    logging.info(
        f"Label {region.label} exceeds area with A={region.area}, attempting to split."
    )
    # smooth probabilities due to strict thresholding
    smoothed = np.array([filters.gaussian(x) for x in pmc_probs[region.slice]])
    split = False
    t = threshold
    n_labels = labels.max()
    # def split_origin_label
    while t < 1 and not split:
        split_labels = measure.label(
            np.array([ndi.binary_opening(x) for x in smoothed > t])
        )

        if split_labels.max() > 1:
            split = True
            logging.info(
                f"Label {region.label} split into {split_labels.max()} regions with t={t}."
            )
            # bounding box may include non-label object, remove from new labels
            split_labels[~region.image] = 0

            # fill in new labels to original area
            flooded = segmentation.watershed(
                np.zeros_like(region.image), split_labels, mask=region.image
            )
            # re-assign label values to avoid conflict with previous labels + maintain
            # consistency
            for r in measure.regionprops(flooded):
                if r.label == 1:
                    assign_to_label(flooded, r, None, region.label)
                else:
                    assign_to_label(flooded, r, None, n_labels + 1)
                    n_labels += 1
                if r.area > 600:
                    logging.debug("area threshold exceeded in split label")
            # assign newly split regions back to original labels
            labels[region.slice][region.image] = flooded[region.image]
        t += 0.025


def find_pmcs(
    stain,
    pmc_probs,
    max_stacks=7,
    min_stacks=3,
    p_low=0.45,
    p_high=0.5,
    selem=None,
    min_area=55,
    max_area=600,
    area_w_diameter=500,
    d_threshold=15,
    strict_threshold=0.8,
):
    """
    Segment each PMC in a 3D confocal image.

    Segments an PMC by combining "loose" and "strict" PMC segmentations.

    Parameters
    ----------
    stain : numpy.ndarray
        3D image containing PMC stain.
    pmc_probs : numpy.ndarray
        3D image containing probabilities of each pixel in `stain` containing a
        PMC.
    max_stacks : int, optional
        The maximum number of z-slices a label can occupy before assuming stacked
        PMCs, by default 7.
    min_stacks : int, optional
        The minimum number of slices a label should occupy before removing,
        by default 3
    p_low : float, optional
        Lower probabilitiy bound for considering a pixel a PMC during loose
        segmentation. Used in `filters.apply_hysteresis_threshold()`,
        by default 0.45
    p_high : float, optional
        Higher probabilitiy bound for considering a pixel a PMC during loose
        segmentation. Used in `filters.apply_hysteresis_threshold()`,
        by default 0.5
    selem : np.ndarray, optional
        Structuring element for morphological operations. Default is None, and
        skimage/scipy.ndimage defaults will be used.
    min_area : float, optional
        Minimum area for a single label. Any label with an area below the
        threshold will be dropped. This will occurr *prior* to any strict
        thresholding of large labels. The default is 55.
    max_area : float, optional
        Maximum area for a single label. Any label exceeding the threshold will
        be attempted to be split into separate labels using stricter thresholding
        and segmentation. Default is 600.
    area_w_diameter : float, optional
        Maximum area of label allowed when the diameter also exceed a specified
        value. Default is 500. Labels that exceed both criteria will be further
        segmented.
    d_threshold : float, optional
        Maximum diameter allowed for larger labels. Default is 15, and labels
        that exceed both `d_threhshold` and `area_w_diameter` criteria will be
        further segmented.
    strict_threshold : float, optional
        Higher probabilitiy bound for considering a pixel a PMC during stricter
        segmentation. Used if the area of a region exceeds `area_thresh`, by
        default 0.8

    Returns
    -------
    np.ndarray
        Integer array with the same size of `stain` and `pmc_probs` where each
        numbered region represents a unique PMC.
    """

    labels = generate_labels(
        stain,
        pmc_probs,
        p_low=p_low,
        p_high=p_high,
        selem=selem,
        max_stacks=max_stacks,
        min_stacks=min_stacks,
    )
    # further segment large label blocks using strict segmentation values
    for region in measure.regionprops(labels):
        if region.area > max_area:
            strict_pmc_prediction(region, pmc_probs, labels, strict_threshold)
        elif region.area > area_w_diameter and region.feret_diameter_max > d_threshold:
            logging.info(
                "Label %s exceeds combined area and diameter threshold: A=%d, D=%0.2f",
                region.label,
                region.area,
                region.feret_diameter_max,
            )
            strict_pmc_prediction(region, pmc_probs, labels, strict_threshold)
    # remove labels deemed to small to be PMCs
    for region in measure.regionprops(labels):
        if region.area < min_area:
            logging.info(
                "Label %s does not meet area threshold. Removing.", region.label
            )
            assign_to_label(labels, region, slc=None, new_label=0)
    renumber_labels(labels)
    return labels


def labels_to_hdf5(image, filename):
    f = h5py.File(filename, "w")
    dataset = f.create_dataset("image", image.shape, h5py.h5t.NATIVE_INT16, data=image)
    f.close()


if __name__ == "__main__":
    try:
        snakemake
    except NameError:
        snakemake = None
    if snakemake is not None:
        logging.basicConfig(filename=snakemake.log[0], level=logging.INFO)
        pmc_probs = np.array(h5py.File(snakemake.input["probs"], "r")["exported_data"])[
            :, :, :, 1
        ]
        pmc_stain = np.array(h5py.File(snakemake.input["stain"], "r")["image"])
        pmc_segmentation = find_pmcs(
            pmc_stain,
            pmc_probs,
            p_low=0.45,
            p_high=0.5,
            selem=None,
            max_stacks=7,
            min_stacks=2,
            max_area=600,
            area_w_diameter=500,
            d_threshold=15,
            strict_threshold=0.8,
        )
        labels_to_hdf5(pmc_segmentation, snakemake.output["labels"])
    else:
        import os
        import napari

        logging.getLogger().setLevel(logging.INFO)
        start_file = "MK886/MK_2_0/replicate3/18hpf_MK2uM_R3_emb9.nd2"
        wc = start_file.replace(".nd2", "").replace("_", "-").replace("/", "_")
        wc = "MK886_MK-1-5_groupB_18-MK-1-5-emb1002"
        pmc_probs = np.array(
            h5py.File(os.path.join("data", "pmc_probs", f"{wc}.h5"), "r")[
                "exported_data"
            ]
        )[:, :, :, 1]
        pmc_stain = np.array(
            h5py.File(os.path.join("data", "pmc_norm", f"{wc}.h5"))["image"]
        )
        pmc_segmentation = generate_labels(
            pmc_stain,
            pmc_probs,
            p_low=0.45,
            p_high=0.5,
            selem=None,
            max_stacks=7,
            min_stacks=2,
        )
        labels2 = find_pmcs(pmc_stain, pmc_probs, strict_threshold=0.8)
        with napari.gui_qt():
            viewer = napari.Viewer()
            viewer.add_image(pmc_stain, name="pmc", scale=[3.1, 1, 1])
            viewer.add_labels(pmc_segmentation, scale=[3.1, 1, 1], name="labels")
            viewer.add_labels(pmc_segmentation == 55, scale=[3.1, 1, 1], name="55")
            try:
                viewer.add_labels(labels2, scale=[3.1, 1, 1], name="labels-adjusted")
            except NameError:
                pass
            viewer.add_image(
                pmc_probs,
                scale=[3.1, 1, 1],
                colormap="red",
                blending="additive",
                name="probs",
            )
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys
import os

import h5py
import numpy as np
from skimage import exposure

try:
    from intensipy import Intensify
except ImportError:
    pass

sys.path.append(os.path.basename(__file__))
import utils


def get_channel_index(channels, channel):
    channel_index = [
        i for i, x in enumerate(channels.split(";")) if x.lower() == channel.lower()
    ][0]
    return channel_index


def preprocess_slice(img, upper_percentile=99.99, new_min=0, new_max=1):
    """Preprocess a Z-slice by scaling and equalizing intensities.

    Parameters
    ----------
    img : np.ndarray
        Image slice to scale and equalize.
    upper_percentile : float, optional
        Upper bound to clip intensities for scaling, by default 99.99.
    new_min : int, optional
        New minimum intensity value, by default 0.
    new_max : int, optional
        New maximum intensity value, by default 1.

    Returns
    -------
    np.ndarray
        Scaled and equalize image slice.
    """
    lb, ub = np.percentile(img, (0, upper_percentile))
    out = exposure.equalize_adapthist(
        exposure.rescale_intensity(img, in_range=(lb, ub), out_range=(new_min, new_max))
    )
    return out


if __name__ == "__main__":
    try:
        snakemake
    except NameError:
        snakemake = None
    if snakemake is not None:
        img, __ = utils.read_image_file(snakemake.input["image"])
        channel = get_channel_index(
            snakemake.params["channels"], snakemake.params["channel_name"]
        )
        z_start = int(snakemake.params["z_start"])
        z_stop = int(snakemake.params["z_end"])
        pmc = img[channel, z_start:z_stop, :, :]
        if snakemake.params["intensipy"]:
            model = Intensify(xy_norm=False, dy=29, dx=29)
            pmc = model.normalize(pmc.astype(float))
        else:
            pmc = np.array(
                [
                    preprocess_slice(x, upper_percentile=100, new_min=0, new_max=1)
                    for x in pmc
                ]
            )
        utils.to_hdf5(pmc, snakemake.output["h5"])
54
55
script:
    "scripts/normalize_pmc_stain.py"
68
69
70
71
72
73
shell:
    "({params.ilastik_loc} --headless "
    "--project={input.model} "
    "--output_format=hdf5 "
    "--output_filename_format={output} "
    "{input.image}) 2> {log}"
86
87
script:
    "scripts/label_pmcs.py"
108
109
script:
    "scripts/count_spots.py"
SnakeMake From line 108 of master/Snakefile
119
120
script:
    "scripts/combine_counts.py"
SnakeMake From line 119 of master/Snakefile
ShowHide 6 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/BradhamLab/pmc_detection
Name: pmc_detection
Version: v0.1.0
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: BSD 3-Clause "New" or "Revised" License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...