Neuroinformatics Workflow for Clinical Atlas Registration and Data Analysis

public public 1yr ago 0 bookmarks

clinical-atlasreg

Inputs:

  • participants.tsv with target subject IDs

  • bids folder

    • other folder containing bids-like processed data

Singularity containers required:

  • khanlab/neuroglia-core:latest

Authors

  • Your name here @yourgithubid

Usage

If you use this workflow in a paper, don't forget to give credits to the authors by citing the URL of this (original) repository and, if available, its DOI (see above).

Step 1: Obtain a copy of this workflow

  1. Create a new github repository using this workflow as a template .

  2. Clone the newly created repository to your local system, into the place where you want to perform the data analysis.

Step 2: Configure workflow

Configure the workflow according to your needs via editing the files in the config/ folder. Adjust config.yml to configure the workflow execution, and participants.tsv to specify your subjects.

Step 3: Install Snakemake

Install Snakemake using conda :

conda create -c bioconda -c conda-forge -n snakemake snakemake

For installation details, see the instructions in the Snakemake documentation .

Step 4: Execute workflow

Activate the conda environment:

conda activate snakemake

Test your configuration by performing a dry-run via

snakemake --use-singularity -n

Execute the workflow locally via

snakemake --use-singularity --cores $N

using $N cores or run it in a cluster environment via

snakemake --use-singularity --cluster qsub --jobs 100

or

snakemake --use-singularity --drmaa --jobs 100

If you are using Compute Canada, you can use the cc-slurm profile, which submits jobs and takes care of requesting the correct resources per job (including GPUs). Once it is set-up with cookiecutter, run:

snakemake --profile cc-slurm

Or, with neuroglia-helpers can get a 8-core, 32gb node and run locally there. First, get a node (default 8-core, 32gb, 3 hour limit):

regularInteractive 

Then, run:

snakemake --use-singularity --cores 8 --resources mem=32000

See the Snakemake documentation for further details.

Step 5: Investigate results

After successful execution, you can create a self-contained interactive HTML report with all results via:

snakemake --report report.html

This report can, e.g., be forwarded to your collaborators. An example (using some trivial test data) can be seen here .

Step 6: Commit changes

Whenever you change something, don't forget to commit the changes back to your github copy of the repository:

git commit -a
git push

Step 7: Obtain updates from upstream

Whenever you want to synchronize your workflow copy with new developments from upstream, do the following.

  1. Once, register the upstream repository in your local copy: git remote add -f upstream [email protected]:snakemake-workflows/{{cookiecutter.repo_name}}.git or git remote add -f upstream https://github.com/snakemake-workflows/{{cookiecutter.repo_name}}.git if you do not have setup ssh keys.

  2. Update the upstream version: git fetch upstream .

  3. Create a diff with the current version: git diff HEAD upstream/master workflow > upstream-changes.diff .

  4. Investigate the changes: vim upstream-changes.diff .

  5. Apply the modified diff via: git apply upstream-changes.diff .

  6. Carefully check whether you need to update the config files: git diff HEAD upstream/master config . If so, do it manually, and only where necessary, since you would otherwise likely overwrite your settings and samples.

Step 8: Contribute back

In case you have also changed or added steps, please consider contributing them back to the original repository:

  1. Fork the original repo to a personal or lab account.

  2. Clone the fork to your local system, to a different place than where you ran your analysis.

  3. Copy the modified files from your analysis to the clone of your fork, e.g., cp -r workflow path/to/fork . Make sure to not accidentally copy config file contents or sample sheets. Instead, manually update the example config files if necessary.

  4. Commit and push your changes to your fork.

  5. Create a pull request against the original repository.

Testing

TODO: create some test datasets

Code Snippets

53
script: '../scripts/working/elec_labels_coords.py'
69
script: '../scripts/label_electrodes_atlas.py'
77
78
79
80
81
82
run:
    df = pd.read_table(input.fcsv,sep=',',header=2)
    coords = df[['x','y','z']].to_numpy()
    with open (output.txt, 'w') as fid:
        for i in range(len(coords)):
            fid.write(' '.join(str(i) for i in np.r_[np.round(coords[i,:],3),int(1)])+ "\n")
91
92
shell:
    'c3d {input.ct} -scale 0 -landmarks-to-spheres {input.txt} 1 -o {output.mask}'
51
52
53
        shell:
            "export FASTSURFER_HOME={params.fastsurfer_run} &&PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:4096 {params.fastsurfer_run}/run_fastsurfer.sh \
--t1 {input.t1} --sd {params.fastsurfer_out} --sid {params.subjid} --order {params.order} --py {params.py} --run_viewagg_on cpu --fsaparc --parallel --surfreg"
88
89
90
        shell:
            "export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:4096 &&export FASTSURFER_HOME={params.fastsurfer_run} &&{params.fastsurfer_run}/run_fastsurfer.sh \
--t1 {input.t1} --sd {output.fastsurfer_out} --sid {params.sid} --order {params.order} --py {params.py} --threads {params.threads} --batch {params.batch} --run_viewagg_on cpu --fsaparc --parallel --surfreg"
111
script: '../scripts/vis_electrodes_native.py'
31
32
33
34
35
36
37
38
39
40
shell:
    'export SINGULARITYENV_FS_LICENSE=$HOME/.freesurfer.txt&&\
    singularity run --cleanenv \
    --bind {params.bids_dir}:/tmp/input \
    --bind {params.out_dir}:/tmp/output \
    --bind {params.license}:/tmp/{params.license_name} \
    {params.fmriprep_img} /tmp/input  /tmp/output participant --skip_bids_validation \
    --participant_label {params.sub} --anat-only \
    --fs-license-file /tmp/{params.license_name} \
    --bids-filter-file {params.bids_filter}'
10
11
shell:
    "mkdir -p {params.in_dir} && mkdir -p {params.in_dir_sub} && mkdir -p {params.in_dir_sub}/anat && cp {input.in_t1w} {output.out_t1w}"
26
27
28
shell:
    "singularity run -e {params.hippunfold_container} {input.in_dir} {params.hippunfold_out} participant --force_output --participant_label {params.participant_label}\
    --modality {params.modality} --cores 4"
61
shell: "echo {input} &&cp {input} {output}"
68
shell: "echo {input} &&cp {input} {output}"
74
75
shell:
    'cp {input} {output}'
81
82
shell:
    'cp {input} {output}'
96
97
shell:
    'reg_aladin -flo {input.flo} -ref {input.ref} {params.dof} -interp 0 -res {output.warped_subj} -aff {output.xfm_ras} -speeeeed'
106
107
script: 
    '../scripts/convert_xfm_tfm.py'
112
113
shell:
    'cp {input} {output}'
123
shell: 'cp {input} {output}'
136
137
shell:
    'reg_aladin -flo {input.flo} -ref {input.ref} {params.dof} -interp 0 -res {output.warped_subj} -aff {output.xfm_ras} -speeeeed'
146
147
script: 
    '../scripts/convert_xfm_tfm.py'
157
shell: 'cp {input} {output}'
170
171
shell:
    'reg_aladin -flo {input.flo} -ref {input.ref} {params.dof} -res {output.warped_subj} -aff {output.xfm_ras} -speeeeed'
180
181
script: 
    '../scripts/convert_xfm_tfm.py'
196
197
shell:
    'reg_aladin -flo {input.flo} -ref {input.ref} -res {output.warped_subj} -aff {output.affine_xfm_ras} -speeeeed'
217
218
219
220
221
shell:
    'greedy -d 3 -threads 4 -a -ia-image-centers -m MI -i {input.ref} {input.flo} -o {output.affine_xfm_ras} -n {params.n_iterations_affine}&&'
    'greedy -d 3 -threads 4 -m MI -i {input.ref} {input.flo} -it {output.affine_xfm_ras}  -o {output.xfm_deform} -oinv {output.xfm_deform_inv} -n {params.n_iterations_deform} -s {params.grad_sigma} {params.warp_sigma} &&'
    'greedy -d 3 -threads 4 -rf {input.ref} -rm {input.flo} {output.warped_subj_affine} -r {output.affine_xfm_ras}&&'
    'greedy -d 3 -threads 4 -rf {input.ref} -rm {input.flo} {output.warped_subj_greedy} -r {output.xfm_deform} {output.affine_xfm_ras}'
229
230
script: 
    '../scripts/convert_xfm_tfm.py'
243
244
shell:
    'c3d_affine_tool {input.xfm}  -oitk {output}'
255
256
257
258
259
260
261
262
263
264
265
266
    shell: 'antsApplyTransforms -d 3 --interpolation NearestNeighbor -i {input.mask} -o {output.mask} -r {input.ref} '
            ' -t [{input.xfm},1] ' #use inverse xfm (going from template to subject)

rule warp_tissue_probseg_from_template_affine:
    input: 
        probseg = config['template_tissue_probseg'],
        ref = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='T1w.nii.gz'),
        xfm = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='xfm.txt',from_='subject',to='{template}',desc='{desc}',type_='itk'),
    output:
        probseg = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='probseg.nii.gz',label='{tissue}',from_='{template}',reg='{desc}'),
    #container: config['singularity']['neuroglia']
    group: 'preproc'
270
271
272
273
274
275
276
277
278
279
280
281
282
    shell: 
        'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
        'antsApplyTransforms -d 3 --interpolation Linear -i {input.probseg} -o {output.probseg} -r {input.ref} '
            ' -t [{input.xfm},1]' #use inverse xfm (going from template to subject)

rule n4biasfield:
    input: 
        t1 = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='T1w.nii.gz'),
    output:
        t1 = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,desc='n4', suffix='T1w.nii.gz'),
    threads: 8
    #container: config['singularity']['neuroglia']
    group: 'preproc'
283
284
285
shell:
    'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
    'N4BiasFieldCorrection -d 3 -i {input.t1} -o {output}'
296
297
shell:
    'fslmaths {input.t1} -mas {input.mask} {output}'
324
325
326
327
328
329
shell: 'fslmaths {input.t1} -mas {input.mask} {output.t1_masked} &&\
        reg_aladin -flo {params.template} -ref {input.t1} -res {output.template_to_t1} -aff {output.template_to_t1_matrix} -interp 0 -speeeeed &&\
        c3d_affine_tool {output.template_to_t1_matrix}  -oitk {output.template_to_t1_itk} &&\
        antsApplyTransforms -d 3 -i {params.facemask} -o {output.warped_mask} -r {input.t1} -t [{output.template_to_t1_itk},0] &&\
        fslmaths {input.t1} -mas {output.warped_mask} {output.defaced_t1} &&\
        fslmaths {input.ct} -mas {output.warped_mask} {output.defaced_ct}'
339
340
shell:
    'fslmaths {input.t1} -mas {input.mask} {output}'
351
352
shell:
    'fslmaths {input.ct} -mas {input.mask} {output.ct}'
387
388
389
390
391
392
393
394
shell: 
    'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
    'antsRegistration {params.base_opts} {params.intensity_opts} '
    '{params.init_transform} ' #initial xfm  -- rely on this for affine
#    '-t Rigid[0.1] {params.linear_metric} {params.linear_multires} ' # rigid registration
#    '-t Affine[0.1] {params.linear_metric} {params.linear_multires} ' # affine registration
    '{params.deform_model} {params.deform_metric} {params.deform_multires} '  # deformable registration
    '-o [{params.out_prefix},{output.warped_flo}]'
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    shell: 
        'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
        'antsApplyTransforms -d 3 --interpolation NearestNeighbor -i {input.dseg} -o {output.dseg} -r {input.ref} '
            ' -t {input.inv_composite} ' #use inverse xfm (going from template to subject)

rule warp_tissue_probseg_from_template:
    input: 
        probseg = config['template_tissue_probseg'],
        ref = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='T1w.nii.gz'),
        inv_composite = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='InverseComposite.h5',from_='subject',to='{template}'),
    output:
        probseg = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='probseg.nii.gz',label='{tissue}',from_='{template}',reg='SyN'),
    #container: config['singularity']['neuroglia']
    group: 'preproc'
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    shell: 
        'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
        'antsApplyTransforms -d 3 --interpolation Linear -i {input.probseg} -o {output.probseg} -r {input.ref} '
            ' -t {input.inv_composite} ' #use inverse xfm (going from template to subject)

rule warp_brainmask_from_template:
    input: 
        mask = config['template_mask'],
        ref = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='T1w.nii.gz'),
        inv_composite = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='InverseComposite.h5',from_='subject',to='{template}'),
    output:
        mask = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='mask.nii.gz',from_='{template}',reg='SyN',desc='brain'),
    #container: config['singularity']['neuroglia']
    group: 'preproc'
442
443
444
445
446
447
448
449
450
451
452
453
    shell: 
        'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} '
        'antsApplyTransforms -d 3 --interpolation NearestNeighbor -i {input.mask} -o {output.mask} -r {input.ref} '
            ' -t {input.inv_composite} ' #use inverse xfm (going from template to subject)

rule dilate_brainmask:
    input:
        mask = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='mask.nii.gz',from_='{template}',reg='{desc}',desc='brain'),
    output:
        mask = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='mask.nii.gz',from_='{template}',reg='{desc}',desc='braindilated'),
    #container: config['singularity']['neuroglia']
    group: 'preproc'
454
455
shell:
    'fslmaths {input} -dilD {output}'
467
468
shell:
    'fslmaths {input} {params.dil_opt} {output}'
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    shell:
        'Atropos -d 3 -a {input.t1} -i KMeans[{params.k}] -m {params.m} -c {params.c} -x {input.mask} -o [{output.seg},{params.posterior_fmt}] && '
        'fslmerge -t {output.posteriors} {params.posterior_glob} ' #merge posteriors into a 4d file (intermediate files will be removed b/c shadow)

rule map_channels_to_tissue:
    input:
        tissue_priors = expand(bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='probseg.nii.gz',label='{tissue}',from_='{template}'.format(template=config['template']),reg='affine'),
                            tissue=config['tissue_labels'],allow_missing=True),
        seg_channels_4d = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='probseg.nii.gz',desc='atroposKseg'),
    output:
        mapping_json = bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='mapping.json',desc='atropos3seg'),
        tissue_segs = expand(bids(root=join(config['out_dir'], 'derivatives', 'atlasreg'),subject=subject_id,suffix='probseg.nii.gz',label='{tissue}',desc='atropos3seg'),
                            tissue=config['tissue_labels'],allow_missing=True),
    group: 'preproc'
39
script: '../scripts/map_channels_to_tissue.py'
49
50
shell:
    'fslmerge -t {output} {input}'
29
script: '../scripts/vis_regqc.py'
54
script: '../scripts/vis_regqc.py'
75
script: '../scripts/vis_regqc.py'
96
script: '../scripts/vis_regqc.py'
130
script: '../scripts/vis_qc_probseg.py'
SnakeMake From line 130 of rules/visqc.smk
150
script: '../scripts/vis_qc_dseg.py'
SnakeMake From line 150 of rules/visqc.smk
170
script: '../scripts/vis_qc_dseg.py'
SnakeMake From line 170 of rules/visqc.smk
187
script: '../scripts/vis_qc_tissue_seg.py'
SnakeMake From line 187 of rules/visqc.smk
204
script: '../scripts/vis_contacts.py'
SnakeMake From line 204 of rules/visqc.smk
223
script: '../scripts/vis_electrodes.py'
SnakeMake From line 223 of rules/visqc.smk
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import pandas as pd
from scipy.io import loadmat


# filen=r'/home/greydon/Documents/data/SEEG_peds/derivatives/atlasreg/sub-P010/sub-P010_acq-noncontrast_desc-rigid_from-noncontrast_to-contrast_type-ras_xfm.mat'
# filen_out=r'/home/greydon/Documents/data/SEEG_peds/derivatives/atlasreg/sub-P010/sub-P010_acq-noncontrast_desc-rigid_from-noncontrast_to-contrast_type-ras_xfm_1.txt'
# filen=r'/home/greydon/Documents/data/SEEG/derivatives/atlasreg/sub-P097/sub-P097_desc-rigid_from-ct_to-T1w_type-ras_xfm.txt'


transformMatrix = np.loadtxt(snakemake.input.xfm)
lps2ras=np.diag([-1, -1, 1, 1])
ras2lps=np.diag([-1, -1, 1, 1])
transform_lps=np.dot(ras2lps, np.dot(transformMatrix,lps2ras))

Parameters = " ".join([str(x) for x in np.concatenate((transform_lps[0:3,0:3].reshape(9), transform_lps[0:3,3]))])
#output_matrix_txt = filen.split('.txt')[0] + '.tfm'

with open(snakemake.output.tfm, 'w') as fid:
	fid.write("#Insight Transform File V1.0\n")
	fid.write("#Transform 0\n")
	fid.write("Transform: AffineTransform_double_3_3\n")
	fid.write("Parameters: " + Parameters + "\n")
	fid.write("FixedParameters: 0 0 0\n")
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pandas as pd
import numpy as np
import nibabel as nib
import os

#read fcsv electrodes file
df_elec = pd.read_table(snakemake.input.fcsv,sep=',',header=2)
df_elec
df_atlas = pd.read_table(snakemake.input.dseg_tsv)
df_atlas

#load up tissue probability, warped from template
tissue_prob_vol = dict()
tissue_prob_elec = dict()

for label,nii in zip(snakemake.config['tissue_labels'], snakemake.input.tissue_seg):
	print(label)
	print(nii)
	tissue_prob_vol[label] = nib.load(nii).get_fdata()
	tissue_prob_elec[label] = list()

#load dseg nii (as integer)
dseg_nii = nib.load(snakemake.input.dseg_nii)
dseg_vol = dseg_nii.get_fdata().astype('int')

#get affine from image, so we can go from RAS coords to array indices
dseg_affine = dseg_nii.affine
dseg_affine

#get coords from fcsv
coords = df_elec[['x','y','z']].to_numpy()


labelnames = []

for i in range(len(coords)):

	vec = np.hstack([coords[i,:],1])

	#dseg_affine is used to xfm indices to RAS coords, 
	# so we use the inverse to go the other way
	tvec = np.linalg.inv(dseg_affine) @ vec.T   
	inds = np.round(tvec[:3]).astype('int')
	labelnum = dseg_vol[inds[0],inds[1],inds[2]]


	if labelnum >0:
		labelnames.append(df_atlas.loc[df_atlas['label']==labelnum,'name'].to_list()[0])
	else:
		labelnames.append('None')

	for label in snakemake.config['tissue_labels']:
		tissue_prob_elec[label].append(tissue_prob_vol[label][inds[0],inds[1],inds[2]])

#add new columns to existing dataframe
df_elec['atlas_label'] = labelnames
for label in snakemake.config['tissue_labels']:
	df_elec[label] = tissue_prob_elec[label]

#create new dataframe with selected variables and save it
out_df = df_elec[['label','atlas_label'] + snakemake.config['tissue_labels'] + ['x','y','z']]
out_df.to_csv(snakemake.output.tsv,sep='\t',float_format='%.3f',index=False)
out_df.to_excel(os.path.splitext(snakemake.output.tsv)[0]+'.xlsx',float_format='%.3f',index=False)

out_df
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import nibabel as nib
import json

#load up tissue probability, warped from template
tissue_prob_vol = dict()

for label,nii in zip(snakemake.config['tissue_labels'], snakemake.input.tissue_priors):
	tissue_prob_vol[label] = nib.load(nii).get_fdata()


#load up k-class tissue segmentation
tissue_k_seg = nib.load(snakemake.input.seg_channels_4d)
tissue_k_seg.shape

sim_prior_k = np.zeros([len(snakemake.config['tissue_labels']),tissue_k_seg.shape[3]])

#for each prior, need to find the channel that best fits
for i,label in enumerate(snakemake.config['tissue_labels']):
	for k in range(tissue_k_seg.shape[3]):

		print(f'Computing overlap of {label} prior and channel {k}... ')
		#compute intersection over union
		s1 = tissue_prob_vol[label] >0.5
		s2 = tissue_k_seg.slicer[:,:,:,k].get_fdata() >0.5
		sim_prior_k[i,k] = np.sum(np.logical_and(s1,s2).flat) / np.sum(np.logical_or(s1,s2).flat) 

label_to_k_dict = dict()

for i,label in enumerate(snakemake.config['tissue_labels']):
	label_to_k_dict[label] = int(np.argmax(sim_prior_k[i,:]))
	#write nii to file
	print('writing image at channel {} to output file: {}'.format(label_to_k_dict[label], \
													snakemake.output.tissue_segs[i]))
	nib.save(tissue_k_seg.slicer[:,:,:,label_to_k_dict[label]],\
					snakemake.output.tissue_segs[i])


with open(snakemake.output.mapping_json, 'w') as outfile:
	json.dump(label_to_k_dict, outfile,indent=4)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import matplotlib
matplotlib.use('Agg')
import ants
from nilearn import plotting,image
import nibabel as nib
import numpy as np

# snakemake.input.
# snakemake.output.
# html_view.open_in_browser()

template = ants.image_read(ants.get_ants_data('mni'))

ct_img=nib.load(snakemake.input.ct)
if (np.isnan(ct_img.get_fdata())).any():
	ct_img=nib.Nifti1Image(np.nan_to_num(ct_img.get_fdata()), header=ct_img.header, affine=ct_img.affine)
	nib.save(ct_img,snakemake.input.ct)

ct_ants = ants.image_read(snakemake.input.ct)
mask_ants = ants.image_read(snakemake.input.mask)

ct_ants_reg = ants.registration(template, ct_ants, type_of_transform='QuickRigid')
ct_ants_reg_applied=ants.apply_transforms(template, ct_ants, transformlist=ct_ants_reg['fwdtransforms'])
ct_resample = ants.to_nibabel(ct_ants_reg_applied)

mask_ants_reg_applied = ants.apply_transforms(ct_ants, mask_ants, transformlist=ct_ants_reg['fwdtransforms'])
mask_resample = ants.to_nibabel(mask_ants_reg_applied)

mask_params = {
			'symmetric_cmap': True,
			'cut_coords':[0,0,0],
			'dim': 1,
			'cmap':'viridis',
			'opacity':0.7
			}

html_view = plotting.view_img(stat_map_img=mask_resample,bg_img=ct_resample,**mask_params)
html_view.save_as_html(snakemake.output.html)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import os
import nibabel as nb
import numpy as np
import pandas as pd
import regex as re
import matplotlib.pyplot as plt
from nilearn.plotting.displays import PlotlySurfaceFigure
import plotly.graph_objs as go
from mne.transforms import apply_trans


AXIS_CONFIG = {
    "showgrid": False,
    "showline": False,
    "ticks": "",
    "title": "",
    "showticklabels": False,
    "zeroline": False,
    "showspikes": False,
    "spikesides": False,
    "showbackground": False,
}

LAYOUT = {
	"scene": {f"{dim}axis": AXIS_CONFIG for dim in ("x", "y", "z")},
	"paper_bgcolor": "#fff",
	"hovermode": False,
	"showlegend":True,
	"legend":{
		"itemsizing": "constant",
		"groupclick":"togglegroup",
		"yanchor":"top",
		"y":0.8,
		"xanchor":"left",
		"x":0.05,
		"title_font_family":"Times New Roman",
		"font":{
			"size":20
		},
		"bordercolor":"Black",
		"borderwidth":1
	},
	"margin": {"l": 0, "r": 0, "b": 0, "t": 0, "pad": 0},
}

CAMERAS = {
    "left": {
        "eye": {"x": -1.5, "y": 0, "z": 0},
        "up": {"x": 0, "y": 0, "z": 1},
        "center": {"x": 0, "y": 0, "z": 0},
    },
    "right": {
        "eye": {"x": 1.5, "y": 0, "z": 0},
        "up": {"x": 0, "y": 0, "z": 1},
        "center": {"x": 0, "y": 0, "z": 0},
    },
    "dorsal": {
        "eye": {"x": 0, "y": 0, "z": 1.5},
        "up": {"x": 0, "y": 1, "z": 0},
        "center": {"x": 0, "y": 0, "z": 0},
    },
    "ventral": {
        "eye": {"x": 0, "y": 0, "z": -1.5},
        "up": {"x": 0, "y": 1, "z": 0},
        "center": {"x": 0, "y": 0, "z": 0},
    },
    "anterior": {
        "eye": {"x": 0, "y": 1.5, "z": 0},
        "up": {"x": 0, "y": 0, "z": 1},
        "center": {"x": 0, "y": 0, "z": 0},
    },
    "posterior": {
        "eye": {"x": 0, "y": -1.5, "z": 0},
        "up": {"x": 0, "y": 0, "z": 1},
        "center": {"x": 0, "y": 0, "z": 0},
    },
}

lighting_effects = dict(ambient=0.4, diffuse=0.5, roughness = 0.9, specular=0.6, fresnel=0.2)

def determine_groups(iterable, numbered_labels=False):
	values = []
	for item in iterable:
		temp=None
		if re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item):
			temp = "".join(list(re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item)[0]))
		elif '-' in item:
			temp=item.split('-')[0]
		else:
			if numbered_labels:
				temp=''.join([x for x in item if not x.isdigit()])
				for sub in ("T1","T2"):
					if sub in item:
						temp=item.split(sub)[0] + sub
			else:
				temp=item
		if temp is None:
			temp=item

		values.append(temp)

	vals,indexes,count = np.unique(values, return_index=True, return_counts=True)
	vals=vals[indexes.argsort()]
	count=count[indexes.argsort()]
	return vals,count


hemi = ["lh", "rh"]
surf_suffix = ["pial", "white", "inflated"]

def readRegMatrix(trsfPath):
	with open(trsfPath) as (f):
		return np.loadtxt(f.readlines())

#%%


debug = False
if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="098"
	datap=r'/home/greydon/Documents/data/SEEG/derivatives'

	input=dotdict({
		't1_fname':datap+f'/fastsurfer/sub-P{isub}/mri/orig.mgz',
		'fcsv':datap+ f'/seega_coordinates/sub-P{isub}/sub-P{isub}_space-native_SEEGA.tsv',
		'xfm_noncontrast':datap+f'/atlasreg/sub-P{isub}/sub-P{isub}_desc-rigid_from-noncontrast_to-contrast_type-ras_xfm.txt',
	})

	output=dotdict({
		'html':datap+f'/atlasreg/sub-P{isub}/sub-P{isub}_space-native_electrodes.html',
	})

	params=dotdict({
		'lh_pial':datap+f'/fastsurfer/sub-P{isub}/surf/lh.pial',
		'rh_pial':datap+f'/fastsurfer/sub-P{isub}/surf/rh.pial',
		'lh_sulc':datap+f'/fastsurfer/sub-P{isub}/surf/lh.sulc',
		'rh_sulc':datap+f'/fastsurfer/sub-P{isub}/surf/rh.sulc',
	})

	snakemake = Namespace(output=output, input=input,params=params)

t1_obj = nb.load(snakemake.input.t1_fname)
Torig = t1_obj.header.get_vox2ras_tkr()
#fs_transform=(t1_obj.affine-Torig)+np.eye(4)
fs_transform=np.dot(t1_obj.affine, np.linalg.inv(Torig))

verl,facel=nb.freesurfer.read_geometry(snakemake.params.lh_pial)
verr,facer=nb.freesurfer.read_geometry(snakemake.params.rh_pial)

all_ver = np.concatenate([verl, verr], axis=0)
all_face = np.concatenate([facel, facer+verl.shape[0]], axis=0)
surf_mesh = [all_ver, all_face]

all_ver_shift=(apply_trans(fs_transform, all_ver))

if len(snakemake.input.xfm_noncontrast)>0:
	if os.path.exists(snakemake.input.xfm_noncontrast):
		t1_transform=readRegMatrix(snakemake.input.xfm_noncontrast)
		all_ver_shift=(apply_trans(np.linalg.inv(t1_transform), all_ver_shift))
		#all_ver_shift=(apply_trans(t1_transform, all_ver_shift))


lh_sulc_data = nb.freesurfer.read_morph_data(snakemake.params.lh_sulc)
rh_sulc_data = nb.freesurfer.read_morph_data(snakemake.params.rh_sulc)
bg_map = np.concatenate((lh_sulc_data, rh_sulc_data))


mesh_3d = go.Mesh3d(x=all_ver_shift[:,0], y=all_ver_shift[:,1], z=all_ver_shift[:,2], i=all_face[:,0], j=all_face[:,1], k=all_face[:,2],opacity=.1,color='grey',alphahull=-10)

value=[np.round(x,2) for x in np.arange(.1,.6-.05,.05)]

df = pd.read_table(os.path.splitext(snakemake.input.fcsv)[0]+".tsv",sep='\t',header=0)
groups,n_members=determine_groups(df['label'].tolist(), True)
df['group']=np.repeat(groups,n_members)

cmap = plt.get_cmap('rainbow')
color_maps=cmap(np.linspace(0, 1, len(groups))).tolist()
res = dict(zip(groups, color_maps))

colors=[]
for igroup in df['group']:
	colors.append(res[igroup])

colors=np.vstack(colors)

data=[mesh_3d]
for igroup in groups:
	idx = [i for i,x in enumerate(df['label'].tolist()) if igroup in x]
	data.append(go.Scatter3d(
		x = df['x'][idx].values,
		y = df['y'][idx].values,
		z = df['z'][idx].values,
		name=igroup,
		mode = "markers+text",
		text=df['label'][idx].tolist(),
		textfont=dict(
			family="sans serif",
			size=16,
			color="black"
		),
		textposition = "middle left",
		marker=dict(
			size=5,
			line=dict(
				width=1,
			),
			color=['rgb({},{},{})'.format(int(r*256),int(g*256),int(b*256)) for r,g,b,h in colors[idx]],
			opacity=1
			)))

fig = go.Figure(data=data)
fig.update_layout(scene_camera=CAMERAS['left'],
				  legend_title_text="Electrodes",
				  **LAYOUT)

steps = []
for i in range(len(value)):
	step = dict(
		label = str(f"{value[i]:.2f}"),
		method="restyle",
		args=[{'opacity': [value[i]]+(len(data)-1)*[1],
			 'alphahull': [-10]+(len(data)-1)*[1]
		 }]
	)
	steps.append(step)

sliders = [dict(
	currentvalue={"visible": True,"prefix": "Opacity: ","font":{"size":16}},
	active=0,
	steps=steps,
	x=.35,y=.1,len=.3,
	pad={"t": 8},
)]

fig.update_layout(sliders=sliders)
fig.write_html(snakemake.output.html)
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pandas as pd
import numpy as np
import matplotlib
import re
import matplotlib.pyplot as plt
from nilearn import plotting

matplotlib.use('Agg')

debug = False

if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="P097"
	data_dir=r'/home/greydon/Documents/data/SEEG/derivatives'

	input=dotdict({'fcsv':f'{data_dir}/seega_coordinates/' + f'sub-{isub}/sub-{isub}_space-native_SEEGA.fcsv',
				'xfm_ras':f'{data_dir}/atlasreg/' + f'sub-{isub}/sub-{isub}_desc-rigid_from-subject_to-MNI152NLin2009cSym_type-ras_xfm.txt'
				})

	output=dotdict({'html':f'{data_dir}/atlasreg/' + f'sub-{isub}/qc/sub-{isub}_space-MNI152NLin2009cSym_desc-affine_electrodes.html',
				'png':f'{data_dir}/atlasreg/' + f'sub-{isub}/qc/sub-{isub}_space-MNI152NLin2009cSym_desc-affine_electrodevis.png'
				})

	snakemake = Namespace(output=output, input=input)

def determine_groups(iterable, numbered_labels=False):
	values = []
	for item in iterable:
		temp=None
		if re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item):
			temp = "".join(list(re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item)[0]))
		elif '-' in item:
			temp=item.split('-')[0]
		else:
			if numbered_labels:
				temp=''.join([x for x in item if not x.isdigit()])
				for sub in ("T1","T2"):
					if sub in item:
						temp=item.split(sub)[0] + sub
			else:
				temp=item
		if temp is None:
			temp=item

		values.append(temp)

	vals,indexes,count = np.unique(values, return_index=True, return_counts=True)
	vals=vals[indexes.argsort()]
	count=count[indexes.argsort()]
	return vals,count

#read fcsv electrodes file
df = pd.read_table(snakemake.input.fcsv,sep=',',header=2)

groups,n_members=determine_groups(df['label'].tolist(),numbered_labels=True)
df['group']=np.repeat(groups,n_members)

cmap = plt.get_cmap('rainbow')
color_maps=cmap(np.linspace(0, 1, len(groups))).tolist()
res = dict(zip(groups, color_maps))

colors=[]
for igroup in df['group']:
	colors.append(res[igroup])

colors=np.vstack(colors)

labels=[str(x) for x in  range(colors.shape[0])]

#load transform from subj to template
sub2template= np.loadtxt(snakemake.input.xfm_ras)

#plot electrodes transformed (affine) to MNI space, with MNI glass brain
coords = df[['x','y','z']].to_numpy()

#to plot in mni space, need to transform coords
tcoords = np.zeros(coords.shape)
for i in range(len(coords)):

    vec = np.hstack([coords[i,:],1])
    tvec = np.linalg.inv(sub2template) @ vec.T
    tcoords[i,:] = tvec[:3]

html_view = plotting.view_markers(tcoords, marker_size=4.0, marker_color=colors, marker_labels=df['label'].tolist())
#html_view.open_in_browser()
html_view.save_as_html(snakemake.output.html)

#plot subject native space electrodes with glass brain
adjacency_matrix = np.zeros([len(coords),len(coords)])

node_label=np.repeat(groups, n_members, axis=0)

group = np.array([1,3,2,1,3])
cdict = {1: 'red', 2: 'blue', 3: 'green'}

_, idx = np.unique(colors, return_index=True, axis=0)

label_dict=dict(zip(groups,colors[np.sort(idx)].tolist()))

display = plotting.plot_connectome(adjacency_matrix, tcoords, node_color=colors, node_size=3)
display.savefig(snakemake.output.png,dpi=300)
display.close()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from nilearn import plotting
import matplotlib.pyplot as plt
import matplotlib
from scipy import ndimage
from nilearn import plotting, image
import matplotlib.pyplot as plt
import matplotlib
import nibabel as nib
from nibabel.affines import apply_affine
matplotlib.use('Qt5Agg')
import numpy as np
import base64
import os
from io import BytesIO
import base64
from svgutils.transform import SVGFigure, GroupElement,fromstring
from svgutils.compose import Unit
from tempfile import TemporaryDirectory
from pathlib import Path
from uuid import uuid4
import re
import numpy as np
from matplotlib import gridspec

np.set_printoptions(precision=6,suppress=True)

from nilearn.datasets import load_mni152_template


def svg2str(display_object, dpi):
	"""Serialize a nilearn display object to string."""
	from io import StringIO

	image_buf = StringIO()
	display_object.frame_axes.figure.savefig(
		image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
	)
	return image_buf.getvalue()

def extract_svg(display_object, dpi=300):
	"""Remove the preamble of the svg files generated with nilearn."""
	image_svg = svg2str(display_object, dpi)

	image_svg = re.sub(' height="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(' width="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(
		" viewBox", ' preseveAspectRation="xMidYMid meet" viewBox', image_svg, count=1
	)
	start_tag = "<svg "
	start_idx = image_svg.find(start_tag)
	end_tag = "</svg>"
	end_idx = image_svg.rfind(end_tag)

	# rfind gives the start index of the substr. We want this substr
	# included in our return value so we add its length to the index.
	end_idx += len(end_tag)
	return image_svg[start_idx:end_idx]

def clean_svg(fg_svgs, bg_svgs, ref=0):
	# Find and replace the figure_1 id.
	svgs = bg_svgs+fg_svgs
	roots = [f.getroot() for f in svgs]

	sizes = []
	for f in svgs:
		viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
		width = int(viewbox[2])
		height = int(viewbox[3])
		sizes.append((width, height))
	nsvgs = len([bg_svgs])

	sizes = np.array(sizes)

	# Calculate the scale to fit all widths
	width = sizes[ref, 0]
	scales = width / sizes[:, 0]
	heights = sizes[:, 1] * scales

	# Compose the views panel: total size is the width of
	# any element (used the first here) and the sum of heights
	fig = SVGFigure(Unit(f"{width}px"), Unit(f"{heights[:nsvgs].sum()}px"))

	yoffset = 0
	for i, r in enumerate(roots):
		r.moveto(0, yoffset, scale_x=scales[i])
		if i == (nsvgs - 1):
			yoffset = 0
		else:
			yoffset += heights[i]

	# Group background and foreground panels in two groups
	if fg_svgs:
		newroots = [
			GroupElement(roots[:nsvgs], {"class": "background-svg"}),
			GroupElement(roots[nsvgs:], {"class": "foreground-svg"}),
		]
	else:
		newroots = roots

	fig.append(newroots)
	fig.root.attrib.pop("width", None)
	fig.root.attrib.pop("height", None)
	fig.root.set("preserveAspectRatio", "xMidYMid meet")

	with TemporaryDirectory() as tmpdirname:
		out_file = Path(tmpdirname) / "tmp.svg"
		fig.save(str(out_file))
		# Post processing
		svg = out_file.read_text().splitlines()

	# Remove <?xml... line
	if svg[0].startswith("<?xml"):
		svg = svg[1:]

	# Add styles for the flicker animation
	if fg_svgs:
		svg.insert(
			2,
			"""\
<style type="text/css">
@keyframes flickerAnimation%s { 0%% {opacity: 1;} 100%% { opacity:0; }}
.foreground-svg { animation: 1s ease-in-out 0s alternate none infinite running flickerAnimation%s;}
.foreground-svg:hover { animation-play-state: running;}
</style>"""
			% tuple([uuid4()] * 2),
		)

	return svg

debug = False
if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="070"
	datap=r'/media/veracrypt6/projects/SEEG/derivatives/atlasreg/'

	input=dotdict({
		'img':datap+f'sub-P{isub}/sub-P{isub}_desc-masked_from-atropos3seg_T1w.nii.gz',
		'seg':datap+f'sub-P{isub}/sub-P{isub}_atlas-CerebrA_from-MNI152NLin2009cSym_reg-SyN_dseg.nii.gz',
	})

	output=dotdict({
		#'html':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.html',
		'html':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.html',
		#'png':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.png'
		'png':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.png'
	})

	snakemake = Namespace(output=output, input=input)

	title = 'sub-P001'


#%%

ref_img=nib.load(snakemake.input.img)
ref_resamp = nib.nifti1.Nifti1Image(ref_img.get_fdata(), affine=ref_img.affine,header=ref_img.header)
#ref_resamp = image.resample_img(ref_img, target_affine=np.eye(3), interpolation='continuous')

flo_img=nib.load(snakemake.input.seg)
flo_resamp = nib.nifti1.Nifti1Image(flo_img.get_fdata(), affine=flo_img.affine,header=flo_img.header)
#flo_resamp = image.resample_img(flo_img, target_affine=np.eye(3), interpolation='continuous')


display = plotting.plot_anat(ref_resamp, display_mode='ortho', draw_cross=False,dim=-1,bg_img=None, cut_coords=[0,0,30])
fg_svgs = [fromstring(extract_svg(display,450))]
display.close()

display = plotting.plot_anat(flo_resamp, display_mode='ortho', draw_cross=False,bg_img=None,black_bg=True,cmap=plt.cm.cubehelix, cut_coords=[0,0,30])
bg_svgs = [fromstring(extract_svg(display,450))]
display.close()

final_svg="\n".join(clean_svg(fg_svgs, bg_svgs))

# make figure of thalamic contours
display = plotting.plot_roi(roi_img=flo_resamp, bg_img=ref_resamp, display_mode='ortho', draw_cross=False, cut_coords=[0,0,30])
display.savefig(snakemake.output.png,dpi=300)
display.close()

tmpfile_ref = BytesIO()
display.savefig(tmpfile_ref,dpi=300)
display.close()
tmpfile_ref.seek(0)
data_uri = base64.b64encode(tmpfile_ref.getvalue()).decode('utf-8')
img_tag = '<center><img src="data:image/png;base64,{0}"/></center>'.format(data_uri)

htmlbase='<!DOCTYPE html> <html lang="en"> <head> <title>Slice viewer</title>  <meta charset="UTF-8" /> </head> <body>'
htmlend='</body> </html>'

htmlfull=htmlbase + final_svg + img_tag + htmlend

# Write HTML String to file.html
with open(snakemake.output.html, "w") as file:
	file.write(htmlfull)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from nilearn import plotting, image
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Qt5Agg')
import nibabel as nib
from matplotlib.colors import ListedColormap,LinearSegmentedColormap
import json
import time
import numpy as np


debug = False
if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="070"
	datap=r'/media/veracrypt6/projects/SEEG/derivatives/atlasreg/'

	input=dotdict({
		'img':datap+f'sub-P{isub}/sub-P{isub}_desc-masked_from-atropos3seg_T1w.nii.gz',
		'seg4d':datap+f'sub-P{isub}/sub-P{isub}_desc-atropos3seg_probseg.nii.gz',
		'mapping':datap+f'sub-P{isub}/sub-P{isub}_desc-atropos3seg_mapping.json',
	})

	output=dotdict({
		#'html':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.html',
		'html':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.html',
		#'png':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.png'
		'png':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.png'
	})

	snakemake = Namespace(output=output, input=input)

	title = 'sub-P001'

#html_view = plotting.view_img(stat_map_img=snakemake.input.seg,bg_img=snakemake.input.img,
#                              opacity=0.5,cmap='viridis',dim=-1,threshold=0.5,
#                              symmetric_cmap=False,title='sub-{subject}'.format(**snakemake.wildcards))
#
#html_view.save_as_html(snakemake.output.html)


ref_img=nib.load(snakemake.input.img)
ref_resamp = nib.nifti1.Nifti1Image(ref_img.get_fdata(), affine=ref_img.affine,header=ref_img.header)
ref_resamp = image.resample_img(ref_img, target_affine=np.eye(3), interpolation='continuous')

flo_img=nib.load(snakemake.input.seg4d)
flo_resamp = nib.nifti1.Nifti1Image(flo_img.get_fdata(), affine=flo_img.affine,header=flo_img.header)
flo_resamp = image.resample_img(flo_img, target_affine=np.eye(3), interpolation='continuous')


with open(snakemake.input.mapping, "r+") as fid:
	mapping_data = json.load(fid)

mapping_data = {str(y):x for x,y in mapping_data.items()}
mapping_data['3']=mapping_data['0']
del mapping_data['0']

coords = plotting.find_xyz_cut_coords(ref_resamp)

colors_dict=[(102,204,238),(34,136,51),(238,102,119),(170,51,119),(204,51,17),(222,143,5),(213,94,0)]
colors_map=[]
for imap in range(len(mapping_data)):
	colors_map.append(colors_dict[imap])

colors_map=[list(np.array(x)/255) for x in colors_map]


display = plotting.plot_prob_atlas(bg_img=ref_resamp,maps_img=flo_resamp, view_type='continuous',display_mode='ortho', draw_cross=False, alpha=.2,cmap=LinearSegmentedColormap.from_list("",colors_map),vmin=0,vmax=4,cut_coords=coords,colorbar=True)

new_yticks=[]
for j, lab in enumerate(mapping_data.values()):
	new_yticks.append(matplotlib.text.Text(0, (j + 1.5), text=lab))

display._cbar.ax.set_yticklabels(new_yticks)
display.savefig(snakemake.output.png,dpi=300)
display.close()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from nilearn import plotting
import matplotlib.pyplot as plt
import matplotlib
from scipy import ndimage
from nilearn import plotting, image
import matplotlib.pyplot as plt
import matplotlib
import nibabel as nib
from nibabel.affines import apply_affine
matplotlib.use('Qt5Agg')
import numpy as np
import base64
import os
from io import BytesIO
import base64
from svgutils.transform import SVGFigure, GroupElement,fromstring
from svgutils.compose import Unit
from tempfile import TemporaryDirectory
from pathlib import Path
from uuid import uuid4
import re
import numpy as np
from matplotlib import gridspec
from matplotlib.colors import LinearSegmentedColormap
np.set_printoptions(precision=6,suppress=True)

from nilearn.datasets import load_mni152_template


def svg2str(display_object, dpi):
	"""Serialize a nilearn display object to string."""
	from io import StringIO

	image_buf = StringIO()
	display_object.frame_axes.figure.savefig(
		image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
	)
	return image_buf.getvalue()

def extract_svg(display_object, dpi=300):
	"""Remove the preamble of the svg files generated with nilearn."""
	image_svg = svg2str(display_object, dpi)

	image_svg = re.sub(' height="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(' width="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(
		" viewBox", ' preseveAspectRation="xMidYMid meet" viewBox', image_svg, count=1
	)
	start_tag = "<svg "
	start_idx = image_svg.find(start_tag)
	end_tag = "</svg>"
	end_idx = image_svg.rfind(end_tag)

	# rfind gives the start index of the substr. We want this substr
	# included in our return value so we add its length to the index.
	end_idx += len(end_tag)
	return image_svg[start_idx:end_idx]

def clean_svg(fg_svgs, bg_svgs, ref=0):
	# Find and replace the figure_1 id.
	svgs = bg_svgs+fg_svgs
	roots = [f.getroot() for f in svgs]

	sizes = []
	for f in svgs:
		viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
		width = int(viewbox[2])
		height = int(viewbox[3])
		sizes.append((width, height))
	nsvgs = len([bg_svgs])

	sizes = np.array(sizes)

	# Calculate the scale to fit all widths
	width = sizes[ref, 0]
	scales = width / sizes[:, 0]
	heights = sizes[:, 1] * scales

	# Compose the views panel: total size is the width of
	# any element (used the first here) and the sum of heights
	fig = SVGFigure(Unit(f"{width}px"), Unit(f"{heights[:nsvgs].sum()}px"))

	yoffset = 0
	for i, r in enumerate(roots):
		r.moveto(0, yoffset, scale_x=scales[i])
		if i == (nsvgs - 1):
			yoffset = 0
		else:
			yoffset += heights[i]

	# Group background and foreground panels in two groups
	if fg_svgs:
		newroots = [
			GroupElement(roots[:nsvgs], {"class": "background-svg"}),
			GroupElement(roots[nsvgs:], {"class": "foreground-svg"}),
		]
	else:
		newroots = roots

	fig.append(newroots)
	fig.root.attrib.pop("width", None)
	fig.root.attrib.pop("height", None)
	fig.root.set("preserveAspectRatio", "xMidYMid meet")

	with TemporaryDirectory() as tmpdirname:
		out_file = Path(tmpdirname) / "tmp.svg"
		fig.save(str(out_file))
		# Post processing
		svg = out_file.read_text().splitlines()

	# Remove <?xml... line
	if svg[0].startswith("<?xml"):
		svg = svg[1:]

	# Add styles for the flicker animation
	if fg_svgs:
		svg.insert(
			2,
			"""\
<style type="text/css">
@keyframes flickerAnimation%s { 0%% {opacity: 1;} 100%% { opacity:0; }}
.foreground-svg { animation: 1s ease-in-out 0s alternate none infinite running flickerAnimation%s;}
.foreground-svg:hover { animation-play-state: running;}
</style>"""
			% tuple([uuid4()] * 2),
		)

	return svg

debug = False
if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="070"
	datap=r'/media/veracrypt6/projects/SEEG/derivatives/atlasreg/'

	input=dotdict({
		'img':datap+f'sub-P{isub}/sub-P{isub}_desc-masked_from-atropos3seg_T1w.nii.gz',
		'wm':datap+f'sub-P{isub}/sub-P{isub}_label-WM_desc-atropos3seg_probseg.nii.gz',
		'gm':datap+f'sub-P{isub}/sub-P{isub}_label-GM_desc-atropos3seg_probseg.nii.gz',
		'csf':datap+f'sub-P{isub}/sub-P{isub}_label-CSF_desc-atropos3seg_probseg.nii.gz',
	})

	output=dotdict({
		#'html':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.html',
		'html':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.html',
		#'png':'/home/greydon/Downloads/' + f'sub-P{isub}_from-contrast_to-noncontrast_regqc.png'
		'png':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.png'
	})

	snakemake = Namespace(output=output, input=input)

	title = 'sub-P001'


#%%

ref_img=nib.load(snakemake.input.img)
ref_resamp = nib.nifti1.Nifti1Image(ref_img.get_fdata(), affine=ref_img.affine,header=ref_img.header)
ref_resamp = image.resample_img(ref_img, target_affine=np.eye(3), interpolation='continuous')

coords = plotting.find_xyz_cut_coords(ref_resamp)

wm_seg=nib.load(snakemake.input.wm)
wm_seg = nib.nifti1.Nifti1Image(wm_seg.get_fdata(), affine=wm_seg.affine,header=wm_seg.header)

gm_seg=nib.load(snakemake.input.gm)
gm_seg = nib.nifti1.Nifti1Image(gm_seg.get_fdata(), affine=gm_seg.affine,header=gm_seg.header)

csf_seg=nib.load(snakemake.input.csf)
csf_seg = nib.nifti1.Nifti1Image(csf_seg.get_fdata(), affine=csf_seg.affine,header=csf_seg.header)


fig, axes = plt.subplots(3, 1,figsize=(16,12))
fig.tight_layout(pad=2)

# make figure of thalamic contours
plotting.plot_roi(roi_img=csf_seg, bg_img=ref_resamp, display_mode='ortho', draw_cross=False, cut_coords=coords,cmap=LinearSegmentedColormap.from_list("",['black','red']),axes=axes[0])
plotting.plot_roi(roi_img=wm_seg, bg_img=ref_resamp, display_mode='ortho', draw_cross=False, cut_coords=coords,cmap=LinearSegmentedColormap.from_list("",['black','yellow']),axes=axes[1])
plotting.plot_roi(roi_img=gm_seg, bg_img=ref_resamp, display_mode='ortho', draw_cross=False, cut_coords=coords,cmap=LinearSegmentedColormap.from_list("",['black','green']),axes=axes[2])

axes[0].set_title('CSF', fontdict={'fontsize': 20, 'fontweight': 'bold'})
axes[1].set_title('White Matter', fontdict={'fontsize': 20, 'fontweight': 'bold'})
axes[2].set_title('Gray Matter', fontdict={'fontsize': 20, 'fontweight': 'bold'})

fig.savefig(snakemake.output.png,dpi=300)
plt.close(fig)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from nilearn import plotting, image
import matplotlib.pyplot as plt
import matplotlib
import nibabel as nib
matplotlib.use('Qt5Agg')
import numpy as np
import base64
import os
from io import BytesIO
import base64
from svgutils.transform import SVGFigure, GroupElement,fromstring
from svgutils.compose import Unit
from tempfile import TemporaryDirectory
from pathlib import Path
from uuid import uuid4
import re
import numpy as np
from matplotlib import gridspec




def svg2str(display_object, dpi):
	"""Serialize a nilearn display object to string."""
	from io import StringIO

	image_buf = StringIO()
	display_object.frame_axes.figure.savefig(
		image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
	)
	return image_buf.getvalue()

def extract_svg(display_object, dpi=300):
	"""Remove the preamble of the svg files generated with nilearn."""
	image_svg = svg2str(display_object, dpi)

	image_svg = re.sub(' height="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(' width="[0-9]+[a-z]*"', "", image_svg, count=1)
	image_svg = re.sub(
		" viewBox", ' preseveAspectRation="xMidYMid meet" viewBox', image_svg, count=1
	)
	start_tag = "<svg "
	start_idx = image_svg.find(start_tag)
	end_tag = "</svg>"
	end_idx = image_svg.rfind(end_tag)

	# rfind gives the start index of the substr. We want this substr
	# included in our return value so we add its length to the index.
	end_idx += len(end_tag)
	return image_svg[start_idx:end_idx]

def clean_svg(fg_svgs, bg_svgs, ref=0):
	# Find and replace the figure_1 id.
	svgs = bg_svgs+fg_svgs
	roots = [f.getroot() for f in svgs]

	sizes = []
	for f in svgs:
		viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
		width = int(viewbox[2])
		height = int(viewbox[3])
		sizes.append((width, height))
	nsvgs = len([bg_svgs])

	sizes = np.array(sizes)

	# Calculate the scale to fit all widths
	width = sizes[ref, 0]
	scales = width / sizes[:, 0]
	heights = sizes[:, 1] * scales

	# Compose the views panel: total size is the width of
	# any element (used the first here) and the sum of heights
	fig = SVGFigure(Unit(f"{width}px"), Unit(f"{heights[:nsvgs].sum()}px"))

	yoffset = 0
	for i, r in enumerate(roots):
		r.moveto(0, yoffset, scale_x=scales[i])
		if i == (nsvgs - 1):
			yoffset = 0
		else:
			yoffset += heights[i]

	# Group background and foreground panels in two groups
	if fg_svgs:
		newroots = [
			GroupElement(roots[:nsvgs], {"class": "background-svg"}),
			GroupElement(roots[nsvgs:], {"class": "foreground-svg"}),
		]
	else:
		newroots = roots

	fig.append(newroots)
	fig.root.attrib.pop("width", None)
	fig.root.attrib.pop("height", None)
	fig.root.set("preserveAspectRatio", "xMidYMid meet")

	with TemporaryDirectory() as tmpdirname:
		out_file = Path(tmpdirname) / "tmp.svg"
		fig.save(str(out_file))
		# Post processing
		svg = out_file.read_text().splitlines()

	# Remove <?xml... line
	if svg[0].startswith("<?xml"):
		svg = svg[1:]

	# Add styles for the flicker animation
	if fg_svgs:
		svg.insert(
			2,
			"""\
<style type="text/css">
@keyframes flickerAnimation%s { 0%% {opacity: 1;} 100%% { opacity:0; }}
.foreground-svg { animation: 1s ease-in-out 0s alternate none infinite running flickerAnimation%s;}
.foreground-svg:hover { animation-play-state: running;}
</style>"""
			% tuple([uuid4()] * 2),
		)

	return svg

debug = False
if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	isub="047"
	datap=r'/home/greydon/Documents/data/clinical/derivatives/atlasreg/'

	input=dotdict({
		'flo':datap+f'sub-P{isub}/sub-P{isub}_space-T1w_desc-rigid_ct.nii.gz',
		#'flo':datap+f'sub-P{isub}/sub-P{isub}_space-MNI152NLin2009cSym_desc-affine_T1w.nii.gz',
		'ref':datap+f'sub-P{isub}/sub-P{isub}_acq-contrast_T1w.nii.gz'
		#'ref':'/home/greydon/Documents/GitHub/seeg2bids-pipeline/resources/tpl-MNI152NLin2009cSym/tpl-MNI152NLin2009cAsym_res-1_T1w.nii.gz'
	})

	output=dotdict({
		'html':'/home/greydon/Downloads/' + f'sub-P{isub}_from-ct_to-noncontrast_regqc.html',
		#'html':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.html',
		'png':'/home/greydon/Downloads/' + f'sub-P{isub}_from-ct_to-noncontrast_regqc.png'
		#'png':'/home/greydon/Downloads/' + f'sub-P{isub}_desc-affine_from-subject_to-MNI152NLin2009cSym_regqc.png'
	})

	snakemake = Namespace(output=output, input=input)

	title = 'sub-P001'


#%%


ref_img=nib.load(snakemake.input.ref)
ref_img_data = np.round(ref_img.get_fdata()).astype(np.float32)
ref_img = nib.Nifti1Image(ref_img_data, header=ref_img.header, affine=ref_img.affine)
ref_img.header.set_data_dtype('float32')

flo_img=nib.load(snakemake.input.flo)
flo_img_data = np.round(flo_img.get_fdata()).astype(np.float32)
flo_img = nib.Nifti1Image(flo_img_data, header=flo_img.header, affine=flo_img.affine)
flo_img.header.set_data_dtype('float32')


if not any(x in os.path.basename(snakemake.output.png) for x in ('from-subject_to-')):
	ref_img = nib.nifti1.Nifti1Image(ref_img.get_fdata(), affine=ref_img.affine,header=ref_img.header)
	ref_resamp = image.resample_img(ref_img, target_affine=np.eye(3), interpolation='continuous')
	flo_img = nib.nifti1.Nifti1Image(flo_img.get_fdata(), affine=flo_img.affine,header=flo_img.header)
	flo_resamp = image.resample_img(flo_img, target_affine=np.eye(3), interpolation='continuous')
else:
	ref_resamp=image.resample_img(ref_img, target_affine=np.eye(3), interpolation='continuous')
	flo_resamp = image.resample_to_img(flo_img, ref_resamp, interpolation='continuous')



plot_args_ref={'dim':-1}
if any(x in os.path.basename(snakemake.output.png) for x in ('from-subject_to-')):
	plot_args_ref={'dim':1}

plot_args_flo={'dim':-1}
if any(x in os.path.basename(snakemake.output.png) for x in ('from-ct')):
	plot_args_flo={'dim':0}


display = plotting.plot_anat(ref_resamp, display_mode='ortho', draw_cross=False,cut_coords=[0,0,40], **plot_args_ref)
fg_svgs = [fromstring(extract_svg(display,450))]
display.close()

display = plotting.plot_anat(flo_resamp, display_mode='ortho', draw_cross=False,cut_coords=[0,0,40], **plot_args_flo)
bg_svgs = [fromstring(extract_svg(display,450))]
display.close()

final_svg="\n".join(clean_svg(fg_svgs, bg_svgs))


# make figure of thalamic contours
display = plotting.plot_anat(ref_resamp, display_mode='ortho',draw_cross=False,cut_coords=[0,0,40],**plot_args_ref)
display.add_contours(flo_resamp,alpha=0.6,colors='r',linewidths=0.5)
display.savefig(snakemake.output.png,dpi=300)

tmpfile_ref = BytesIO()
display.savefig(tmpfile_ref,dpi=300)
display.close()
tmpfile_ref.seek(0)
data_uri = base64.b64encode(tmpfile_ref.getvalue()).decode('utf-8')
img_tag = '<center><img src="data:image/png;base64,{0}"/></center>'.format(data_uri)


htmlbase='<!DOCTYPE html> <html lang="en"> <head> <title>Slice viewer</title>  <meta charset="UTF-8" /> </head> <body>'
htmlend='</body> </html>'

htmlfull=htmlbase + final_svg + img_tag + htmlend

# Write HTML String to file.html
with open(snakemake.output.html, "w") as file:
	file.write(htmlfull)
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import os
import pandas as pd
import numpy as np
import re
import csv
from bids.layout import BIDSLayout
import shutil
from collections import OrderedDict

chan_label_dic = {
					'LAntSSMA': 'LASSMA',
					'RAntSSMA': 'RASSMA',
					'LPostSSMA': 'LPSSMA',
					'RPostSSMA': 'RPSSMA',
					'LMidIn': 'LMIn',
					'RMidIn': 'RMIn',
					'LOFr': 'LOF',
					'ROFr': 'ROF',
					'LAntOF': 'LAOF',
					'RAntOF': 'RAOF',
					'LPostOF': 'LPOF',
					'RPostOF': 'RPOF',
					'LFACing': 'LACg',
					'RFACing': 'RACg',
					'LAMesFr': 'LAMeFr',
					'RAMesFr': 'RAMeFr',
					'LPMesFr': 'LPMeFr',
					'RPMesFr': 'RPMeFr',
					'LAmy': 'LAm',
					'RAmy': 'RAm',
					'LTAmy': 'LAm',
					'RTAmy': 'RAm',
					'LTPole': 'LTeP',
					'RTPole': 'RTeP',
					'LTAHc': 'LAHc',
					'RTAHc': 'RAHc',
					'LTPHc': 'LPHc',
					'RTPHc': 'RPHc',
					'LPost_Central': 'LPCe',
					'RPost_Central': 'RPCe',
					'LFr_Convex': 'LFrC',
					'RFr_Convex': 'RFrC',
					'LPost_to_Les': 'LPLs',
					'RPost_to_Les': 'RPLs',
					'LFAnt_to_Les': 'LALs',
					'RFAnt_to_Les': 'RALs',
					'LTPO_PostCing': 'LPCg',
					'RTPO_PostCing': 'RPCg',
					'LTPOPostCing': 'LPCg',
					'RTPOPostCing': 'RPCg',
					'LAntCing': 'LACg',
					'RAntCing': 'RACg',
					'LPostTe_MedOcc': 'LPTMOc',
					'RPostTe_MedOcc': 'RPTMOc',
					'LPreCent_Face': 'LPrCeP',
					'RPreCent_Face': 'RPrCeP',
					'LFACing': 'LACg',
					'RFACing': 'RACg',
					'LSensory_Cx_Leg': 'LSleg',
					'RSensory_Cx_Leg': 'RSleg',
					'LMotor_Cx_Leg': 'LMleg',
					'RMotor_Cx_Leg': 'RMleg',
					'LLesion_SUP': 'LSLs',
					'RLesion_SUP': 'RSLs',
					'LLesion_POST': 'LPLs',
					'RLesion_POST': 'RPLs',
					'LLesion_INF': 'LILs',
					'RLesion_INF': 'RILs',
					'LLesion_ANT': 'LALs',
					'RLesion_ANT': 'RALs',
					'LTFusiform': 'LFGy',
					'RTFusiform': 'RFGy',
					'LFOF': 'LOFr',
					'RFOF': 'LOFr',
					'LSupMargPCing': 'LPCg',
					'RSupMargPCing': 'RPCg',
					'LTHeschl': 'LHs',
					'RTHeschl': 'RHs',
					}

def determine_groups(iterable, numbered_labels=False):
	values = []
	for item in iterable:
		temp=None
		if re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item):
			temp = "".join(list(re.findall(r"([a-zA-Z]+)([0-9]+)([a-zA-Z]+)", item)[0]))
		elif '-' in item:
			temp=item.split('-')[0]
		else:
			if numbered_labels:
				temp=''.join([x for x in item if not x.isdigit()])
				for sub in ("T1","T2"):
					if sub in item:
						temp=item.split(sub)[0] + sub
			else:
				temp=item
		if temp is None:
			temp=item

		values.append(temp)

	vals,indexes,count = np.unique(values, return_index=True, return_counts=True)
	values_unique = [values[index] for index in sorted(indexes)]

	return values_unique,count


def levenshtein_ratio_and_distance(s, t, ratio_calc = False):
	""" levenshtein_ratio_and_distance:
		Calculates levenshtein distance between two strings.
		If ratio_calc = True, the function computes the
		levenshtein distance ratio of similarity between two strings
		For all i and j, distance[i,j] will contain the Levenshtein
		distance between the first i characters of s and the
		first j characters of t
	"""
	# Initialize matrix of zeros
	rows = len(s)+1
	cols = len(t)+1
	distance = np.zeros((rows,cols),dtype = int)

	# Populate matrix of zeros with the indeces of each character of both strings
	for i in range(1, rows):
		for k in range(1,cols):
			distance[i][0] = i
			distance[0][k] = k

	# Iterate over the matrix to compute the cost of deletions,insertions and/or substitutions    
	for col in range(1, cols):
		for row in range(1, rows):
			if s[row-1] == t[col-1]:
				cost = 0 # If the characters are the same in the two strings in a given position [i,j] then the cost is 0
			else:
				# In order to align the results with those of the Python Levenshtein package, if we choose to calculate the ratio
				# the cost of a substitution is 2. If we calculate just distance, then the cost of a substitution is 1.
				if ratio_calc == True:
					cost = 2
				else:
					cost = 1
			distance[row][col] = min(distance[row-1][col] + 1,      # Cost of deletions
								 distance[row][col-1] + 1,          # Cost of insertions
								 distance[row-1][col-1] + cost)     # Cost of substitutions
	if ratio_calc == True:
		# Computation of the Levenshtein Distance Ratio
		Ratio = ((len(s)+len(t)) - distance[row][col]) / (len(s)+len(t))
		return Ratio
	else:
		# print(distance) # Uncomment if you want to see the matrix showing how the algorithm computes the cost of deletions,
		# insertions and/or substitutions
		# This is the minimum number of edits needed to convert string a to string b
		return "The strings are {} edits away".format(distance[row][col])

def make_bids_filename(subject_id, space_id, desc_id, suffix, prefix):

	order = OrderedDict([('space', space_id if space_id is not None else None),
						 ('desc', desc_id if desc_id is not None else None)])

	filename = []
	if subject_id is not None:
		filename.append(subject_id)
	for key, val in order.items():
		if val is not None:
			filename.append('%s-%s' % (key, val))

	if isinstance(suffix, str):
		filename.append(suffix)

	filename = '_'.join(filename)
	if isinstance(prefix, str):
		filename = os.path.join(prefix, filename)

	return filename

def determineFCSVCoordSystem(input_fcsv):
	# need to determine if file is in RAS or LPS
	# loop through header to find coordinate system
	coordFlag = re.compile('# CoordinateSystem')
	coord_sys=None
	with open(input_fcsv, 'r+') as fid:
		rdr = csv.DictReader(filter(lambda row: row[0]=='#', fid))
		row_cnt=0
		for row in rdr:
			cleaned_dict={k:v for k,v in row.items() if k is not None}
			if any(coordFlag.match(x) for x in list(cleaned_dict.values())):
				coordString = list(filter(coordFlag.match,  list(cleaned_dict.values())))
				assert len(coordString)==1
				coord_sys = coordString[0].split('=')[-1].strip()
			row_cnt +=1

	if any(x in coord_sys for x in {'LPS','1'}):
		df = pd.read_csv(input_fcsv, skiprows=3, header=None)
		df[1] = -1 * df[1] # flip orientation in x
		df[2] = -1 * df[2] # flip orientation in y

		with open(input_fcsv, 'w') as fid:
			fid.write("# Markups fiducial file version = 4.11\n")
			fid.write("# CoordinateSystem = 0\n")
			fid.write("# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")

		df.rename(columns={0:'node_id', 1:'x', 2:'y', 3:'z', 4:'ow', 5:'ox',
							6:'oy', 7:'oz', 8:'vis', 9:'sel', 10:'lock',
							11:'label', 12:'description', 13:'associatedNodeID'}, inplace=True)

		df['associatedNodeID']= pd.Series(np.repeat('',df.shape[0]))
		df.round(3).to_csv(input_fcsv, sep=',', index=False, line_terminator="", mode='a', header=False)

		print(f"Converted LPS to RAS: {os.path.dirname(input_fcsv)}/{os.path.basename(input_fcsv)}")


debug = False

if debug:
	class dotdict(dict):
		"""dot.notation access to dictionary attributes"""
		__getattr__ = dict.get
		__setattr__ = dict.__setitem__
		__delattr__ = dict.__delitem__

	class Namespace:
		def __init__(self, **kwargs):
			self.__dict__.update(kwargs)

	sub='P009'

	config=dotdict({'out_dir':'/home/greydon/Documents/data/SEEG_peds'})
	#config=dotdict({'out_dir':'/media/stereotaxy/3E7CE0407CDFF11F/data/SEEG/imaging/clinical'})

	params=dotdict({'sub':sub})
	input=dotdict({'seega_scene':f'/home/greydon/Documents/data/SEEG_peds/derivatives/seega_scenes/sub-{sub}'})
	#input=dotdict({'seega_scene':f'/home/greydon/Documents/data/SEEG/derivatives/seega_scenes/sub-{sub}'})

	snakemake = Namespace(params=params, input=input,config=config)

#%%

isub='sub-'+snakemake.params.sub

patient_output = os.path.join(snakemake.config['out_dir'], 'derivatives','seega_coordinates',isub)
if not os.path.exists(patient_output):
	os.makedirs(patient_output)

patient_files = []
for dirpath, subdirs, subfiles in os.walk(os.path.dirname(snakemake.input.seega_scene[0])):
	for x in subfiles:
		if x.endswith(".fcsv") and not x.startswith('coords'):
			patient_files.append(os.path.join(dirpath, x))


acpc_file = [x for x in patient_files if os.path.splitext(x)[0].lower().endswith('acpc')]
patient_files = [x for x in patient_files if any(os.path.splitext(x)[0].lower().endswith(y) for y in ('seega','planned','actual'))]

if acpc_file:

	# determine the coordinate system of the FCSV
	determineFCSVCoordSystem(acpc_file[0])

	acpc_data = pd.read_csv(acpc_file[0], skiprows=3, header=None)
	acpc_data.rename(columns={0:'node_id', 1:'x', 2:'y', 3:'z', 4:'ow', 5:'ox',
						6:'oy', 7:'oz', 8:'vis', 9:'sel', 10:'lock',
						11:'label', 12:'description', 13:'associatedNodeID'}, inplace=True)
	ac_point = acpc_data.loc[acpc_data['label'] =='ac', 'x':'z'].values[0]
	pc_point = acpc_data.loc[acpc_data['label'] =='pc', 'x':'z'].values[0]
	mcp_point = [(ac_point[0]+pc_point[0])/2, (ac_point[1]+pc_point[1])/2, (ac_point[2]+pc_point[2])/2]
	output_matrix_txt = make_bids_filename(isub, 'T1w', None, 'mcp.tfm', patient_output)
	with open(output_matrix_txt, 'w') as fid:
		fid.write("#Insight Transform File V1.0\n")
		fid.write("#Transform 0\n")
		fid.write("Transform: AffineTransform_double_3_3\n")
		fid.write("Parameters: 1 0 0 0 1 0 0 0 1 {} {} {}\n".format(1*(round(mcp_point[0],3)), 1*(round(mcp_point[1],3)), -1*(round(mcp_point[2],3))))
		fid.write("FixedParameters: 0 0 0\n")

for ifile in patient_files:

	# determine the coordinate system of the FCSV
	determineFCSVCoordSystem(ifile)

	data_table_full = pd.read_csv(ifile, skiprows=3, header=None)
	data_table_full.rename(columns={0:'node_id', 1:'x', 2:'y', 3:'z', 4:'ow', 5:'ox',
					6:'oy', 7:'oz', 8:'vis', 9:'sel', 10:'lock',
					11:'label', 12:'description', 13:'associatedNodeID'}, inplace=True)

	#data_table_full['label'] = data_table_full['label'].str.replace('-','')
	data_table_full['type'] = np.repeat(ifile.split(os.sep)[-1].split('.fcsv')[0], data_table_full.shape[0])

	if os.path.splitext(ifile.split(os.sep)[-1])[0].lower().endswith('seega'):
		groups,n_members = determine_groups(np.array(data_table_full['label'].values), True)

		group_pair = []
		new_label = []
		new_group = []
		for ichan in data_table_full['label'].values:
			group_pair.append([x for x in groups if ichan.startswith(x)][0])
			if '_' in group_pair[-1]:
				group_pair[-1] = "_".join(["".join(x for x in group_pair[-1].split('_')[0] if not x.isdigit())] + group_pair[-1].split('_')[1:])

			if "".join(x for x in group_pair[-1] if not x.isdigit()) in list(chan_label_dic):
				temp = "".join(x for x in group_pair[-1] if not x.isdigit())
				new_group.append(chan_label_dic[temp])
				new_label.append(ichan.replace(temp, chan_label_dic[temp]))
			else:
				new_group.append("".join(x for x in group_pair[-1]))
				new_label.append(new_group[-1] + ichan.split(group_pair[-1])[-1])

		data_table_full.insert(data_table_full.shape[1],'orig_group',group_pair)
		data_table_full.insert(data_table_full.shape[1],'new_label',new_label)
		data_table_full.insert(data_table_full.shape[1],'new_group',new_group)

	if acpc_file:
		data_table_full['x_mcp'] = data_table_full['x'] - mcp_point[0]
		data_table_full['y_mcp'] = data_table_full['y'] - mcp_point[1]
		data_table_full['z_mcp'] = data_table_full['z'] - mcp_point[2]

		#### Write MCP Based Coords TSV file
		output_fname = make_bids_filename(isub, 'acpc', None, ifile.split(os.sep)[-1].split('.fcsv')[0] + '.tsv', patient_output)
		if 'seega' in ifile.split(os.sep)[-1].split('.fcsv')[0].lower():
			head = ['type','label','x_mcp','y_mcp','z_mcp','orig_group','new_label','new_group']
		else:
			head = ['type','label','x_mcp','y_mcp','z_mcp']

		data_table_full.round(3).to_csv(output_fname, sep='\t', index=False, na_rep='n/a', line_terminator="", columns = head)

		#### Write MCP Based Coords FCSV file
		output_fname = make_bids_filename(isub, 'acpc', None, ifile.split(os.sep)[-1], patient_output)
		with open(output_fname, 'w') as fid:
			fid.write("# Markups fiducial file version = 4.11\n")
			fid.write("# CoordinateSystem = 0\n")
			fid.write("# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")

		head = ['node_id', 'x_mcp', 'y_mcp', 'z_mcp', 'ow', 'ox', 'oy', 'oz', 'vis','sel', 'lock', 'label', 'description', 'associatedNodeID']
		data_table_full['node_id'] = ['vtkMRMLMarkupsFiducialNode_' + str(x) for x in range(data_table_full.shape[0])]
		data_table_full['associatedNodeID'] = np.repeat('',data_table_full.shape[0])
		data_table_full.round(3).to_csv(output_fname, sep=',', index=False, line_terminator="", columns = head, mode='a', header=False)

	#### Write Native Coords TSV file
	output_fname = make_bids_filename(isub, 'native', None, ifile.split(os.sep)[-1].split('.fcsv')[0] + '.tsv', patient_output)
	if os.path.splitext(ifile.split(os.sep)[-1])[0].lower().endswith('seega'):
		head=['type','label','x','y','z','orig_group','new_label','new_group']
	else:
		head=['type','label','x','y','z']

	data_table_full.round(3).to_csv(output_fname, sep='\t', index=False, na_rep='n/a', line_terminator="", columns = head)

	#### Write Native Coords FCSV file
	output_fname = make_bids_filename(isub, 'native', None, ifile.split(os.sep)[-1], patient_output)
	with open(output_fname, 'w') as fid:
		fid.write("# Markups fiducial file version = 4.11\n")
		fid.write("# CoordinateSystem = 0\n")
		fid.write("# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")

	head = ['node_id', 'x', 'y', 'z', 'ow', 'ox', 'oy', 'oz', 'vis','sel', 'lock', 'label', 'description', 'associatedNodeID']
	del data_table_full['node_id']
	del data_table_full['associatedNodeID']
	data_table_full.insert(data_table_full.shape[1],'node_id',pd.Series(['vtkMRMLMarkupsFiducialNode_' + str(x) for x in range(data_table_full.shape[0])]))
	data_table_full.insert(data_table_full.shape[1],'associatedNodeID', pd.Series(np.repeat('',data_table_full.shape[0])))
	data_table_full.round(3).to_csv(output_fname, sep=',', index=False, line_terminator="", columns = head, mode='a', header=False)


coords_fname = make_bids_filename(isub, 'native', None, 'SEEGA.tsv', patient_output)
coords_table = pd.read_csv(coords_fname, sep='\t', header=0)

indexes = np.unique(coords_table['new_group'], return_index=True)[1]
slicer_chans_groups = [coords_table['new_group'][index] for index in sorted(indexes)]

indexes = np.unique(coords_table['orig_group'], return_index=True)[1]
slicer_chans_groups_orig = [coords_table['orig_group'][index] for index in sorted(indexes)]

coords_pairs = {}
coords_pairs['ieeg_labels'] = list(np.repeat(np.nan,len(slicer_chans_groups)))
coords_pairs['combined_labels'] = slicer_chans_groups
coords_pairs['seega_labels'] = slicer_chans_groups_orig

coords_pairs = pd.DataFrame(coords_pairs)
output_fname = make_bids_filename(isub, None, None, 'mapping.tsv', patient_output)
coords_pairs.to_csv(output_fname, sep='\t', index=False, na_rep='n/a', line_terminator="")
ShowHide 54 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/greydongilmore/seeg2bids-pipeline
Name: seeg2bids-pipeline
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: MIT License
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...