A pipeline to connect GWAS Variant-to-Gene-to-Program (V2G2P) Approach

public public 1yr ago 0 bookmarks

Variant-to-Gene-to-Program (V2G2P) Approach

Description

A pipeline to connect GWAS variants to genes to disease-associate gene programs. This pipeline uses snakemake and cNMF from Kotliar et al. The V2G2P approach could be applied to any GWAS studies with the correct cell type(s).

Overview of the V2G2P pipeline

The V2G2P approach has three components: V2G, G2P, and V2G2P enrichment test. Each one works as a stand-alone pipeline. Together these three components are essential for V2G2P. Below is an overview of the relations among the three steps:

V2G2P overview

The V2G pipeline is linked here . Author: Rosa Ma.

Details about G2P and V2G2P enrichment test

Below is a figure showing the different modules and features within the G2P pipeline:

G2P and V2G2P enrichment

Usage

Step 1: Clone this github repository

Step 2: Install conda environment

Install Snakemake and conda environment using conda:

bash conda env create -f conda_env/cnmf_env.yml conda env create -f conda_env/cnmf_analysis_R.yml cnmf_env contains snakemake. If you do not have snakemake installed already, you can activate the environment via conda activate cnmf_env , then run the pipeline.

Step 3: Gather all input data

Necessary inputs:

Config file slots: | field | meaning | total workers | Number of processes to run in parallel | | seed | A number to set seed for reproducibility | | num runs | Number of NMF run (recommend 100 for the actual data analysis, and 10 for testing the pipeline)

Step 4: Run the pipeline

sh conda activate cnmf_env snakemake -n --configfile /path/to/config.json --quiet ## always recommend doing a dry run

Execute the workflow locally via sh snakemake --configfile /path/to/config.json Please see the log.sh file in this github page for more examples.

For more snakemake usage and configuration, please visit snakemake documentation page.

Outputs

The output files are in the folders specified in analysisDir and figDir fields in the config file.

Analysis files (in analysisDir)

Summary output for choosing the number of components

Output can be found in config['analysisDir']/{cNMF_gene_selection}/{sampleName}/acrossK/

Outputs for each model (each choice of # components)

Output can be found in config['analysisDir']/{cNMF_gene_selection}/{sampleName}/K*/threshold_*/

Code Snippets

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/perturbationAnalysis.R \
	--sampleName {wildcards.sample} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir}/ \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--cell.count.thr {wildcards.min_cell_per_guide} \
	--guide.count.thr {wildcards.min_guide_per_ptb} \
	--recompute F \
	--motif.enhancer.background /oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv \
	--motif.promoter.background /oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv \
	' "
930
931
932
933
934
935
936
937
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/MAST_DE_Topics_scatter.R \
	--barcode.names {params.barcode_names} \
	--sampleName {wildcards.sample} \
	--num.genes.per.MAST.runGroup {params.num_genes_per_MAST_runGroup} \
	--scatteroutput {params.scatteroutput} ' "
955
956
957
958
959
960
961
962
963
964
965
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/MAST_DE_Topics_preparation.R \
	--barcode.names {params.barcode_names} \
	--outdirsample {params.outdirsample} \
	--scatteroutput {params.scatteroutput} \
	--numCtrl {params.numCtrl} \
	--sampleName {wildcards.sample} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold}  ' "
984
985
986
987
988
989
990
991
992
993
994
995
996
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/MAST_DE_Topics_runGroups.R \
	--barcode.names {params.barcode_names} \
	--outdirsample {params.outdirsample} \
	--scatteroutput {params.scatteroutput} \
	--gene.group.list {input.MAST_gene_groups} \
	--scatter.gene.group {wildcards.MAST_run_index} \
	--sampleName {wildcards.sample} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--scriptdir {params.perturbAnalysis_scriptdir} ' "
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/MAST_DE_Topics_gatherGroups.R \
	--barcode.names {params.barcode_names} \
	--outdirsample {params.outdirsample} \
	--scatteroutput {params.scatteroutput} \
	--total.scatter.gene.group {params.MAST_num_runs} \
	--sampleName {wildcards.sample} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--scriptdir {params.perturbAnalysis_scriptdir} ' "
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
	shell:
		"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_analysis_R; \
		Rscript workflow/scripts/aggregate_across_K_perturb-seq.R \
		--figdir {params.figdir} \
		--outdir {params.analysisdir} \
		--datadir {params.datadir} \
		--sampleName {wildcards.sample} \
		--K.list {params.klist_comma} \
		--K.table {params.K_spectra_threshold_table} ' " ## how to create this automatically?


rule findK_plot_perturb_seq:
	input:
		toplot = os.path.join(config["analysisDir"], "{folder}/{sample}/acrossK/aggregated.outputs.findK.perturb-seq.RData")
	output:
		percent_batch_topics_plot = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/percent.batch.topics.pdf")

	params:
		time = "3:00:00",
		mem_gb = "64",
		figdir = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/"),
		analysisdir = os.path.join(config["analysisDir"], "{folder}/{sample}/acrossK/"),
		GO_threshold = 0.1,
		partition = "owners,normal"
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_findK_plots_perturb-seq.R \
	--figdir {params.figdir} \
	--outdir {params.analysisdir} \
	--sampleName {wildcards.sample} \
	--p.adj.threshold 0.1 \
	--aggregated.data {input.toplot} \
	' "
150
151
152
153
154
155
156
157
158
159
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/seurat_to_h5ad.R \
	--sampleName {params.sampleName} \
	--inputSeuratObject {input.seurat_object} \
	--output_h5ad {output.h5ad_mtx} \
	--output_gene_name_txt {output.gene_name_txt} \
	--minUMIsPerCell {params.min_UMIs_per_cell} \
	--minUniqueGenesPerCell {params.min_unique_genes_per_cell} ' "
283
284
285
286
287
288
289
290
291
292
293
294
295
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}/{wildcards.sample}; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {wildcards.k} \
	--n-iter {params.run_per_worker} \
	--total-workers {params.run_per_worker} \
	--seed {params.seed} \
	--numgenes {wildcards.num_genes} ' "
323
324
325
326
327
328
shell:
	" bash -c 'source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.todir}/{wildcards.sample}; \
	cp -r {params.fromdir}/{wildcards.sample}/cnmf_tmp {params.todir}/{wildcards.sample}/; \
	cp {params.fromdir}/{wildcards.sample}/{wildcards.sample}.overdispersed_genes.txt {params.todir}/{wildcards.sample}/ ' "
422
423
424
425
426
427
428
429
430
431
432
433
434
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}/{wildcards.sample}; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {wildcards.k} \
	--n-iter {params.run_per_worker} \
	--total-workers {params.total_workers} \
	--seed {params.seed} \
	--genes-file {input.genes} ' "
461
462
463
464
465
466
shell:
	" bash -c 'source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.todir}/{wildcards.sample}; \
	cp -r {params.fromdir}/{wildcards.sample}/cnmf_tmp {params.todir}/{wildcards.sample}/; \
	cp {params.fromdir}/{wildcards.sample}/{wildcards.sample}.overdispersed_genes.txt {params.todir}/{wildcards.sample}/ ' "
530
531
532
533
534
535
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/cNMF/cnmf.py factorize \
	--output-dir {params.outdir} \
	--name {wildcards.sample} ' "
554
555
556
557
558
559
560
561
562
563
run:
	cmd = "mkdir -p " + os.path.join(params.inputdir, wildcards.sample, "cnmf_tmp")
	shell(cmd)
	for worker in range(params.num_workers):
		for run in range(params.run_per_worker):
			from_file = os.path.join(params.inputdir, "worker" + str(worker), wildcards.sample, "cnmf_tmp/" + wildcards.sample + ".spectra.k_" + wildcards.k + ".iter_" + str(run) + ".df.npz")
			index_here = worker * params.run_per_worker + run  ### give the runs a new index
			to_file = os.path.join(params.outdir, wildcards.sample, "cnmf_tmp/" + wildcards.sample + ".spectra.k_" + wildcards.k + ".iter_" + str(index_here) + ".df.npz")
			cmd = "cp " + from_file + " " + to_file
			shell(cmd)
588
589
590
591
592
593
594
595
596
597
598
599
600
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}/K{wildcards.k}/; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {wildcards.k} \
	--n-iter {params.num_runs} \
	--total-workers 1 \
	--seed {params.seed} \
	--numgenes {wildcards.num_genes} ' "
627
628
629
630
631
632
633
634
635
636
637
638
639
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}/K{wildcards.k}/; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {wildcards.k} \
	--n-iter {params.num_runs} \
	--total-workers 1 \
	--seed {params.seed} \
	--genes-file {input.genes} ' "
663
664
665
666
667
668
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/cNMF/cnmf.py combine \
	--output-dir {params.outdir} \
	--name {wildcards.sample} ' "
685
686
687
688
689
run:
	cmd = "mkdir -p " + os.path.join(params.outdir, wildcards.folder + "_acrossK", wildcards.sample, "cnmf_tmp")
	shell(cmd)
	cmd = "cp " + input.merged_result + " " + output.merged_copied_result
	shell(cmd)
717
718
719
720
721
722
723
724
725
726
727
728
729
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {params.klist} \
	--n-iter {params.num_runs} \
	--total-workers 1 \
	--seed {params.seed} \
	--numgenes {wildcards.num_genes} ' "
758
759
760
761
762
763
764
765
766
767
768
769
770
shell:
	" bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	mkdir -p {params.outdir}; \
	python workflow/scripts/cNMF/cnmf.py prepare \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	-c {input.h5ad_mtx} \
	-k {params.klist} \
	--n-iter {params.num_runs} \
	--total-workers 1 \
	--seed {params.seed} \
	--genes-file {input.genes} ' "
815
816
817
818
819
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/cNMF/cnmf.modified.py k_selection_plot --output-dir {params.outdir} --name {wildcards.sample}; \
	cp {output.plot} {output.plot_new_location} ' "
892
893
894
895
896
897
898
899
900
run:
	threshold_here = wildcards.threshold.replace("_",".")
	shell("bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/cNMF/cnmf.modified.py consensus \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	--components {wildcards.k} \
	--local-density-threshold {threshold_here} ' ") # --show-clustering 
924
925
926
927
928
929
930
931
932
933
934
run:
	threshold_here = wildcards.threshold.replace("_",".")
	shell("bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/cNMF/cnmf.modified.py consensus \
	--output-dir {params.outdir} \
	--name {wildcards.sample} \
	--components {wildcards.k} \
	--local-density-threshold {threshold_here} \
	--show-clustering; \
	cp {params.outdir}/{wildcards.sample}/{wildcards.sample}.clustering.k_{wildcards.k}.dt_{wildcards.threshold}.png {params.figdir}/{wildcards.folder}/{wildcards.sample}/K{wildcards.k}/ ' ")
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
	shell:
		"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_analysis_R; \
		Rscript workflow/scripts/calcUMAP.only.R \
			--outdir {params.outdir}/ \
			--inputSeuratObject {input.input_seurat_object} \
			--sampleName {wildcards.sample} \
			--maxMt 50 \
			--maxCount 25000 \
			--minUniqueGenes 0 \
			--UMAP.resolution 0.6' " ## can make this a wildcard

rule plot_UMAP:
	input:
		seurat_object_withUMAP = os.path.join(config["analysisDir"], "data/{sample}.withUMAP_SeuratObject.RDS"),
		cNMF_Results = os.path.join(config["analysisDir"], "{folder}/{sample}/K{k}/threshold_{threshold}/cNMF_results.k_{k}.dt_{threshold}.RData")
	output:
		factor_expression_UMAP = os.path.join(config["figDir"], "{folder}/{sample}/K{k}/{sample}_K{k}_dt_{threshold}_Factor.Expression.UMAP.pdf")
	params:
		time = "6:00:00",
		mem_gb = "200",
		datadir = config["dataDir"],
		outdir = os.path.join(config["analysisDir"], "{folder}_acrossK/{sample}"),
		figdir = os.path.join(config["figDir"], "{folder}"), 
		analysisdir = os.path.join(config["analysisDir"], "{folder}"), # K{k}/threshold_{threshold}
		threshold = get_cNMF_filter_threshold_double,
		partition = "owners,normal"
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_UMAP_plot.R \
	--sampleName {wildcards.sample} \
	--inputSeuratObject {input.seurat_object_withUMAP} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--recompute F ' "
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/variance_explained_v2.py \
	--path_to_topics {params.path_to_topics} \
	--topic_sampleName {wildcards.sample} \
	--X_normalized {params.X_normalized_path} \
	--outdir {params.outdir} \
	--k {wildcards.k} \
	--density_threshold {params.threshold} ' "
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	conda info --env; \
	Rscript workflow/scripts/cNMF_analysis.R \
	--topic.model.result.dir {params.outdir}/ \
	--sampleName {wildcards.sample} \
	--barcode.names {params.barcode_names} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir}/ \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--recompute F \
	--organism {params.organism} \
	--motif.enhancer.background /oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv \
	--motif.promoter.background /oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv \
	' "
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_analysis_topic_plot.R \
	--sampleName {wildcards.sample} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir}/ \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--recompute F ' "
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/batch.topic.correlation.R \
	--figdir {params.figdir} \
	--outdir {params.analysisdir} \
	--sampleName {wildcards.sample} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--barcode.names {input.barcode_names} \
	--recompute F ' "
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/motif_enrichment.R \
	--sampleName {wildcards.sample} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--recompute F \
	--ep.type {wildcards.ep_type} \
	--organism {params.organism} \
	--motif.match.thr.str {wildcards.motif_match_thr} \
	--motif.enhancer.background {input.fimo_formatted} \
	--motif.promoter.background {input.fimo_formatted} '" ## to do
	# --motif.enhancer.background {input.enhancer_fimo_formatted} \
	# --motif.promoter.background {input.promoter_fimo_formatted} \
	# ' "
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_analysis_motif.enrichment_plot.R \
	--sampleName {wildcards.sample} \
	--ep.type {wildcards.ep_type} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--motif.match.thr.str {wildcards.motif_match_thr} \
	--recompute F ' "
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_analysis_gsea_clusterProfiler.R \
	--topic.model.result.dir {params.outdir} \
	--sampleName {wildcards.sample} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--ranking.type {wildcards.ranking_type} \
	--GSEA.type {wildcards.GSEA_type} \
	--organism {params.organism} \
	--recompute F ' "
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/plot_gsea_clusterProfiler.R \
	--sampleName {wildcards.sample} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--ranking.type {wildcards.ranking_type} \
	--GSEA.type {wildcards.GSEA_type} '"
1396
1397
1398
1399
1400
1401
1402
1403
1404
shell:
	"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_analysis_R; \
		Rscript workflow/scripts/create_program_summary_table.R \
		--sampleName {wildcards.sample} \
		--outdir {params.analysisdir} \
		--K.val {wildcards.k} \
		--density.thr {params.threshold} \
		--perturbSeq {params.perturbseq} '"
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
shell:
	"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_analysis_R; \
		Rscript workflow/scripts/create_comprehensive_program_summary_table.R \
		--sampleName {wildcards.sample} \
		--outdir {params.analysisdir} \
		--scratch.outdir {params.scratch_outdir} \
		--K.val {wildcards.k} \
		--density.thr {params.threshold} \
		--perturbSeq {params.perturbseq} '"
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
	shell:
		"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_analysis_R; \
		Rscript workflow/scripts/aggregate_across_K.R \
		--figdir {params.figdir} \
		--outdir {params.analysisdir} \
		--sampleName {wildcards.sample} \
		--K.list {params.klist_comma} \
		--K.table {params.K_spectra_threshold_table} ' " ## how to create this automatically?


rule findK_plot:
	input:
		toplot = os.path.join(config["analysisDir"], "{folder}/{sample}/acrossK/aggregated.outputs.findK.RData")
	output:
		GSEA_plots = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/All_GSEA.pdf"),
		TFMotifEnrichment_plots = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/All_TFMotifEnrichment.pdf"),
		topic_clustering_plot = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/cluster.topic.zscore.by.Pearson.corr.pdf"),
		variance_explained_plot = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/variance.explained.by.model.pdf")
	params:
		time = "3:00:00",
		mem_gb = "64",
		figdir = os.path.join(config["figDir"], "{folder}/{sample}/acrossK/"),
		analysisdir = os.path.join(config["analysisDir"], "{folder}/{sample}/acrossK/"),
		GO_threshold = 0.1,
		partition = "owners,normal"
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/cNMF_findK_plots.R \
	--figdir {params.figdir} \
	--outdir {params.analysisdir} \
	--sampleName {wildcards.sample} \
	--p.adj.threshold 0.1 \
	--aggregated.data {input.toplot} \
	' "
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
	shell:
		"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_env; \
		mkdir -p {output.munged_features} {params.raw_features_dir}; \
		cp -r {params.external_features}/* {params.raw_features_dir}/; \
		cd {params.munged_features_dir}; \
		python {params.pipelineDir}/workflow/scripts/pops/munge_feature_directory.py \
       		--gene_annot_path {params.gene_annot_path} \
       		--feature_dir {params.raw_features_dir} \
       		--save_prefix {wildcards.magma_prefix} \
       		--nan_policy zero  ' "
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
	shell:
		"bash -c ' source $HOME/.bashrc; \
		conda activate cnmf_env; \
		rm -r {params.raw_features_dir}; \
		mkdir -p {output.munged_features} {params.raw_features_dir}; \
 		cp -r {params.external_features}/* {params.raw_features_dir}/; \
		cp {input.cNMF_ENSG_topic_zscore_scaled} {params.raw_features_dir}/; \
		cd {params.munged_features_dir}; \
		python {params.pipelineDir}/workflow/scripts/pops/munge_feature_directory.py \
       		--gene_annot_path {params.gene_annot_path} \
       		--feature_dir {params.raw_features_dir} \
       		--save_prefix {wildcards.magma_prefix}_cNMF{wildcards.k} \
       		--nan_policy zero  ' "
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python {params.pipelineDir}/workflow/scripts/pops/pops.py \
		--gene_annot_path {params.gene_annot_path} \
		--feature_mat_prefix {params.munged_features_dir}/{wildcards.magma_prefix}_cNMF{wildcards.k} \
		--num_feature_chunks {params.num_munged_feature_chunks} \
		--control_features_path {params.PoPS_control_features} \
		--magma_prefix {params.magma_dir}/{wildcards.magma_prefix} \
		--out_prefix {params.outdir}/{wildcards.magma_prefix}_cNMF{wildcards.k} \
		--verbose ' "
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python {params.pipelineDir}/workflow/scripts/pops/pops.py \
		--gene_annot_path {params.gene_annot_path} \
		--feature_mat_prefix {params.munged_features_dir}/{wildcards.magma_prefix} \
		--num_feature_chunks {params.num_munged_feature_chunks} \
		--control_features_path {params.PoPS_control_features} \
		--magma_prefix {params.magma_dir}/{wildcards.magma_prefix} \
		--out_prefix {params.outdir}/{wildcards.magma_prefix} \
		--verbose ' "
1670
1671
1672
1673
1674
1675
1676
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	mkdir -p {params.outdir}; \
	Rscript workflow/scripts/PoPS_aggregate_features.R \
	--feature.dir {params.raw_features_dir} \
	--output {params.outdir}/ ' "
1690
1691
1692
1693
1694
1695
1696
1697
1698
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	mkdir -p {params.outdir}; \
	Rscript workflow/scripts/PoPS_aggregate_features_with_cNMF.R \
	--feature.RDS {input.all_features_RDS} \
	--cNMF.features {input.cNMF_ENSG_topic_zscore_scaled} \
	--output {params.outdir}/ \
	--prefix {wildcards.magma_prefix}_cNMF{wildcards.k} ' "
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/PoPS.data.processing.R \
	--output {params.outdir}/ \
	--scratch.output {params.scratch_outdir}/ \
	--prefix {wildcards.magma_prefix}_cNMF{wildcards.k} \
	--external.features.metadata {params.external_features_metadata} \
	--coefs_with_cNMF {input.coefs_with_cNMF} \
	--preds_with_cNMF {input.preds_with_cNMF} \
	--marginals_with_cNMF {input.marginals_with_cNMF} \
	--coefs_without_cNMF {input.coefs_without_cNMF} \
	--preds_without_cNMF {input.preds_without_cNMF} \
	--marginals_without_cNMF {input.marginals_with_cNMF} \
	--cNMF.features {input.cNMF_ENSG_topic_zscore_scaled} \
	--all.features {input.all_features_with_cNMF_RDS} \
	--recompute F ' "
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/PoPS.plots.R \
	--sampleName {wildcards.sample} \
	--output {params.outdir}/ \
	--figure {params.figdir}/ \
	--scratch.output {params.scratch_outdir}/ \
	--prefix {wildcards.magma_prefix}_cNMF{wildcards.k} \
	--k.val {wildcards.k} \
	--coefs_with_cNMF {input.coefs_with_cNMF} \
	--preds_with_cNMF {input.preds_with_cNMF} \
	--marginals_with_cNMF {input.marginals_with_cNMF} \
	--coefs_without_cNMF {input.coefs_without_cNMF} \
	--preds_without_cNMF {input.preds_without_cNMF} \
	--marginals_without_cNMF {input.marginals_with_cNMF} \
	--cNMF.features {input.cNMF_ENSG_topic_zscore_scaled} \
	--all.features {input.all_features_with_cNMF_RDS} \
	--external.features.metadata {params.external_features_metadata} \
	--combined.preds {input.combined_preds} \
	--coefs.defining.top.topic.RDS {input.coe.afs_defining_top_topic_RDS} \
	--preds.importance.score.key.columns {input.preds_importance_score_key_columns} ' "
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/output_IGVF_format.R \
		--sampleName {wildcards.sample} \
		--outdir {params.analysisdir}/ \
		--K.val {wildcards.k} \
		--density.thr {params.threshold} \
		--level {params.level} \
		--cell.type {params.cell_type} \
	' "
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/create_cellxgene_h5ad_IGVF_format.py \
		--path_to_topics {params.cNMF_outdir} \
		--topic_sampleName {wildcards.sample} \
		--outdir {params.analysisdir} \
		--k {wildcards.k} \
		--density_threshold {params.threshold} \
		--barcode_dir {params.barcode_dir} \
	' "
21
22
23
24
25
26
27
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_env; \
	python workflow/scripts/filter_to_h5ad.py \
	--inputPath {input.raw_h5ad_mtx} \
	--output_h5ad {output.h5ad_mtx} \
	--output_gene_name_txt {output.gene_name_txt} ' "
12
13
14
15
shell:
	"bash -c ' source ~/.bashrc; \
	conda activate cnmf_env; \
	bash workflow/scripts/fimo_motif_match.sh {input.coord} {input.fasta} {output.fasta} ' "
27
28
29
30
31
32
33
34
35
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/program_prioritization/create_input_table.R \
	--sampleName {wildcards.sample} \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--density.thr {params.threshold} \
	--perturbSeq {params.perturbseq}' " # popsdir =  "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/"
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/program_prioritization/compute_enrichment.R \
	--input.GWAS.table {input.input_GWAS_table} \
	--coding.variant.df {input.coding_variant_table} \
	--sampleName {wildcards.sample} \
	--outdirsample {params.outdirsample}/ \
	--celltype {params.celltype} \
	--figdir {params.figdir}/ \
	--outdir {params.analysisdir} \
	--K.val {wildcards.k} \
	--trait.name {wildcards.GWAS_trait} \
	--density.thr {params.threshold} \
	--cNMF.table {input.input_table_for_compute_enrichment} \
	--regulator.analysis.type {wildcards.regulator_analysis_type} \
	--perturbSeq {params.perturbseq}' "
12
13
14
15
16
17
18
shell:
	"bash -c ' source $HOME/.bashrc; \
	conda activate cnmf_analysis_R; \
	Rscript workflow/scripts/seurat_to_h5ad.R \
	--inputSeuratObject {input.seurat_object} \
	--output_h5ad {output.h5ad_mtx} \
	--output_gene_name_txt {output.gene_name_txt} ' "
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
packages <- c("optparse","dplyr", "data.table", "reshape2", "conflicted","ggplot2",
              "tidyr", "textshape","readxl") # , "IsoplotR"
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")

## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/", help="Figure directory"), # "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes"
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"), # "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_2min"
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"), # "/oak/stanford/groups/engreitz/Users/kangh/process_sequencing_data/210912_FT010_fresh_Telo_sortedEC/multiome_FT010_fresh_2min/outs/filtered_feature_bc_matrix"
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  make_option("--K.list", type="character", default="14,15,60", help="K values available for analysis"),
  # make_option("--K.val", type="numeric", default=14, help="K value to analyze"),
  make_option("--K.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210625_snakemake_output/analysis/2kG.library/K.spectra.threshold.table.txt", help="table for defining spectra threshold"), # opt$K.table <-"/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/all_genes/2kG.library/K.spectra.threshold.table.txt"
  make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  # make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  # make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  # make_option("--raw.mtx.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  # make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  # make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  # make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  make_option("--adj.p.value.thr", type="numeric", default=0.05, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## 2n1.99x
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/figures/all_genes"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/"
## opt$sampleName <- "Perturb_2kG_dup4"

## scratch sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/"
## ## opt$datadir <- "/oak/stanford/groups/engreitz/Users/kangh/process_sequencing_data/210912_FT010_fresh_Telo_sortedEC/multiome_FT010_fresh_2min/outs/filtered_feature_bc_matrix"
## opt$sampleName <- "FT010_fresh_3min"

mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))


## ## all genes (for interactive sessions)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K.spectra.threshold.table.txt"

## control only perturb-seq (for sdev)


## ## for testing findK_plots for control only cells
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/"
## opt$sampleName <- "2kG.library.ctrl.only"



SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
# STATIC.SAMPLE=c("Telo_no_IL1B_T200_1", "Telo_no_IL1B_T200_2", "Telo_plus_IL1B_T200_1", "Telo_plus_IL1B_T200_2", "no_IL1B", "plus_IL1B",  "pooled")
DATADIR=opt$datadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
OUTDIR.ACROSS.K=paste0(OUTDIR,"/",SAMPLE,"/acrossK/")
## OUTDIR.ACROSS.K=paste0(OUTDIR,"/",SAMPLE,"/acrossK/threshold_", DENSITY.THRESHOLD, "/")
# SEP=opt$sep
K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
## k <- opt$K.val
FIGDIR=opt$figdir



## adjusted p-value threshold
p.value.thr <- opt$adj.p.value.thr

## ## directories for factor motif enrichment
## FILENAME=opt$filename



## # create dir if not already
## if(SEP) check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE, ".sep/"), paste0(FIGDIR,SAMPLE, ".sep/K",k,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE, paste0(OUTDIR,SAMPLE, ".sep/"),FGSEADIR, FGSEAFIG, OUTDIR.ACROSS.K) else 
##   check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE, FGSEADIR, FGSEAFIG, OUTDIR.ACROSS.K)
check.dir <- c(OUTDIR, FIGDIR, OUTDIR.ACROSS.K)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x) }))


## palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
## selected.gene <- c("EDN1", "NOS3", "TP53", "GOSR2", "CDKN1A")
## # ABC genes
## gene.set <- c("INPP5B", "SF3A3", "SERPINH1", "NR2C1", "FGD6", "VEZT", "SMAD3", "AAGAB", "GOSR2", "ATP5G1", "ANGPTL4", "SRBD1", "PRKCE", "DAGLB") # ABC_0.015_CAD_pp.1_genes
## # cell cycle genes
## gene.list.three.groups <- read.delim(paste0(DATADIR,"/ptbd.genes_three.groups.txt"), header=T, stringsAsFactors=F)
## enhancer.set <- gene.list.three.groups$Gene[grep("E_at_", gene.list.three.groups$Gene)]
## CAD.focus.gene.set <- gene.list.three.groups %>% subset(Group=="CAD_focus") %>% pull(Gene) %>% append(enhancer.set)
## EC.pos.ctrl.gene.set <- gene.list.three.groups %>% subset(Group=="EC_pos._ctrls") %>% pull(Gene)

## cell.count.thr <- opt$cell.count.thr # greater than this number, filter to keep the guides with greater than this number of cells
## guide.count.thr <- opt$guide.count.thr # greater than this number, filter to keep the perturbations with greater than this number of guides

## guide.design = read.delim(file=paste0(opt$datadir, "/200607_ECPerturbSeqMiniPool.design.txt"), header=T, stringsAsFactors = F)


## ## add GO pathway log2FC
## GO <- read.delim(file=paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.table.brief.txt"), header=T, check.names=FALSE)
## GO.list <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.list.brief.txt", header=T, check.names=F)
## colnames(GO)[1] <- "Gene"
## colnames(GO.list)[1] <- "Gene"
## ## load all sample, K, topic's top 100 genes (by TopFeatures() KL-score measure)
## ## allGeneKtopic100 <- read.delim(paste0(TMDIR, "no.plus.pooled.top100.topicStats.txt"), header=T)
## # load non-expressed control gene list
## non.expressed.genes <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/non.expressed.ctrl.genes.txt", header=F, stringsAsFactors=F) %>% unlist %>% as.character() %>% sort()

## # perturbation type list
## gene.set.type.df <- data.frame(Gene=guide.design %>% pull(guideSet) %>% unique(),
##                                type=rep("other", guide.design %>% pull(guideSet) %>% unique() %>% length())) 
## gene.set.type.df$Gene <- gene.set.type.df$Gene %>% as.character()
## gene.set.type.df$type <- gene.set.type.df$type %>% as.character()
## gene.set.type.df$type[which(gene.set.type.df$Gene %in% non.expressed.genes)] <- "non-expressed"
## gene.set.type.df$type[which(gene.set.type.df$Gene %in% CAD.focus.gene.set)] <- "CAD focus"
## gene.set.type.df$type[grepl("^safe|^negative", gene.set.type.df$Gene)] <- "negative-control"
## gene.set.type.df$Gene[which(gene.set.type.df$Gene == "negative_control")] <- "negative-control"
## gene.set.type.df$Gene[which(gene.set.type.df$Gene == "safe_targeting")] <- "safe-targeting"
## # gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"

## # convert enhancer SNP rs number to enhancer target gene name
## enh.snp.to.gene <- read.delim(paste0(DATADIR, "/enhancer.SNP.to.gene.name.txt"), header=T, stringsAsFactors = F) %>% mutate(Enhancer_name=gsub("_","-", Enhancer_name))

## # gene corresponding pathway
## gene.def.pathways <- read_excel(paste0(DATADIR,"topic.gene.definition.pathways.xlsx"), sheet="Gene_Pathway")

# K spectra threshold table
if(file.exists(opt$K.table)) {
    K.spectra.threshold <- read.table(file=paste0(opt$K.table), header=T, stringsAsFactors=F)
} else {
    K.spectra.threshold <- data.frame(K = K.list, density.threshold=rep(0.2, length(K.list)))  ## assume 0.2 is the best threshold for filtering out outlier topics
}

## load cNMF pipeline aggregated results (?)

## initialize storage variables
# promoter.fisher.df.list <- enhancer.fisher.df.list <- fgsea.results <- all.test.df.list <- all.fdr.df.list <- count.by.GWAS.list <- count.by.GWAS.withTopic.list <- theta.zscore.list <- theta.raw.list <- all.enhancer.fisher.df.list <- all.promoter.fisher.df.list <- all.enhancer.fisher.df.10en6.list <- promoter.wide.10en6.list <- promoter.wide.binary.10en6.list <- all.promoter.fisher.df.10en6.list <- all.promoter.ttest.df.list <- all.promoter.ttest.df.10en6.list <- all.enhancer.ttest.df.list <- all.enhancer.ttest.df.10en6.list <- vector("list", nrow(K.spectra.threshold))
batch.percent.df.list <- all.test.df.list <- MAST.df.list <- vector("list", nrow(K.spectra.threshold))
## loop over all values of K and aggregate results
for (n in 1:nrow(K.spectra.threshold)) {
    k <- K.spectra.threshold[n,"K"]
    DENSITY.THRESHOLD <- K.spectra.threshold[n,"density.threshold"] %>% gsub("\\.","_",.)

    ## subscript for files
    SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
    SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

    ## directories
    # FIGDIRSAMPLE=ifelse(SEP, paste0(FIGDIR,SAMPLE,".sep/K",k,"/"), paste0(FIGDIR, SAMPLE, "/K",k,"/"))
    # FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
    OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k, "/threshold_", DENSITY.THRESHOLD)
    # FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
    # FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

    ## batch topic correlation file
    batch.correlation.file.name <- paste0(OUTDIRSAMPLE, "/batch.correlation.RDS")
    if(file.exists(batch.correlation.file.name)) {
        print(paste0("loading batch correlation file from: ", batch.correlation.file.name))
        load(batch.correlation.file.name)
        batch.percent.df.list[[n]] <- batch.percent.df
    } else {
        print(paste0("file ", batch.correlation.file.name, " not found"))
    }

    # ## load motif enrichment results
    # file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.factorMotifEnrichment.",SUBSCRIPT.SHORT,".RData")
    # print(file.name)
    # if(file.exists((file.name))) { 
    #     load(file.name)
    #     print(paste0("loading ", file.name))
    # }
    # motif.enrichment.variables <- c("all.enhancer.fisher.df", "all.promoter.fisher.df", 
    #                                 "promoter.wide", "enhancer.wide", "promoter.wide.binary", "enhancer.wide.binary",
    #                                 "enhancer.wide.10en6", "enhancer.wide.binary.10en6", "all.enhancer.fisher.df.10en6",
    #                                 "promoter.wide.10en6", "promoter.wide.binary.10en6", "all.promoter.fisher.df.10en6",
    #                                 "all.promoter.ttest.df", "all.promoter.ttest.df.10en6", "all.enhancer.ttest.df", "all.enhancer.ttest.df.10en6")
    # motif.enrichment.variables.missing <- (!(motif.enrichment.variables %in% ls())) %>% as.numeric %>% sum 
    # if ( motif.enrichment.variables.missing > 0 ) {
    #     warning(paste0(motif.enrichment.variables[!(motif.enrichment.variables %in% ls())], " not available"))
    # } else {
    #     promoter.fisher.df.list[[n]] <- all.promoter.fisher.df %>% mutate(K = k)
    #     enhancer.fisher.df.list[[n]] <- all.enhancer.fisher.df %>% mutate(K = k)
    #     all.promoter.ttest.df.list[[n]] <- all.promoter.ttest.df %>% mutate(K = k)
    #     all.promoter.ttest.df.10en6.list[[n]] <- all.promoter.ttest.df.10en6 %>% mutate(K = k)
    #     all.enhancer.ttest.df.list[[n]] <- all.enhancer.ttest.df %>% mutate(K = k)
    #     all.enhancer.ttest.df.10en6.list[[n]] <- all.enhancer.ttest.df.10en6 %>% mutate(K = k)
    #     all.promoter.fisher.df.list[[n]] <- all.promoter.fisher.df
    #     all.enhancer.fisher.df.list[[n]] <- all.enhancer.fisher.df
    # }



    ## all Wilcoxon statistical tests
    file.name <- paste0(OUTDIRSAMPLE, "/all.test.", SUBSCRIPT, ".txt")
    print(file.name)
    if(file.exists(file.name)) {
        print(paste0("loading ", file.name))
        all.test.df.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    } else {
        print("all.test file not found")
    }

    ## MAST statistical test
    file.name <- paste0(OUTDIRSAMPLE, "/", SAMPLE, "_MAST_DEtopics.txt")
    print(file.name)
    if(file.exists(file.name)) {
        print(paste0("loading ", file.name))
        MAST.df.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F, fill=T, check.names=F) %>% mutate(K = k)
    } else {
        print("MAST result file not found")
    }

    # file.name <- paste0(OUTDIRSAMPLE, "/all.expressed.genes.pval.fdr.", SUBSCRIPT, ".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     all.fdr.df.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }

    # # load count.by.GWAS 
    # file.name <- paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     count.by.GWAS.list[[n]] <- read.delim(file=file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }

    # # load count.by.GWAS.with.topic
    # file.name <- paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes.withTopic_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     count.by.GWAS.withTopic.list[[n]] <- read.delim(file=file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }


}

batch.percent.df <- do.call(rbind, batch.percent.df.list)
MAST.df <- do.call(rbind, MAST.df.list)
all.test.df <- do.call(rbind, all.test.df.list)

file.name <- paste0(OUTDIR.ACROSS.K, "/aggregated.outputs.findK.perturb-seq.RData")
save(batch.percent.df, all.test.df, MAST.df, file = file.name)
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
packages <- c("optparse","dplyr", "data.table", "reshape2", "conflicted","ggplot2",
              "tidyr", "textshape","readxl") # , "IsoplotR"
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")

## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/", help="Figure directory"), # "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes"
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"), # "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_2min"
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"), # "/oak/stanford/groups/engreitz/Users/kangh/process_sequencing_data/210912_FT010_fresh_Telo_sortedEC/multiome_FT010_fresh_2min/outs/filtered_feature_bc_matrix"
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  make_option("--K.list", type="character", default="14,30,60", help="K values available for analysis"),
  # make_option("--K.val", type="numeric", default=14, help="K value to analyze"),
  make_option("--K.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K.spectra.threshold.table.txt", help="table for defining spectra threshold"), # opt$K.table <-"/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/all_genes/2kG.library/K.spectra.threshold.table.txt"
  # make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  # make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  # make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  # make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  # make_option("--raw.mtx.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  # make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  # make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  # make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))

## sdev for 2n1.99x singlets
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/figures/all_genes"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/"
## ## opt$datadir <- "/oak/stanford/groups/engreitz/Users/kangh/process_sequencing_data/210912_FT010_fresh_Telo_sortedEC/multiome_FT010_fresh_2min/outs/filtered_feature_bc_matrix"
## opt$sampleName <- "Perturb_2kG_dup4"

mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))


## ## all genes (for interactive sessions)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K.spectra.threshold.table.txt"

## ## overdispersed Genes
## opt$figdir <- "//oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/figures/top2000VariableGenes/"
## opt$outdir <- "//oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes/"
## opt$K.table <- ""
## opt$K.list <- "3,5,7,12,14,19,21,23,25,27,29,31,35,40,45,50,60,70,80,90,100,120"
## opt$sampleName <- "2kG.library_overdispersedGenes"

## ## sdev K562 gwps 2k most dispersed genes cNMF
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes"
## opt$sampleName <- "WeissmanK562gwps"
## opt$K.list <- "3,5,10,15,20,25,30,35,40,45,50,55,60,70,80,90,100,110,120"
## opt$K.table <- "/to/use/for/specifying/spectra/cut/off/threshold//default/0.2"

SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
# STATIC.SAMPLE=c("Telo_no_IL1B_T200_1", "Telo_no_IL1B_T200_2", "Telo_plus_IL1B_T200_1", "Telo_plus_IL1B_T200_2", "no_IL1B", "plus_IL1B",  "pooled")
DATADIR=opt$datadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
OUTDIR.ACROSS.K=paste0(OUTDIR,"/",SAMPLE,"/acrossK/")
K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
FIGDIR=opt$figdir



## adjusted p-value threshold
p.value.thr <- opt$adj.p.value.thr



## # create dir if not already
check.dir <- c(OUTDIR, FIGDIR, OUTDIR.ACROSS.K)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x) }))


# K spectra threshold table
if(file.exists(opt$K.table)) {
    K.spectra.threshold <- read.table(file=paste0(opt$K.table), header=T, stringsAsFactors=F)
} else {
    K.spectra.threshold <- data.frame(K = K.list, density.threshold=rep(0.2, length(K.list)))  ## assume 0.2 is the best threshold for filtering out outlier topics
}



## initialize storage variables
GSEA.types <- c("GOEnrichment", "PosGenesGOEnrichment", "ByWeightGSEA", "GSEA")
for (j in 1:length(GSEA.types)) {
    GSEA.type <- GSEA.types[j]
    to.eval <- paste0("clusterProfiler.", GSEA.type, ".list <- vector(\"list\",nrow(K.spectra.threshold))")
    eval(parse(text = to.eval))
}
all.MAST.df.list <- all.test.df.list <- all.fdr.df.list <- count.by.GWAS.list <- count.by.GWAS.withTopic.list <- theta.zscore.list <- theta.raw.list <-  all.promoter.ttest.df.list <- all.enhancer.ttest.df.list <- varianceExplainedByModel.list <- varianceExplainedPerProgram.list <- vector("list", nrow(K.spectra.threshold))
## loop over all values of K and aggregate results
for (n in 1:nrow(K.spectra.threshold)) {
    k <- K.spectra.threshold[n,"K"]
    DENSITY.THRESHOLD <- K.spectra.threshold[n,"density.threshold"] %>% gsub("\\.","_",.)

    ## subscript for files
    SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
    # SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

    ## directories
    # FIGDIRSAMPLE=ifelse(SEP, paste0(FIGDIR,SAMPLE,".sep/K",k,"/"), paste0(FIGDIR, SAMPLE, "/K",k,"/"))
    # FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
    OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k, "/threshold_", DENSITY.THRESHOLD)

    ## if (SEP) {
    ##     guideCounts <- loadGuides(n, sep=T) %>% mutate(Gene=Gene.marked)
    ##     tmp.labels <- guideCounts$Gene %>% unique() %>% strsplit("-") %>% sapply("[[",2) %>% unique()
    ##     tmp.labels <- tmp.labels[!(tmp.labels %in% c("control","targeting"))]
    ##     rep1.label <- paste0("-",tmp.labels[1])
    ##     rep2.label <- paste0("-",tmp.labels[2])
    ## } else guideCounts <- loadGuides(n) %>% mutate(Gene=Gene.marked)

    ## cNMF direct output file (GEP)
    cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
    ## cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.k_", k, ".dt_", density.threshold, ".RData")
    print(cNMF.result.file)
    if(file.exists(cNMF.result.file)) {
        print("loading cNMF result file")
        load(cNMF.result.file)
    } else {
        print(paste0("file ", cNMF.result.file, " not found"))
    }

    # ## cNMF analysis results file
    # file.name <- ifelse(SEP,
    #                     paste0(OUTDIRSAMPLE,"/cNMFAnalysis.",SUBSCRIPT,".sep.RData"),
    #                     paste0(OUTDIRSAMPLE,"/cNMFAnalysis.",SUBSCRIPT,".RData"))
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print("loading the file")
    #     load(file.name)
    # } else {
    #     print("file not found")
    # }

    ## theta.zscore
    ## theta.zscore.list[[n]] <- theta.zscore %>% as.data.frame %>% `colnames<-`(paste0("factor_", colnames(.))) %>% mutate(Gene=rownames(.)) %>% melt(value.name="weight", id.vars="Gene", variable.name="Factor") %>% mutate(K=k)
    theta.zscore.list[[n]] <- theta.zscore %>% `colnames<-`(paste0("K",k,"_factor_", colnames(.)))
    ## theta.raw
    ## theta.raw.list[[n]] <- theta.raw %>% as.data.frame %>% `colnames<-`(paste0("factor_", colnames(.))) %>% mutate(Gene=rownames(.)) %>% melt(value.name="weight", id.vars="Gene", variable.name="Factor") %>% mutate(K=k)
    theta.raw.list[[n]] <- theta.raw %>% `colnames<-`(paste0("K",k,"_factor_", colnames(.)))
    ## ## theta.KL
    ## ## load KL score
    ## file.name <- paste0(OUTDIRSAMPLE, "/topic.KL.score_", SUBSCRIPT.SHORT, ".txt") %>% gsub("_k_", "_K", .)
    ## if(file.exists(file.name)) {
    ##     print(paste0("Loading ", file.name))
    ##     theta.KL.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F)
    ## } else {
    ##     print(paste0(file.name, " does not exist."))
    ## }

    # ## motif enrichment file (old.211025)
    # file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.factorMotifEnrichment.",SUBSCRIPT.SHORT, ".RData")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print("loading the file")
    #     load(file.name)
    # } else {
    #     print("file not found")
    # }
    # if ("all.promoter.fisher.df" %in% ls()) {
    #     promoter.fisher.df.list[[n]] <- all.promoter.fisher.df %>% mutate(K = k)
    #     enhancer.fisher.df.list[[n]] <- all.enhancer.fisher.df %>% mutate(K = k)
    #     rm(list=c("all.promoter.fisher.df", "all.enhancer.fisher.df"))
    # } else {
    #     print("missing all.promoter.fisher.df and/or all.enhancer.fisher.df")
    # }

    ## load motif enrichment results
    for(ep.type in c("promoter", "enhancer")){
        num.top.genes <- 300
        file.name <- paste0(OUTDIRSAMPLE, "/", ep.type, ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.pval1e-6_", SUBSCRIPT.SHORT,".txt")
        if(file.exists(file.name)) {
            eval(parse(text = paste0("all.", ep.type, ".ttest.df.list[[n]] <- read.delim(file.name, stringsAsFactors=F) %>% mutate(K = k)"))) ## store in all.{promoter, enhancer}.ttest.df.list
        } else {
            message(paste0(file.name, " does not exist"))
        }
    }


    # file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.factorMotifEnrichment.",SUBSCRIPT.SHORT,".RData")
    # print(file.name)
    # if(file.exists((file.name))) { 
    #     load(file.name)
    #     print(paste0("loading ", file.name))
    # }
    # motif.enrichment.variables <- c("all.enhancer.fisher.df", "all.promoter.fisher.df", 
    #                                 "promoter.wide", "enhancer.wide", "promoter.wide.binary", "enhancer.wide.binary",
    #                                 "enhancer.wide.10en6", "enhancer.wide.binary.10en6", "all.enhancer.fisher.df.10en6",
    #                                 "promoter.wide.10en6", "promoter.wide.binary.10en6", "all.promoter.fisher.df.10en6",
    #                                 "all.promoter.ttest.df", "all.promoter.ttest.df.10en6", "all.enhancer.ttest.df", "all.enhancer.ttest.df.10en6")
    # motif.enrichment.variables.missing <- (!(motif.enrichment.variables %in% ls())) %>% as.numeric %>% sum 
    # if ( motif.enrichment.variables.missing > 0 ) {
    #     warning(paste0(motif.enrichment.variables[!(motif.enrichment.variables %in% ls())], " not available"))
    # } else {
    #     promoter.fisher.df.list[[n]] <- all.promoter.fisher.df %>% mutate(K = k)
    #     enhancer.fisher.df.list[[n]] <- all.enhancer.fisher.df %>% mutate(K = k)
    #     all.promoter.ttest.df.list[[n]] <- all.promoter.ttest.df %>% mutate(K = k)
    #     all.promoter.ttest.df.10en6.list[[n]] <- all.promoter.ttest.df.10en6 %>% mutate(K = k)
    #     all.enhancer.ttest.df.list[[n]] <- all.enhancer.ttest.df %>% mutate(K = k)
    #     all.enhancer.ttest.df.10en6.list[[n]] <- all.enhancer.ttest.df.10en6 %>% mutate(K = k)
    #     all.promoter.fisher.df.list[[n]] <- all.promoter.fisher.df
    #     all.enhancer.fisher.df.list[[n]] <- all.enhancer.fisher.df
    # }


    ## GSEA results
    ranking.types <- c("zscore", "raw", "median_spectra", "median_spectra_zscore")
    for (j in 1:length(GSEA.types)) {
        GSEA.type <- GSEA.types[j]
        to.eval <- paste0("clusterProfiler.", GSEA.type, ".list.here <- vector(\"list\",length(ranking.types))")
        eval(parse(text = to.eval))
        for (i in 1:length(ranking.types)) {
            ranking.type <- ranking.types[i]
            file.name <- paste0(OUTDIRSAMPLE,"/clusterProfiler_GeneRankingType",ranking.type,"_EnrichmentType", GSEA.type, ".txt")
            if(file.exists(file.name)) {
                message("Loading ", file.name)
                to.eval <- paste0("clusterProfiler.", GSEA.type, ".list.here[[i]] <- read.delim(file.name, header=T, stringsAsFactors = F) %>% mutate(type = ranking.type, K = k)")
                eval(parse(text = to.eval))
            } else {
                warning(paste0(file.name, " file does not exist"))
            }
        }
        to.eval <- paste0("clusterProfiler.", GSEA.type, ".list[[n]] <- do.call(rbind, clusterProfiler.", GSEA.type, ".list.here)")
        eval(parse(text = to.eval))
    }


    ## variance explained by the model
    file.name <- paste0(OUTDIRSAMPLE, "/summary.varianceExplained.df.txt")
    if(file.exists(file.name)) {
        varianceExplainedByModel.list[[n]] <- read.delim(file.name, stringsAsFactors=F) %>% mutate(K = k)
    } else {
        message(paste0(file.name, " does not exist"))
    }

    ## variance explained per program
    file.name <- paste0(OUTDIRSAMPLE, "/metrics.varianceExplained.df.txt")
    if(file.exists(file.name)) {
        varianceExplainedPerProgram.list[[n]] <- read.delim(file.name, stringsAsFactors=F)
    } else {
        message(paste0(file.name, " does not exist"))
    }        

    # ## all statistical tests
    # file.name <- paste0(OUTDIRSAMPLE, "/all.test.", SUBSCRIPT, ".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     all.test.df.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("all.test file not found")
    # }

    # file.name <- paste0(OUTDIRSAMPLE, "/all.expressed.genes.pval.fdr.", SUBSCRIPT, ".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     all.fdr.df.list[[n]] <- read.table(file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }

    # # load count.by.GWAS 
    # file.name <- paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     count.by.GWAS.list[[n]] <- read.delim(file=file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }

    # # load count.by.GWAS.with.topic
    # file.name <- paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes.withTopic_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt")
    # print(file.name)
    # if(file.exists(file.name)) {
    #     print(paste0("loading ", file.name))
    #     count.by.GWAS.withTopic.list[[n]] <- read.delim(file=file.name, header=T, stringsAsFactors=F) %>% mutate(K = k)
    # } else {
    #     print("file not found")
    # }


}

# promoter.fisher.df <- do.call(rbind, promoter.fisher.df.list)
# enhancer.fisher.df <- do.call(rbind, enhancer.fisher.df.list)
GSEA.types <- c("GOEnrichment", "ByWeightGSEA", "GSEA") ## use external input
clusterProfiler.GO.list.here <- clusterProfiler.GSEA.list.here <- clusterProfiler.enricher.GSEA.list.here <- vector("list",length(GSEA.types))
for (j in 1:length(GSEA.types)) {
    GSEA.type <- GSEA.types[j]
    to.eval <- paste0("clusterProfiler.", GSEA.type, ".df <- do.call(rbind, clusterProfiler.", GSEA.type, ".list)")
    eval(parse(text = to.eval))
}
## clusterProfiler.GO.df <- do.call(rbind, clusterProfiler.GO.list)
## clusterProfiler.GSEA.df <- do.call(rbind, clusterProfiler.GSEA.list)
## clusterProfiler.enricher.GSEA.df <- do.call(rbind, clusterProfiler.enricher.GSEA.list)
# all.test.df <- do.call(rbind, all.test.df.list)
# all.fdr.df <- do.call(rbind, all.fdr.df.list)
# count.by.GWAS <- do.call(rbind, count.by.GWAS.list)
# count.by.GWAS.withTopic <- do.call(rbind, count.by.GWAS.withTopic.list)
theta.zscore.df <- do.call(cbind, theta.zscore.list)
theta.raw.df <- do.call(cbind, theta.raw.list)
# theta.KL.df <- do.call(rbind, theta.KL.list)
all.promoter.ttest.df <- do.call(rbind, all.promoter.ttest.df.list)
# all.promoter.ttest.df.10en6 <- do.call(rbind, all.promoter.ttest.df.10en6.list)
all.enhancer.ttest.df <- do.call(rbind, all.enhancer.ttest.df.list)
# all.enhancer.ttest.df.10en6 <- do.call(rbind, all.enhancer.ttest.df.10en6.list)
varianceExplainedByModel.df <- do.call(rbind, varianceExplainedByModel.list)
varianceExplainedPerProgram.df <- do.call(rbind, varianceExplainedPerProgram.list)

file.name <- paste0(OUTDIR.ACROSS.K, "/aggregated.outputs.findK.RData")
save(clusterProfiler.GOEnrichment.df, clusterProfiler.ByWeightGSEA.df, clusterProfiler.GSEA.df, theta.zscore.df, theta.raw.df, all.promoter.ttest.df, all.enhancer.ttest.df, varianceExplainedByModel.df, varianceExplainedPerProgram.df,
     file=file.name)
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel","ramify",
              "ggpubr","gridExtra",
              "org.Hs.eg.db","limma","fgsea", "conflicted",
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "Seurat", "writexl", "purrr") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")



## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/", help="Output directory"),
  # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),
  make_option("--barcode.names", type="character", default="", help="metadata CBC and sample information data table"),


  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60

# ## control genes directories (for sdev)
# opt$sampleName <- "2kG.library.ctrl.only"
# opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/"
# opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/all_genes_acrossK/2kG.library.ctrl.only/"
# opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/"
# opt$K.val <- 60

## ## overdispersed gene directories (for sdev)
## opt$sampleName <- "2kG.library_overdispersedGenes"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/figures/top2000VariableGenes"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes"
## opt$K.val <- 120

## ## K562 gwps sdev
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$K.val <- 90
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"

## ## ENCODE Mouse Heart data
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/figures/top2000VariableGenes"
## opt$K.val <- 55
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/collab_data/IGVF/mouse_ENCODE_heart/auxiliary_data/snrna/heart_Parse_10x_integrated_metadata.csv"

## ## teloHAEC no_IL1B 200 gene library
## opt$sampleName <- "no_IL1B"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/analysis/all_genes/"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/figures/all_genes/"
## opt$K.val <- 20
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/no_IL1B.barcodes.txt"


SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
## DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")


## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr

# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, OUTDIRSAMPLE, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))


## graphing constants and helpers
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)


##################################################
## load data
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
if(file.exists(cNMF.result.file)) {
    message(paste0("loading cNMF result file: \n", cNMF.result.file))
    load(cNMF.result.file)
} else {
	print(paste0(cNMF.result.file, " does not exist"))
}


## annotate omega to get Gene, Guide, Sample, and CBC
if(grepl("2kG.library", SAMPLE)) {
    ann.omega <- cbind(omega, barcode.names)  ## %>%
} else {
    if(grepl("[.]csv", opt$barcode.names)) barcode.names <- read.delim(opt$barcode.names, stringsAsFactors=F, sep=",") else barcode.names <- read.delim(opt$barcode.names, stringsAsFactors=F)
    ann.omega <- merge(omega, barcode.names %>% select(CBC, sample), by.x=0, by.y="CBC", all.x=T)
}


## Batch Effect QC: correlate batch binary labels with topic expression
## Batch binary label
if(grepl("2kG.library", SAMPLE)) {
    ann.omega.batch <- ann.omega %>% mutate(sample.short = gsub("scRNAseq_2kG_", "", sample) %>% gsub("_.*$","", .))
    ann.omega.batch.binary <- ann.omega.batch %>% mutate(tmp.value = 1) %>% spread(key="sample.short", fill=0, value="tmp.value")
    ann.omega.batch.binary.mtx <- ann.omega.batch.binary %>% select(-long.CBC,-Gene.full.name,-Guide,-CBC,-sample,-Gene) %>% as.matrix()
    m <- cor(ann.omega.batch.binary.mtx, method="pearson") %>% as.matrix()
    batch.correlation.mtx <- m[1:k,(k+1):(dim(m)[2])]
} 
ann.omega.sample.batch.binary <- ann.omega %>% mutate(tmp.value = 1) %>% spread(key="sample", fill=0, value="tmp.value")
ann.omega.sample.batch.binary.mtx <- ann.omega.sample.batch.binary %>% select(-Row.names) %>% as.matrix()
m <- cor(ann.omega.sample.batch.binary.mtx, method="pearson") %>% as.matrix()
sample.batch.correlation.mtx <- m[1:k, (k+1):(dim(m)[2])]


## calculate percent of topics with correlation past a threshold (0.1, 0.2, 0.4, 0.6)
correlation.threshold.list <- c(0.1, 0.2, 0.4, 0.6)
batch.passed.threshold.df <- do.call(rbind, lapply(correlation.threshold.list, function(threshold) {
    df <- sample.batch.correlation.mtx %>% apply(1, function(x) (x > threshold) %>% as.numeric %>% sum) %>% as.data.frame %>% `colnames<-`("num.batch.correlated") %>% mutate(batch.thr = threshold) %>% mutate(ProgramID = rownames(.), K = k)
}))
batch.percent.df <- batch.passed.threshold.df %>% group_by(batch.thr) %>% summarize(percent.correlated = ((num.batch.correlated > 0) %>% as.numeric %>% sum) / k) %>% mutate(K = k)

## max batch correlation per topic
max.batch.correlation.df <- sample.batch.correlation.mtx %>%
    apply(1, function(x) {
        out <- max(abs(x))
    }) %>%
    as.data.frame %>%
    `colnames<-`("maxPearsonCorrelation") %>%
    mutate(ProgramID = row.names(.)) %>%
    as.data.frame


## store batch and sample correlation matrix
if(grepl("2kG.library", SAMPLE)) write.table(batch.correlation.mtx, file=paste0(OUTDIRSAMPLE, "/batch.correction.mtx.txt"), sep="\t", quote=F)
write.table(sample.batch.correlation.mtx, file=paste0(OUTDIRSAMPLE, "/sample.batch.correction.mtx.txt"), sep="\t", quote=F)
write.table(batch.passed.threshold.df, file=paste0(OUTDIRSAMPLE, "/batch.passed.thr.df.txt"), sep="\t", quote=F, row.names=F)
write.table(batch.percent.df, file=paste0(OUTDIRSAMPLE, "/batch.percent.df.txt"), sep="\t", quote=F, row.names=F)
write.table(max.batch.correlation.df, file=paste0(OUTDIRSAMPLE, "/max.batch.correlation.df.txt"), sep="\t", quote=F, row.names=F)
if(grepl("2kG.library", SAMPLE)) {
    save(batch.correlation.mtx, sample.batch.correlation.mtx, batch.passed.threshold.df, batch.percent.df, max.batch.correlation.df,
     file=paste0(OUTDIRSAMPLE, "/batch.correlation.RDS"))
} else {
    save(sample.batch.correlation.mtx, batch.passed.threshold.df, batch.percent.df, max.batch.correlation.df,
         file=paste0(OUTDIRSAMPLE, "/batch.correlation.RDS"))
}


## Batch correlation heatmap
plotHeatmap <- function(mtx, title){
    heatmap.2(
        mtx, 
        Rowv=T, 
        Colv=T,
        trace='none',
        key=T,
        col=palette,
        labCol=colnames(mtx),
        ## margins=c(15,5), 
        cex.main=0.1, 
        cexCol=1/(nrow(mtx)^(1/7)), cexRow=1/(ncol(mtx)^(1/7)),
        main=title
    )
}


pdf(paste0(FIGDIRTOP, "batch.correlation.heatmap.pdf"),width=0.15*ncol(sample.batch.correlation.mtx)+5, height=0.1*nrow(sample.batch.correlation.mtx)+5)
if(grepl("2kG.library", SAMPLE)) plotHeatmap(batch.correlation.mtx, title=paste0(SAMPLE, ", K=", k, ", topic batch correlation"))
plotHeatmap(sample.batch.correlation.mtx, title=paste0(SAMPLE, ", K=", k, ", topic sample correlation"))
dev.off()

## automate batch topic selection
Pearson.correlation.threshold <- 0.1 ## is this a good threshold for all values of K? ## check CDF of average correlation?
batch.topic <- apply(sample.batch.correlation.mtx, 1, 
                    function(x) sum(as.numeric(abs(x) > Pearson.correlation.threshold))) %>% 
                keep(function(x) x > 0) %>% 
                names
write.table(batch.topic, file=paste0(OUTDIRSAMPLE, "batch.topics.txt"), quote=F, sep="\t", row.names=F, col.names=F)
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
suppressPackageStartupMessages(library(optparse))


option.list <- list(
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/outputs/", help="Output directory"),
    make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/outputs/", help="Output directory"),
    make_option("--sampleName",type="character",default="FT005_gex_new_pipeline", help="Sample name"),
    make_option("--project",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/",help="Project Directory"),
    make_option("--inputSeuratObject", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/outputs/FT005_gex/withUMAP.SeuratObject.RDS", help="Path to the Seurat Object"),
    make_option("--compareSeuratObject", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/outputs/FT005_gex_new_pipeline/withUMAP.SeruatObject.RDS", help="Path to the Seurat Object"),
    make_option("--maxMt", type="numeric", default=50, help="filter out cells with percent mitochondrial gene higher than this threhsold"),
    make_option("--maxCount", type="numeric", default=25000, help="filter out cells with UMI count more than this threshold"),
    make_option("--minUniqueGenes", type="numeric", default=0, help="filter out cells with unique gene detected less than this threshold"),
    make_option("--UMAP.resolution", type="numeric", default=0.06, help="UMAP resolution. The default is 0.06")
)
opt <- parse_args(OptionParser(option_list=option.list))

library(SeuratObject)
library(Seurat)
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(data.table))
suppressPackageStartupMessages(library(tidyr))
suppressPackageStartupMessages(library(readxl))
suppressPackageStartupMessages(library(ggrepel))

## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

mytheme <- theme_classic() + theme(axis.text = element_text(size = 13), axis.title = element_text(size = 15), plot.title = element_text(hjust = 0.5)) 


#######################################################################
## Constants
PROJECT=opt$project
OUTDIR=opt$outdir
FIGDIR=opt$figdir
SAMPLE=opt$sampleName
# OUTDIRSAMPLE=paste0(OUTDIR,"/",SAMPLE,"/")
FIGDIRSAMPLE=paste0(FIGDIR,"/",SAMPLE,"/")
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n=100)
# create dir if not already
# check.dir <- c(OUTDIR, OUTDIRSAMPLE, FIGDIR, FIGDIRSAMPLE)
check.dir <- c(OUTDIR, FIGDIR, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x) }))



## function for calculating UMAP
calcUMAP <- function(s) {
    s <- SCTransform(s)
    s <- RunPCA(s, verbose = FALSE)
    s <- FindNeighbors(s, dims = 1:10)
    s <- FindClusters(s, resolution = opt$UMAP.resolution)
    s <- RunUMAP(s, dims = 1:10)
}



## load data
s <- readRDS(opt$inputSeuratObject)
s <- calcUMAP(s)

## Choose filters, use cells with Good_singlet, and filter MT/RP
s[["percent.mt"]] <- PercentageFeatureSet(s, pattern = "^MT-")
s[["percent.ribo"]] <- PercentageFeatureSet(s, pattern = "^RPS|^RPL")


## plot QC
## s.meta <- SeuratObject::FetchData(s, colnames(s[[]]))  ## This weird Seurat syntax gets the list of all metadata vars and then fetches a matrix of the data
# plotSingleCellStats(s.meta, mtMax=NULL, nCountMax=NULL, paste0(FIGDIRSAMPLE,"/QC.single.cell.stats.UMAPres.", opt$UMAP.resolution, ".pdf"))

# saveRDS(s, paste0(OUTDIRSAMPLE,"/", SAMPLE, ".withUMAP.", opt$UMAP.resolution, ".SeuratObject.RDS")) 
# saveRDS(s, paste0(OUTDIRSAMPLE,"/", SAMPLE, ".withUMAP_SeuratObject.RDS")) 
saveRDS(s, paste0(OUTDIR,"/", SAMPLE, ".withUMAP_SeuratObject.RDS")) 
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
packages <- c("optparse", "data.table", "reshape2", "fgsea", "conflicted", "readxl", "writexl", "org.Hs.eg.db", "tidyr", "dplyr", "clusterProfiler", "msigdbr")
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/", help="Output directory"),
  # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210707_snakemake_maxParallel/all_genes_acrossK/2kG.library/", help="Topic model results directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  make_option("--K.list", type="character", default="2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,19,21,23,25", help="K values available for analysis"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  make_option("--raw.mtx.dir",type="character",default="stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"), # the first lane: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.expandedMultiTargetGuide.RDS"
  make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),
  # make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/barcodes.tsv", help="barcodes.tsv for all cells"),
  make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  make_option("--enhancer.fimo.threshold", type="character", default="1.0E-4", help="Enhancer fimo motif match threshold"),

  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute"),

  ## GSEA parameters
  make_option("--ranking.type", type="character", default="zscore", help="{zscore, raw} ranking for the top program genes"),
  make_option("--GSEA.type", type="character", default="GOEnrichment", help="{GOEnrichment, ByWeightGSEA, GSEA}"),
  ## make_option("--", type="", default= , help="")

  ## Organism flag
  make_option("--organism", type="character", default="human", help="Organism type, accept org.Hs.eg.db. Only support human and mouse.")

)
opt <- parse_args(OptionParser(option_list=option.list))

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60

## ## ## K562 gwps sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$K.val <- 35
## opt$sampleName <- "WeissmanK562gwps"
## opt$GSEA.type <- "ByWeightGSEA"
## opt$ranking.type <- "median_spectra_zscore"

## ## ENCODE mouse heart
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes"
## opt$K.val <- 10
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$GSEA.type <- "ByWeightGSEA"
## opt$ranking.type <- "zscore"

## ## teloHAEC no_IL1B 200 gene library
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/analysis/top2000VariableGenes/"
## opt$K.val <- 20
## opt$sampleName <- "no_IL1B"
## opt$GSEA.type <- "ByWeightGSEA"
## opt$ranking.type <- "median_spectra"

## ## IGVF b01_LeftCortex
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/"
## opt$K.val <- 20
## opt$sampleName <- "IGVF_b01_LeftCortex"
## opt$GSEA.type <- "GSEA"
## opt$ranking.type <- "median_spectra"
## opt$organism <- "mouse"

## ## RCA Pt4
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220124_snakemake_RCA/analysis/all_genes_acrossK/RCA"
## opt$sampleName <- "RCA"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220124_snakemake_RCA/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220124_snakemake_RCA/analysis/all_genes"
## opt$K.val <- 60
## opt$ranking.type <- "median_spectra_zscore"
## opt$GSEA.type <- "GOEnrichment"
## opt$organism <- "human"


SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
TMDIR=opt$topic.model.result.dir
# SEP=opt$sep
# K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
## FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
## FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
# SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr

db <- ifelse(grepl("mouse|org.Mm.eg.db", opt$organism), "org.Mm.eg.db", "org.Hs.eg.db")
library(!!db) ## load the appropriate database

# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)


## helper function to map between ENSGID and SYMBOL
map.ENSGID.SYMBOL <- function(df) {
    ## need column `Gene` to be present in df
    ## detect gene data type (e.g. ENSGID, Entrez Symbol)
    gene.type <- ifelse(nrow(df) == sum(as.numeric(grepl("^ENS", df$Gene))),
                        "ENSGID",
                        "Gene")
    if(gene.type == "ENSGID") {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "ENSEMBL", column = "SYMBOL")
        df <- df %>% mutate(ENSGID = Gene, Gene = mapped.genes)
    } else {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "SYMBOL", column = "ENSEMBL")
        df <- df %>% mutate(ENSGID = mapped.genes)
    }
    return(df)
}



######################################################################
## Load topic model results    
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
}

## get list of topic defining genes
theta.rank.list <- vector("list", ncol(theta.zscore))## initialize storage list
for(i in 1:ncol(theta.zscore)) {
    topic <- paste0("topic_", colnames(theta.zscore)[i])
    theta.rank.list[[i]] <- theta.zscore %>%
        as.data.frame %>%
        select(all_of(i)) %>%
        `colnames<-`("topic.zscore") %>%
        mutate(Gene = rownames(.)) %>%
        arrange(desc(topic.zscore), .before="topic.zscore") %>%
        mutate(zscore.specificity.rank = 1:n()) %>% ## add rank column
        mutate(Topic = topic) ## add topic column
}
theta.rank.df <- do.call(rbind, theta.rank.list) %>%  ## combine list to df
    `colnames<-`(c("topic.zscore", "Gene", "zscore.specificity.rank", "ProgramID")) %>%
    mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
    as.data.frame %>% map.ENSGID.SYMBOL

## get list of topic genes by raw weight
theta.raw.rank.list <- vector("list", ncol(theta.raw))## initialize storage list
for(i in 1:ncol(theta.raw)) {
    topic <- paste0("topic_", colnames(theta.raw)[i])
    theta.raw.rank.list[[i]] <- theta.raw %>%
        as.data.frame %>%
        select(all_of(i)) %>%
        `colnames<-`("topic.raw") %>%
        mutate(Gene = rownames(.)) %>%
        arrange(desc(topic.raw), .before="topic.raw") %>%
        mutate(raw.score.rank = 1:n()) %>% ## add rank column
        mutate(Topic = topic) ## add topic column
}
theta.raw.rank.df <- do.call(rbind, theta.raw.rank.list) %>%  ## combine list to df
    `colnames<-`(c("topic.raw", "Gene", "raw.score.rank", "ProgramID")) %>%
    mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
    as.data.frame %>% map.ENSGID.SYMBOL


## get list of topic genes by median spectra weight
median.spectra.rank.list <- vector("list", ncol(median.spectra))## initialize storage list
for(i in 1:ncol(median.spectra)) {
    topic <- paste0("topic_", colnames(median.spectra)[i])
    median.spectra.rank.list[[i]] <- median.spectra %>%
        as.data.frame %>%
        select(all_of(i)) %>%
        `colnames<-`("median.spectra") %>%
        mutate(Gene = rownames(.)) %>%
        arrange(desc(median.spectra), .before="median.spectra") %>%
        mutate(median.spectra.rank = 1:n()) %>% ## add rank column
        mutate(Topic = topic) ## add topic column
}
median.spectra.rank.df <- do.call(rbind, median.spectra.rank.list) %>%  ## combine list to df
    `colnames<-`(c("median.spectra", "Gene", "median.spectra.rank", "ProgramID")) %>%
    mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
    as.data.frame %>% map.ENSGID.SYMBOL
## median.spectra.zscore.df <- median.spectra.zscore.df %>% mutate(Gene = ENSGID) ## quick fix, need to add "Gene" column to this dataframe in analysis script

######################################################################
## run cluster profiler GSEA on top 300 genes

## ## map between EntrezID and Gene Symbol
## z <- org.Hs.egSYMBOL
## z_mapped_genes <- mappedkeys(z)
## entrez.to.symbol <- as.list(z[z_mapped_genes])

## ## entrez.to.symbol <- as.list(org.Hs.egSYMBOL)
## symbol.to.entrez <- as.list(org.Hs.egSYMBOL2EG)

## ## map to entrez id (function)
## symbolToEntrez <- function(df) df %>% mutate(EntrezID = symbol.to.entrez[.$gene %>% as.character] %>% sapply("[[",1) %>% as.character)

## subset to top 300 genes

ranking.type.ary <- c("zscore", "raw", "median_spectra_zscore", "median_spectra")
score.colname.ary <- c("zscore", "raw", "median.spectra.zscore", "median.spectra")
ranking.rank.colname.ary <- c("zscore.specificity.rank", "raw.score.rank", "median.spectra.zscore.rank", "median.spectra.rank")
ranking.type.varname.ary <- c("theta.rank.df", "theta.raw.rank.df", "median.spectra.zscore.df", "median.spectra.rank.df")

getData <- function(t) {
    i <- which(ranking.type.ary == opt$ranking.type)
    programID.here <- paste0("K", k, "_", t)
    ranking.type.varname.here <- ranking.type.varname.ary[i]
    if(grepl("median.spectra", ranking.type.varname.here)) {
        ranking.score.colname.here <- score.colname.ary[i] 
    } else {
        ranking.score.colname.here <- paste0("topic.", ranking.type.ary[i])
    }
    gene.df <- get(ranking.type.varname.here) %>%
        subset(ProgramID == programID.here)
    gene.type <- ifelse(nrow(gene.df) == sum(as.numeric(grepl("^ENS", gene.df$Gene))), "ENSGID", "Gene")
    mapped.genes <- mapIds(get(db),
                           keys=gene.df$Gene,
                           keytype = ifelse(gene.type == "Gene", "SYMBOL", "ENSEMBL"),
                           column = ifelse(gene.type == "Gene", "ENSEMBL", "SYMBOL"))
    mapped.entrez.genes <- mapIds(get(db),
                           keys=gene.df$Gene,
                           keytype = ifelse(gene.type == "Gene", "SYMBOL", "ENSEMBL"),
                           column = "ENTREZID")
    gene.df <- gene.df %>%
        mutate(!!gene.type := Gene,
               !!ifelse(gene.type=="ENSGID", "Gene", "ENSGID") := mapped.genes,
               EntrezID = mapped.entrez.genes) %>%
        as.data.frame


    gene.weights <- gene.df %>% pull(get(ranking.score.colname.here)) %>% `names<-`(gene.df$EntrezID)
    gene.weights[gene.weights < 0] <- 0    

    top.gene.df <- gene.df %>%
        subset(get(ranking.rank.colname.ary[i]) <= 300) %>%
        as.data.frame
    ## top.genes <- unlist(mget(top.gene.df$Gene, envir=org.Hs.egSYMBOL2EG, ifnotfound=NA)) ## old
    top.genes <- top.gene.df %>% pull(EntrezID) ## same as above

    pos.gene.df <- gene.df %>%
        subset(get(ranking.score.colname.here) > 0) %>%
        as.data.frame
    pos.genes <- pos.gene.df %>% pull(EntrezID)
    ## pos.genes <- unlist(mget(pos.gene.df$Gene, envir=org.Hs.egSYMBOL2EG, ifnotfound=NA))

    ## geneUniverse <- unlist(mget(get(ranking.type.varname.ary[i])$Gene %>% unique, envir=org.Hs.egSYMBOL2EG, ifnotfound=NA))
    geneUniverse <- gene.df$EntrezID

    return(list(top.genes = top.genes, pos.genes = pos.genes, geneUniverse = geneUniverse, gene.weights = gene.weights))
}

m_df <- msigdbr(species = ifelse(grepl("mouse", opt$organism), "Mus musculus", "Homo sapiens"))

## save this as a txt file and read in ## for future if needed
functionsToRun <- list(GOEnrichment = "out <- enrichGO(gene = top.genes, ont = 'ALL', OrgDb = db, universe = geneUniverse, readable=T, pvalueCutoff=1, pAdjustMethod = 'fdr') %>% as.data.frame %>% mutate(fdr.across.ont = p.adjust, ProgramID = paste0('K', k, '_', t))",
                       PosGenesGOEnrichment = "out <- enrichGO(gene = pos.genes, ont = 'ALL', OrgDb = db, universe = geneUniverse, readable=T, pvalueCutoff=1, pAdjustMethod='fdr') %>% as.data.frame %>% mutate(fdr.across.ont = p.adjust, ProgramID = paste0('K', k, '_', t))",
                       ByWeightGSEA = "out <- GSEA(gene.weights, TERM2GENE = m_df %>% select(gs_name, entrez_gene), pAdjustMethod = 'fdr', pvalueCutoff = 1) %>% as.data.frame %>% mutate(ProgramID = paste0('K', k, '_', t)) ",
                       GSEA = "out <- enricher(top.genes, TERM2GENE = m_df %>% select(gs_name, entrez_gene), universe = geneUniverse, pAdjustMethod = 'fdr', qvalueCutoff=1) %>% as.data.frame %>% mutate(ProgramID = paste0('K', k, '_', t))"
                       )

## for(i in 1:length(ranking.type.ary)) {
ranking.type.here <- opt$ranking.type
GSEA.type <- opt$GSEA.type
## ranking.type.here <- ranking.type.ary[i]
## GO enrichment analysis. Include all of:
## MF: Molecular Function
## CC: Cellular Component
## BP: Biological Process
## ans.go <- do.call(rbind, lapply(1:60, function(t) {

message("starting enrichment")

## out.list <- lapply(1:length(functionsToRun), function(j) {
out <- do.call(rbind, lapply(c(1:k) %>% rev, function(t) {
    data.here <- getData(t)
    top.genes <- data.here$top.genes
    if(sum(as.numeric(is.na(names(top.genes)))) > 0) top.genes <- top.genes[-which(is.na(names(top.genes)))] ## remove genes that doesn't have matched Entrez ID
    ## print(head(top.genes))

    pos.genes <- data.here$pos.genes
    if(sum(as.numeric(is.na(names(pos.genes)))) > 0) pos.genes <- pos.genes[-which(is.na(names(pos.genes)))]
    ## print(head(pos.genes))

    geneUniverse <- data.here$geneUniverse
    ## print(head(geneUniverse))

    gene.weights <- data.here$gene.weights
    if (sum(as.numeric(is.na(names(gene.weights)))) > 0) gene.weights <- gene.weights[-which(is.na(names(gene.weights)))]
    if (which(gene.weights==0) %>% length > 0) gene.weights <- gene.weights[-which(gene.weights==0)] ## can't have zero weights?
    if (length(which(is.na(gene.weights))) > 0) gene.weights <- gene.weights[-which(is.na(gene.weights))] ## can't have NA
    ## print(head(gene.weights))

    message(paste0("Ranking type: ", ranking.type.here, ", Program ", t, ", out of ", k, ", function ", GSEA.type, ", top gene class: ", class(top.genes),
                   "\n geneUniverse class: ", class(geneUniverse), ", gene.weights class: ", class(gene.weights)))
    ## message(paste0("Function to run: \n", functionsToRun[[GSEA.type]]))
    eval(parse(text = functionsToRun[[GSEA.type]]))
    return(out)
}))
## if(j == 1) {
file.name <- paste0(OUTDIRSAMPLE, "/clusterProfiler_GeneRankingType", ranking.type.here, "_EnrichmentType", GSEA.type,".txt")
## } else if (j == 2) {
##     file.name <- paste0(OUTDIRSAMPLE, "/clusterProfiler_allGene", ranking.type.here, "_ByWeight_GSEA.txt")
## } else {
##     file.name <- paste0(OUTDIRSAMPLE, "/clusterProfiler_top300Genes", ranking.type.here, "_GSEA.txt")
## }
message(paste0("output table to ", file.name))
write.table(out, file.name, sep="\t", row.names=F, quote=F)
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
packages <- c("optparse","dplyr", "ggplot2", "reshape2", "ggrepel", "conflicted")
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")



option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),

  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--ep.type", type="character", default="enhancer", help="motif enrichment for enhancer or promoter, specify 'enhancer' or 'promoter'"),
  make_option("--adj.p.value.thr", type="numeric", default=0.05, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute"),
  make_option("--motif.match.thr.str", type="character", default="pval1e-6", help="threshold for subsetting motif matches")

)
opt <- parse_args(OptionParser(option_list=option.list))

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60

mytheme <- theme_classic() + theme(axis.text = element_text(size = 7),
                                   axis.title = element_text(size = 8),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=8))

SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
# STATIC.SAMPLE=c("Telo_no_IL1B_T200_1", "Telo_no_IL1B_T200_2", "Telo_plus_IL1B_T200_1", "Telo_plus_IL1B_T200_2", "no_IL1B", "plus_IL1B",  "pooled")
# DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
## TMDIR=opt$topic.model.result.dir
## SEP=opt$sep
# K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
num.top.genes <- 300 ## number of top topic defining genes
ep.type <- opt$ep.type
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
# SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr
motif.match.thr.str <- opt$motif.match.thr.str

## ## directories for factor motif enrichment
## FILENAME=opt$filename


## ## modify motif.enhancer.background input directory ##HERE: perhaps do a for loop for all the desired thresholds (use strsplit on enhancer.fimo.threshold)
## opt$motif.enhancer.background <- paste0(opt$motif.enhancer.background, opt$enhancer.fimo.threshold, "/fimo.formatted.tsv")


# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE, FGSEADIR, FGSEAFIG)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)



######################################################################
## load data
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
} else {
    warning(paste0(cNMF.result.file, " does not exist"))
}

## load motif enrichment results
all.ttest.df.path <- paste0(OUTDIRSAMPLE,"/", ep.type, ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.", motif.match.thr.str, "_", SUBSCRIPT.SHORT,".txt")
ttest.df <- read.delim(all.ttest.df.path, stringsAsFactors=F)



# file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.factorMotifEnrichment.",SUBSCRIPT.SHORT,".RData")
# print(file.name)
# if(file.exists((file.name))) { 
#     load(file.name)
#     print(paste0("loading ", file.name))
# }
# motif.enrichment.variables <- c("all.enhancer.fisher.df", "all.promoter.fisher.df", 
#                                 "promoter.wide", "enhancer.wide", "promoter.wide.binary", "enhancer.wide.binary",
#                                 "enhancer.wide.10en6", "enhancer.wide.binary.10en6", "all.enhancer.fisher.df.10en6",
#                                 "promoter.wide.10en6", "promoter.wide.binary.10en6", "all.promoter.fisher.df.10en6",
#                                 "all.promoter.ttest.df", "all.promoter.ttest.df.10en6", "all.enhancer.ttest.df", "all.enhancer.ttest.df.10en6")
# motif.enrichment.variables.missing <- (!(motif.enrichment.variables %in% ls())) %>% as.numeric %>% sum 
# if ( motif.enrichment.variables.missing > 0 ) {
#     warning(paste0(motif.enrichment.variables[!(motif.enrichment.variables %in% ls())], " not available"))
# }


## End of data loading



##########################################################################
## Plots


## volcano plots
volcano.plot <- function(toplot, ep.type, ranking.type, label.type="") {
    if( label.type == "pos") {
        label <- toplot %>% subset(two.sided.p.adjust < fdr.thr & enrichment.log2fc > 0) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
    } else {
        label <- toplot %>% subset(two.sided.p.adjust < fdr.thr) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
    }
    t <- gsub("topic_", "", toplot$topic[1])
    p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(two.sided.p.adjust))) + geom_point(size=0.5) + mytheme +
        ggtitle(paste0(SAMPLE[1], " Topic ", t, " Top ", num.top.genes, " ", ranking.type,"\n", ifelse(ep.type=="promoter", "Promoter", "Enhancer"), " Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(adjusted p-value)") +
        xlim(0,max(toplot$enrichment.log2fc)) +
        geom_hline(yintercept=-log10(fdr.thr), linetype="dashed", color="gray") +
        geom_text_repel(data=label, box.padding = 0.25,
                        aes(label=motif.toshow), size=2.5,
                        max.overlaps = 15,
                        color="black")# + theme(text=element_text(size=16), axis.title=element_text(size=16), axis.text=element_text(size=16), plot.title=element_text(size=14))
    print(p)
    p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(two.sided.p.value))) + geom_point(size=0.25) + mytheme +
        ggtitle(paste0(SAMPLE[1], " Topic ", t, " Top ", num.top.genes," ", ranking.type,"\n", ifelse(ep.type=="promoter", "Promoter", "Enhancer"), " Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(p-value)") +
        xlim(0,max(toplot$enrichment.log2fc)) +
        geom_hline(yintercept=-log10(fdr.thr), linetype="dashed", color="gray") +
        geom_text_repel(data=label, box.padding = 0.25,
                        aes(label=motif.toshow), size=2.5,
                        max.overlaps = 15,
                        color="black") #+ theme(text=element_text(size=16), axis.title=element_text(size=16), axis.text=element_text(size=16), plot.title=element_text(size=14))
    return(p)
}

## function for all volcano plots
all.volcano.plots <- function(all.fisher.df, ep.type, ranking.type, label.type="") {
    for ( t in 1:k ){
        toplot <- all.fisher.df %>% subset(topic==paste0("topic_",t))
        volcano.plot(toplot, ep.type, ranking.type, label.type) %>% print()
    }
}


##########################################################################
## motif enrichment plot
pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.count.ttest.enrichment_motif.thr.", motif.match.thr.str, ".pdf"), width=3, height=3)
all.volcano.plots(get(paste0("ttest.df")) %>% subset(top.gene.mean != 0 & !grepl("X.NA.",motif)), ep.type, ranking.type="z-score")
dev.off()
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")

packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel","ramify",
              "ggpubr","gridExtra","RNOmni",
              "org.Hs.eg.db","limma","fgsea", "conflicted",
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "Seurat", "writexl") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")



## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")
source("./workflow/scripts/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/", help="Output directory"),
  # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210707_snakemake_maxParallel/all_genes_acrossK/2kG.library/", help="Topic model results directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  make_option("--K.list", type="character", default="2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,19,21,23,25", help="K values available for analysis"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  make_option("--raw.mtx.dir",type="character",default="stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"), # the first lane: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.expandedMultiTargetGuide.RDS"
  make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),
  make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/2kG.library.barcodes.tsv", help="barcodes.tsv for all cells"),
  make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  make_option("--enhancer.fimo.threshold", type="character", default="1.0E-4", help="Enhancer fimo motif match threshold"),

  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute"),
  make_option("--perturb.seq", type="character", default="False", help="True for perturb-seq. The pipeline will perform statistical test if True."),

  ## Organism flag
  make_option("--organism", type="character", default="human", help="Organism type, accept org.Hs.eg.db. Only support human and mouse.")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## 2n dataset (for sdev)
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/figures/all_genes/"
## opt$K.val <- 60
## opt$sampleName <- "Perturb_2kG_dup4"

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/all_genes_acrossK/all_genes_acrossK/2kG.library/"
## opt$K.val <- 60

## ## debug ctrl
## opt$topic.model.result.dir <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210810_snakemake_ctrls/all_genes_acrossK/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation"
## opt$sampleName <- "2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210810_snakemake_ctrls/figures/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210810_snakemake_ctrls/analysis/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation/all_genes/"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210806_curate_ctrl_mtx/outputs/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation.barcodes.tsv"
## opt$K.val <- 60

## ## ctrl 2nd round
## opt$sampleName <- "2kG.library.ctrl.only"
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes_acrossK/2kG.library.ctrl.only/"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210806_curate_ctrl_mtx/211206_ctrl_mtx_for_cNMF_pipeline/outputs/ctrl_mtx/barcodes.tsv"
## opt$subsample.type <- "ctrl"
## opt$K.val <- 8

## ## debug scRNAseq_2kG_11AMDox_1
## opt$sampleName <- "scRNAseq_2kG_11AMDox_1"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/"
## opt$K.val <- 14
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes_acrossK/scRNAseq_2kG_11AMDox_1"

## ## debug K562 gwps sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes_acrossK/WeissmanK562gwps/"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/WeissmanLab_data/K562_gwps_raw_singlecell_01_metadata.txt"
## opt$K.val <- 25
## opt$sampleName <- "WeissmanK562gwps"

## ## debug mouse ENCODE heart sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes/"
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes_acrossK/mouse_ENCODE_heart/"
## opt$K.val <- 45
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/collab_data/IGVF/mouse_ENCODE_adrenal/auxiliary_data/snrna/adrenal_Parse_10x_integrated_metadata.csv" ## sdev for mouse ENCODE


## ## debug IGVF b01_LeftCortex sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/"
## opt$topic.model.result.dir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes_acrossK/IGVF_b01_LeftCortex/"
## opt$K.val <- 15
## opt$sampleName <- "IGVF_b01_LeftCortex"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_igvf_b01_LeftCortex_data/IGVF_b01_LeftCortex.barcodes.txt"
## opt$organism <- "mouse"



mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))

SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
TMDIR=opt$topic.model.result.dir
SEP=opt$sep
# K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
## FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
## FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr

## ## directories for factor motif enrichment
## FILENAME=opt$filename


## ## modify motif.enhancer.background input directory ##HERE: perhaps do a for loop for all the desired thresholds (use strsplit on enhancer.fimo.threshold)
## opt$motif.enhancer.background <- paste0(opt$motif.enhancer.background, opt$enhancer.fimo.threshold, "/fimo.formatted.tsv")


# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, OUTDIRSAMPLE, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

## palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
# selected.gene <- c("EDN1", "NOS3", "TP53", "GOSR2", "CDKN1A")
# # ABC genes
# gene.set <- c("INPP5B", "SF3A3", "SERPINH1", "NR2C1", "FGD6", "VEZT", "SMAD3", "AAGAB", "GOSR2", "ATP5G1", "ANGPTL4", "SRBD1", "PRKCE", "DAGLB") # ABC_0.015_CAD_pp.1_genes #200 gene library

# # cell cycle genes
# ## need to update these for 2kG library
# gene.list.three.groups <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/ptbd.genes_three.groups.txt"), header=T, stringsAsFactors=F)
# enhancer.set <- gene.list.three.groups$Gene[grep("E_at_", gene.list.three.groups$Gene)]
# CAD.focus.gene.set <- gene.list.three.groups %>% subset(Group=="CAD_focus") %>% pull(Gene) %>% append(enhancer.set)
# EC.pos.ctrl.gene.set <- gene.list.three.groups %>% subset(Group=="EC_pos._ctrls") %>% pull(Gene)

cell.count.thr <- opt$cell.count.thr # greater than this number, filter to keep the guides with greater than this number of cells
guide.count.thr <- opt$guide.count.thr # greater than this number, filter to keep the perturbations with greater than this number of guides

# guide.design = read.delim(file=paste0(DATADIR, "/200607_ECPerturbSeqMiniPool.design.txt"), header=T, stringsAsFactors = F)


# ## add GO pathway log2FC
# GO <- read.delim(file=paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.table.brief.txt"), header=T, check.names=FALSE)
# GO.list <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.list.brief.txt", header=T, check.names=F)
# colnames(GO)[1] <- "Gene"
# colnames(GO.list)[1] <- "Gene"
# ## load all sample, K, topic's top 100 genes (by TopFeatures() KL-score measure)
# ## allGeneKtopic100 <- read.delim(paste0(TMDIR, "no.plus.pooled.top100.topicStats.txt"), header=T)
# # load non-expressed control gene list
# non.expressed.genes <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/non.expressed.ctrl.genes.txt", header=F, stringsAsFactors=F) %>% unlist %>% as.character() %>% sort()

# # perturbation type list
# gene.set.type.df <- data.frame(Gene=guide.design %>% pull(guideSet) %>% unique(),
#                                type=rep("other", guide.design %>% pull(guideSet) %>% unique() %>% length())) 
# gene.set.type.df$Gene <- gene.set.type.df$Gene %>% as.character()
# gene.set.type.df$type <- gene.set.type.df$type %>% as.character()
# gene.set.type.df$type[which(gene.set.type.df$Gene %in% non.expressed.genes)] <- "non-expressed"
# gene.set.type.df$type[which(gene.set.type.df$Gene %in% CAD.focus.gene.set)] <- "CAD focus"
# gene.set.type.df$type[grepl("^safe|^negative", gene.set.type.df$Gene)] <- "negative-control"
# gene.set.type.df$Gene[which(gene.set.type.df$Gene == "negative_control")] <- "negative-control"
# gene.set.type.df$Gene[which(gene.set.type.df$Gene == "safe_targeting")] <- "safe-targeting"
# # gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"

# gene.set.type.df.200 <- gene.set.type.df

# # reference table
# ref.table <- read_xlsx(opt$reference.table, sheet="2000_gene_library_annotated") 
# gene.set.type.df <- ref.table %>% select(Symbol, `Class(es)`) %>% `colnames<-`(c("Gene", "type"))
# gene.set.type.df$type[grepl("EC_ctrls", gene.set.type.df$type)] <- "EC_ctrls"
# gene.set.type.df$type[grepl("NonExpressed", gene.set.type.df$type)] <- "non-expressed"
# gene.set.type.df$type[grepl("abc.015", gene.set.type.df$type)] <- "ABC"
# gene.set.type.df <- rbind(gene.set.type.df, c("negative-control", "negative-control"), c("safe-targeting", "safe-targeting"))
# non.expressed.genes <- gene.set.type.df %>% subset(type == "non-expressed") %>% pull(Gene)
# # ABC genes
# gene.set <- gene.set.type.df %>% subset(grepl("ABC", type)) %>% pull(Gene)

# ## add GWAS classification
# modified.ref.table <- ref.table %>% mutate(GWAS.classification="")
# CAD.index <- which(grepl("CAD_Loci",ref.table$`Class(es)`))
# EC_ctrls.index <- which(grepl("^EC_ctrls",ref.table$`Class(es)`))
# ABC_linked.index <- which(grepl("MIG_etc",ref.table$`Class(es)`))
# IBD.index <- which(grepl("Non-CAD_loci_IBD",ref.table$`Class(es)`))
# non.expressed.index <- which(grepl("NonExpressed",ref.table$`Class(es)`))
# poorly.annotated.9p21.index <- which(grepl("9p21",ref.table$`Class(es)`))
#                                         # length(CAD.index) + length(EC_ctrls.index) + length(ABC_linked.index) + length(IBD.index) + length(non.expressed.index) + length(poorly.annotated.9p21.index)
# modified.ref.table$GWAS.classification[ABC_linked.index] <- "ABC"
# modified.ref.table$GWAS.classification[IBD.index] <- "IBD"
# modified.ref.table$GWAS.classification[non.expressed.index] <- "NonExpressed"
# modified.ref.table$GWAS.classification[poorly.annotated.9p21.index] <- "9p21.poorly.annotated"
# modified.ref.table$GWAS.classification[EC_ctrls.index] <- "EC_ctrls"
# modified.ref.table$GWAS.classification[CAD.index] <- "CAD"

# modified.ref.table <- modified.ref.table %>% group_by(GWAS.classification) %>% mutate(gene.count.per.GWAS.category = n())
# ref.table <- modified.ref.table

# ## add TSS distance to SNP
# modified.ref.table <- ref.table %>% mutate(TSS.dist.to.SNP = abs(`TSS v. SNP loc`))
# not.in.SNP.index <- which(is.na(modified.ref.table$`TSS v. SNP loc`))
# modified.ref.table$TSS.dist.to.SNP[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table %>% ungroup()

# ## add closest gene to top GWAS loci ranking
# modified.ref.table <- ref.table %>%
#     group_by(`Top SNP ID`) %>% # per SNP metrics
#     arrange(abs(`TSS v. SNP loc`)) %>%
#     mutate(TSS.v.SNP.ranking = 1:n(),
#            total.gene.in.this.loci = n()) %>% ungroup() %>%
#     group_by(`Top SNP ID`, GWAS.classification) %>% # per SNP per GWAS class (CAD, IBD, NonExpressed, ABC, 9p21.poorly.annotated) 
#     arrange(abs(`TSS v. SNP loc`)) %>%
#     mutate(TSS.v.SNP.ranking.in.GWAS.category = 1:n(),
#            total.gene.in.this.loci.in.GWAS.category = n()) %>% ungroup()
# not.in.SNP.index <- which(is.na(modified.ref.table$`TSS v. SNP loc`))
# modified.ref.table$TSS.v.SNP.ranking.in.GWAS.category[not.in.SNP.index] <- NA
# modified.ref.table$TSS.v.SNP.ranking[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table

# ## add gene count per distance ranking per GWAS loci
# modified.ref.table <- ref.table
# modified.ref.table <- modified.ref.table %>%
#     group_by(TSS.v.SNP.ranking) %>% # per ranking, not considering which GWAS category the gene is from
#     mutate(total.TSS.v.SNP.ranking.count = n()) %>% ungroup() %>%
#     group_by(GWAS.classification, TSS.v.SNP.ranking.in.GWAS.category) %>% # per GWAS category and per ranking
#     mutate(total.TSS.v.SNP.ranking.count.per.GWAS.classification = n()) %>% ungroup()
# not.in.SNP.index <- which(is.na(modified.ref.table$TSS.v.SNP.ranking))
# modified.ref.table$total.TSS.v.SNP.ranking.count[not.in.SNP.index] <- NA
# modified.ref.table$total.TSS.v.SNP.ranking.count.per.GWAS.classification[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table

# write.table(ref.table, file=paste0(opt$datadir, "/ref.table.txt"), row.names=F, quote=F, sep="\t")

# ## ref.table ranking count summary table
# ref.table.gene.to.SNP.dist.ranking.count.summary.allGWAS <- ref.table %>% select(TSS.v.SNP.ranking, total.TSS.v.SNP.ranking.count) %>% mutate(GWAS.classification="all") %>% unique()
# ref.table.gene.to.SNP.dist.ranking.count.summary.indGWAS <- ref.table %>% select(TSS.v.SNP.ranking.in.GWAS.category, total.TSS.v.SNP.ranking.count.per.GWAS.classification, GWAS.classification) %>% `colnames<-`(c("TSS.v.SNP.ranking", "total.TSS.v.SNP.ranking.count", "GWAS.classification")) %>% unique()
# ref.table.gene.to.SNP.dist.ranking.count.summary <- rbind(ref.table.gene.to.SNP.dist.ranking.count.summary.allGWAS, ref.table.gene.to.SNP.dist.ranking.count.summary.indGWAS)
# ref.table.summary.na.index <- which(is.na(ref.table.gene.to.SNP.dist.ranking.count.summary$TSS.v.SNP.ranking))
# ref.table.gene.to.SNP.dist.ranking.count.summary <- ref.table.gene.to.SNP.dist.ranking.count.summary[-ref.table.summary.na.index,]
# rm(ref.table.summary.na.index)


# # convert enhancer SNP rs number to enhancer target gene name # need 2kG library version
# enh.snp.to.gene <- read.delim(paste0(DATADIR, "/enhancer.SNP.to.gene.name.txt"), header=T, stringsAsFactors = F) %>% mutate(Enhancer_name=gsub("_","-", Enhancer_name))

# # gene corresponding pathway
# gene.def.pathways <- read_excel(paste0(DATADIR,"topic.gene.definition.pathways.xlsx"), sheet="Gene_Pathway")

# ## Gavin's new list
# gene.classes.ranked <- read.table(paste0(opt$datadir, "Gene_Classes_Ranked_for_CAD_n_EC.txt"), header=T, stringsAsFactors = F)
# summaries <- read.delim(paste0(opt$datadir, "Gene_Summaries_n_Classes.txt"), sep="\t", header=T, stringsAsFactors = F)
# gene.summaries <- read_xlsx(paste0(opt$datadir, "Gene_Summaries.xlsx"), sheet="uniprot_summaries")


# print("loaded all prerequisite data")



## for the guides that target multiple genes, we will split the gene annotation and duplicate the cell entry, so that each gene will get a cell read out.
adjust.multiTargetGuide.rownames <- function(omega) {
    ## duplicate cells with guide that targets multiple genes
    cells.with.multiTargetGuide <- rownames(omega)[grepl("and",rownames(omega))]
    if(length(cells.with.multiTargetGuide) > 0) {
        ## split index by number of guide targets
        cells.with.multiTargetGuide.index <- which(grepl("and",rownames(omega)))
        cells.with.singleTargetGuide.index <- which(!grepl("and",rownames(omega)))
        ## get multi target gene names
        multiTarget.names <- cells.with.multiTargetGuide %>% strsplit(., split=":") %>% sapply("[[",1) ## full names: GeneA-and-GeneB
        multiTarget.Guide.CBC <- cells.with.multiTargetGuide %>% strsplit(., split=":") %>% sapply( function(x) paste0(x[[2]], ":", x[[3]]) )
        multiTarget.names.1 <- multiTarget.names %>% strsplit(., split="-and-") %>% sapply ("[[",1)
        multiTarget.names.2 <- multiTarget.names %>% strsplit(., split="-and-") %>% sapply ("[[",2)
        multiTarget.names.all <- multiTarget.names.1 %>% append(multiTarget.names.2) %>% unique() ## get all the genes/enhancers that have guides targeting other gene/enhancer at the same time
        cells.with.multiTarget.gene.names.index <- which(grepl(paste0(multiTarget.names.all,collapse="|"), rownames(omega)))

        multiTarget.long.CBC.1 <- sapply(1:length(multiTarget.names), function(i) {
            paste0(multiTarget.names.1[i], "_multiTarget:", multiTarget.Guide.CBC[i])
        })
        multiTarget.long.CBC.2 <- sapply(1:length(multiTarget.names), function(i) {
            paste0(multiTarget.names.2[i], "_multiTarget:", multiTarget.Guide.CBC[i])
        })
        ## change original df's rownames
        multiTargetGuide.mtx <- omega[cells.with.multiTargetGuide.index,]
        multiTargetGuide.mtx.1 <- multiTargetGuide.mtx %>% `rownames<-`(multiTarget.long.CBC.1)
        multiTargetGuide.mtx.2 <- multiTargetGuide.mtx %>% `rownames<-`(multiTarget.long.CBC.2)
        ## pull cells with guides that has a single target, but the perturbed gene has multiTarget guide
        expanded.gene.name.df <- do.call(rbind, lapply(1:length(multiTarget.names.all), function(i) {
            gene.name.here <- multiTarget.names.all[i]
            toPaste.gene.name.here <- multiTarget.names[which(grepl(gene.name.here, multiTarget.names))] %>% gsub("-TSS2","",.) %>% unique()
            out <- do.call(rbind, lapply(1:length(toPaste.gene.name.here), function(j) {
                singleTarget.cell.index.here <- which(grepl(gene.name.here,rownames(omega)) & !grepl("and",rownames(omega)))
                singleTarget.cell.df <- omega[singleTarget.cell.index.here,]
                rownames(singleTarget.cell.df) <- gsub(gene.name.here,toPaste.gene.name.here[j],rownames(singleTarget.cell.df))
                return(singleTarget.cell.df)
            }))
            return(out)
        })) # takes two minutes
        omega <- rbind(omega[cells.with.singleTargetGuide.index,], multiTargetGuide.mtx.1, multiTargetGuide.mtx.2, expanded.gene.name.df)
    }
    return(omega)
}    


######################################################################
## Process topic model results
## for ( n in 1:length(SAMPLE) ) {

    ## if (SEP) {
    ##     guideCounts <- loadGuides(n, sep=T) %>% mutate(Gene=Gene.marked)
    ##     tmp.labels <- guideCounts$Gene %>% unique() %>% strsplit("-") %>% sapply("[[",2) %>% unique()
    ##     tmp.labels <- tmp.labels[!(tmp.labels %in% c("control","targeting"))]
    ##     rep1.label <- paste0("-",tmp.labels[1])
    ##     rep2.label <- paste0("-",tmp.labels[2])
    ## } else guideCounts <- loadGuides(n) %>% mutate(Gene=Gene.marked)

db <- ifelse(grepl("mouse|org.Mm.eg.db", opt$organism), "org.Mm.eg.db", "org.Hs.eg.db")
library(!!db) ## load the appropriate database

cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
    print("finished loading cNMF result file")
} else {
    theta.path <- paste0(TMDIR, "/", SAMPLE, ".gene_spectra_tpm.k_", k, ".dt_", DENSITY.THRESHOLD,".txt")
    theta.zscore.path <- paste0(TMDIR, "/", SAMPLE, ".gene_spectra_score.k_", k, ".dt_", DENSITY.THRESHOLD,".txt")
    median.spectra.path <- paste0(TMDIR, "/", SAMPLE, ".spectra.k_", k, ".dt_", DENSITY.THRESHOLD,".consensus.txt")
    print(theta.path)
    theta.raw <- read.delim(theta.path, header=T, stringsAsFactors=F, check.names=F, row.names=1)
    median.spectra <- read.delim(median.spectra.path, header=T, stringsAsFactors=F, check.names=F, row.names=1)
    ## theta.raw <- read.delim(theta.path, header=T, stringsAsFactors=F, check.names=F) %>% select(-``)
    print("finished reading raw weights for topics")
    tmp.theta <- theta.raw
    tmp.theta[tmp.theta==0] <- min(tmp.theta[tmp.theta > 0])/100
    theta <- tmp.theta %>% apply(1, function(x) x/sum(x)) %>% `colnames<-`(c(1:k))
    theta.raw <- theta.raw %>% t() %>% as.data.frame() %>% `colnames<-`(c(1:k))
    median.spectra <- median.spectra %>% t() %>% as.data.frame %>% `colnames<-`(c(1:k))
    print("loading topic z-score coefficient")
    theta.zscore <- read.delim(theta.zscore.path, header=T, stringsAsFactors=F, check.names=F, row.names=1) %>% t() %>% `colnames<-`(c(1:k)) 
    tmp <- rownames(theta) %>% strsplit(., split=":") %>% sapply("[[",1)
    tmpp <- data.frame(table(tmp)) %>% subset(Freq > 1)  # keep row names that have duplicated gene names but different ENSG names
    tmp.copy <- tmp
    tmp.copy[grepl(paste0(tmpp$tmp,collapse="|"),tmp)] <- rownames(theta)[grepl(paste0(tmpp$tmp,collapse="|"),rownames(theta))]
    rownames(theta) <- rownames(theta.raw) <- rownames(theta.zscore) <- tmp.copy

    ## median.spectra.names <- median.spectra %>% rownames %>% strsplit(split=":") %>% sapply(`[[`,1)
    ## rownames(median.spectra) <- median.spectra.names
    median.spectra.zscore <- apply(median.spectra, MARGIN=1, function(x) (x - mean(x)) / sd(x)) %>% t
    if(grepl("2kG.library", SAMPLE)) {
        median.spectra.zscore.df <- median.spectra.zscore %>%
            as.data.frame %>%
            mutate(Gene.full.name = rownames(.)) %>%
            separate(col="Gene.full.name", sep=":", remove=F, into = c("Gene", "ENSGID")) %>%
            melt(id.vars = c("Gene.full.name", "Gene", "ENSGID"), variable.name="ProgramID", value.name="median.spectra.zscore") %>%
            mutate(ProgramID = paste0("K", k, "_", ProgramID)) %>%
            as.data.frame %>%
            group_by(ProgramID) %>%
            arrange(desc(median.spectra.zscore)) %>%
            mutate(median.spectra.zscore.rank = 1:n()) %>%
            select(-Gene.full.name) %>%
            as.data.frame
        ## put median spectra zscore into ENSGID format for PoPS
        median.spectra.zscore.formatted <- median.spectra.zscore %>%
            as.data.frame %>%
            mutate(Gene.ENSGID = rownames(.)) %>%
            separate(col="Gene.ENSGID", sep=":", remove=F, into = c("Gene", "ENSGID_from_input")) %>%
            as.data.frame
        median.spectra.zscore.mappedENSGID <- mapIds(get(db), keys=median.spectra.zscore.formatted$Gene, keytype = "SYMBOL", column = "ENSEMBL")
        median.spectra.zscore.formatted <- median.spectra.zscore.formatted %>%
            mutate(ENSGID_mapped = median.spectra.zscore.mappedENSGID) %>%
            mutate(matchedENSGIDbool = ENSGID_from_input == ENSGID_mapped)

        median.spectra.zscore.formatted <- median.spectra.zscore.formatted %>% `rownames<-`(.$ENSGID_from_input)

        median.spectra.zscore.formatted <- median.spectra.zscore.formatted %>%
            select(-Gene, -ENSGID_from_input, -ENSGID_mapped, -matchedENSGIDbool, -Gene.ENSGID) %>%
            `colnames<-`(paste0("median_spectra_K", k, "_", colnames(.))) %>%
            as.data.frame

        print("save the data")
        ensembl.theta.zscore.names <- mapIds(get(db), keys = rownames(theta.zscore), keytype = "SYMBOL", column="ENSEMBL")
        ensembl.theta.zscore.names[ensembl.theta.zscore.names %>% is.na] <- rownames(theta.zscore)[ensembl.theta.zscore.names %>% is.na]
        theta.zscore.ensembl <- theta.zscore
        colnames(theta.zscore.ensembl) <- paste0("zscore_K", k, "_", colnames(theta.zscore.ensembl))
        theta.zscore.ensembl <- theta.zscore.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_topic1"))

        ensembl.theta.raw.names <- mapIds(get(db), keys = rownames(theta.raw), keytype = "SYMBOL", column="ENSEMBL")
        ensembl.theta.raw.names[ensembl.theta.raw.names %>% is.na] <- rownames(theta.raw)[ensembl.theta.raw.names %>% is.na]
        theta.raw.ensembl <- theta.raw
        colnames(theta.raw.ensembl) <- paste0("tpm_K", k, "_topic", colnames(theta.raw.ensembl))
        theta.raw.ensembl <- theta.raw.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.raw.names,.before=paste0("raw_K",k,"_topic1"))

        ## normalize to zero mean + unit variance
        theta.raw.ensembl.scaled <- theta.raw.ensembl %>% select(-ENSGID) %>% apply(2, scale)  %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("tpm_K",k,"_1"))
        theta.zscore.ensembl.scaled <- theta.zscore.ensembl %>% select(-ENSGID) %>% apply(2, scale) %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_1"))
        median.spectra.zscore.formatted.scaled <- median.spectra.zscore.formatted %>% apply(2, scale) %>% as.data.frame %>% mutate(ENGSID = row.names(median.spectra.zscore.formatted), .before=paste0("median_spectra_K", k, "_1")) 

    } else {
        ## detect gene data type (e.g. ENSGID, Entrez Symbol)
        gene.type <- ifelse(nrow(median.spectra.zscore) == sum(as.numeric(grepl("^ENS", median.spectra.zscore %>% rownames))),
                            "ENSGID",
                            "Gene")

        median.spectra.zscore.df <- median.spectra.zscore %>%
            as.data.frame %>%
            mutate(!!gene.type := rownames(.)) %>%
            melt(id.vars = c(gene.type), variable.name="ProgramID", value.name="median.spectra.zscore") %>%
            mutate(ProgramID = paste0("K", k, "_", ProgramID)) %>%
            as.data.frame %>%
            group_by(ProgramID) %>%
            arrange(desc(median.spectra.zscore)) %>%
            mutate(median.spectra.zscore.rank = 1:n()) %>%
            as.data.frame

        if(gene.type == "Gene") {
            ## put median spectra zscore into ENSGID format for PoPS
            mapped.genes <- mapIds(get(db), keys=median.spectra.zscore %>% rownames, keytype = "SYMBOL", column = "ENSEMBL")
            median.spectra.zscore.formatted <- median.spectra.zscore %>%
                as.data.frame %>%
                ## `colnames<-`(paste0("median_spectra_K", k, "_", colnames(.))) %>%
                mutate(Gene = rownames(.)) %>%
                mutate(ENSGID = mapped.genes)

            ## ensembl.theta.zscore.names <- mapIds(get(db), keys = rownames(theta.zscore), keytype = "SYMBOL", column="ENSEMBL")
            ## ensembl.theta.zscore.names[ensembl.theta.zscore.names %>% is.na] <- rownames(theta.zscore)[ensembl.theta.zscore.names %>% is.na]
            ## theta.zscore.ensembl <- theta.zscore
            ## colnames(theta.zscore.ensembl) <- paste0("zscore_K", k, "_", colnames(theta.zscore.ensembl))
            ## theta.zscore.ensembl <- theta.zscore.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_1"))

            ## ensembl.theta.raw.names <- mapIds(get(db), keys = rownames(theta.raw), keytype = "SYMBOL", column="ENSEMBL")
            ## ensembl.theta.raw.names[ensembl.theta.raw.names %>% is.na] <- rownames(theta.raw)[ensembl.theta.raw.names %>% is.na]
            ## theta.raw.ensembl <- theta.raw
            ## colnames(theta.raw.ensembl) <- paste0("tpm_K", k, "_", colnames(theta.raw.ensembl))
            ## theta.raw.ensembl <- theta.raw.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.raw.names,.before=paste0("tpm_K",k,"_1"))

            ## ## normalize to zero mean + unit variance
            ## theta.raw.ensembl.scaled <- theta.raw.ensembl %>% select(-ENSGID) %>% apply(2, scale)  %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("tpm_K",k,"_1"))
            ## theta.zscore.ensembl.scaled <- theta.zscore.ensembl %>% select(-ENSGID) %>% apply(2, scale) %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_1"))
            ## median.spectra.zscore.formatted.scaled <- median.spectra.zscore.formatted %>% select(-ENSGID, -Gene) %>% apply(2, scale) %>% as.data.frame %>% mutate(ENGSID = median.spectra.zscore.formatted$ENSGID, .before=paste0("median_spectra_K", k, "_1")) 

        } else {
            mapped.genes <- mapIds(get(db), keys=median.spectra.zscore %>% rownames, keytype = "ENSEMBL", column = "SYMBOL")
            median.spectra.zscore.formatted <- median.spectra.zscore %>%
                as.data.frame %>%
                ## `colnames<-`(paste0("median_spectra_K", k, "_", colnames(.))) %>%
                mutate(ENSGID = rownames(.)) %>%
                mutate(Gene = mapped.genes)
            ## median.spectra.zscore.formatted.scaled <- median.spectra.zscore.formatted %>% select(-ENSGID, -Gene) %>% apply(2, scale) %>% as.data.frame %>% mutate(ENGSID = median.spectra.zscore.formatted$ENSGID, .before=paste0("median_spectra_K", k, "_1")) 

        }

        median.spectra.zscore.ensembl.names <- median.spectra.zscore.formatted$ENSGID
        median.spectra.zscore.formatted <- median.spectra.zscore.formatted %>%
            mutate(Gene_ENSGID = paste0(Gene, ":", ENSGID)) %>%
            `rownames<-`(.$Gene_ENSGID) %>%
            select(-Gene, -ENSGID, -Gene_ENSGID) %>%
            `colnames<-`(paste0("median_spectra_K", k, "_", colnames(.))) %>%
            as.data.frame

        if(gene.type == "Gene") {
            ensembl.theta.zscore.names <- mapIds(get(db), keys = rownames(theta.zscore), keytype = "SYMBOL", column="ENSEMBL")
        } else {
            ensembl.theta.zscore.names <- theta.zscore %>% rownames
        }
        ensembl.theta.zscore.names[ensembl.theta.zscore.names %>% is.na] <- rownames(theta.zscore)[ensembl.theta.zscore.names %>% is.na]
        theta.zscore.ensembl <- theta.zscore
        colnames(theta.zscore.ensembl) <- paste0("zscore_K", k, "_", colnames(theta.zscore.ensembl))
        theta.zscore.ensembl <- theta.zscore.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_1"))

        if(gene.type == "Gene") {
            ensembl.theta.raw.names <- mapIds(get(db), keys = rownames(theta.raw), keytype = "SYMBOL", column="ENSEMBL")
        } else {
            ensembl.theta.raw.names <- theta.raw %>% rownames
        }
        ensembl.theta.raw.names[ensembl.theta.raw.names %>% is.na] <- rownames(theta.raw)[ensembl.theta.raw.names %>% is.na]
        theta.raw.ensembl <- theta.raw
        colnames(theta.raw.ensembl) <- paste0("tpm_K", k, "_", colnames(theta.raw.ensembl))
        theta.raw.ensembl <- theta.raw.ensembl %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.raw.names,.before=paste0("tpm_K",k,"_1"))

        ## normalize to zero mean + unit variance
        theta.raw.ensembl.scaled <- theta.raw.ensembl %>% select(-ENSGID) %>% apply(2, scale)  %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("tpm_K",k,"_1"))
        theta.zscore.ensembl.scaled <- theta.zscore.ensembl %>% select(-ENSGID) %>% apply(2, scale) %>% as.data.frame %>% mutate(ENSGID=ensembl.theta.zscore.names,.before=paste0("zscore_K",k,"_1"))
        median.spectra.zscore.formatted.scaled <- median.spectra.zscore.formatted %>% apply(2, scale) %>% as.data.frame %>% mutate(ENGSID = median.spectra.zscore.ensembl.names, .before=paste0("median_spectra_K", k, "_1")) 


    }

    ## truncate.theta.names <- function(theta) {
    ##     theta.gene.names <- rownames(theta) %>% strsplit(., split=":") %>% sapply("[[",1) # remove ENSG names
    ##     rownames(theta) <- theta.gene.names
    ##     return(theta)
    ## }
    ## theta.raw <- truncate.theta.names(theta.raw)
    ## theta.zscore <- truncate.theta.names(theta.zscore)
    omega.path <- paste0(TMDIR, "/", SAMPLE, ".usages.k_", k, ".dt_", DENSITY.THRESHOLD, ".consensus.txt")
    print(omega.path)
    omega.original <- omega <- read.delim(omega.path, header=T, stringsAsFactors=F, check.names=F, row.names = 1)  %>% apply(1, function(x) x/sum(x)) %>% t()
    colnames(omega) <- paste0("topic_",colnames(omega))
    print("finished loading omega")

    barcode.names <- read.table(opt$barcode.names, header=T, stringsAsFactors=F, sep=ifelse(grepl("csv$", opt$barcode.names), ",", "\t")) ## %>% `colnames<-`("long.CBC")
    if(grepl("2kG.library", SAMPLE)) {
        barcode.names <- read.table(opt$barcode.names, header=F, stringsAsFactors=F) ## %>% `colnames<-`("long.CBC")
        rownames(omega) <- rownames(omega.original) <- barcode.names %>% `colnames<-`("long.CBC") %>% pull(long.CBC) %>% gsub("CSNK2B-and-CSNK2B", "CSNK2B",.) %>% gsub("[(]'", "", .) %>% gsub("',[)]", "", .)
        omega <- adjust.multiTargetGuide.rownames(omega)
        barcode.names <- data.frame(long.CBC=rownames(omega)) %>%
            mutate(long.CBC = gsub("CSNK2B-and-CSNK2B", "CSNK2B", long.CBC)) %>%
            separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
            separate(col="CBC", into=c("CBC", "sample"), sep="-scRNAseq_2kG_", remove=F) %>%
            mutate(Gene = gsub("-TSS2$", "", Gene.full.name),
                   CBC = gsub("RHOA-and-", "", CBC),
                   Guide = gsub("RHOA-and-", "", Guide)) %>%
            as.data.frame
    }



    ## helper function to map between ENSGID and SYMBOL
    map.ENSGID.SYMBOL <- function(df) {
        ## need column `Gene` to be present in df
        ## detect gene data type (e.g. ENSGID, Entrez Symbol)
        gene.type <- ifelse(nrow(df) == sum(as.numeric(grepl("^ENS", df$Gene))),
                            "ENSGID",
                            "Gene")
        if(gene.type == "ENSGID") {
            mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "ENSEMBL", column = "SYMBOL")
            df <- df %>% mutate(ENSGID = Gene, Gene = mapped.genes)
        } else {
            mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "SYMBOL", column = "ENSEMBL")
            df <- df %>% mutate(ENSGID = mapped.genes)
        }
        return(df)
    }


    ## get list of topic defining genes
    theta.rank.list <- vector("list", ncol(theta.zscore))## initialize storage list
    for(i in 1:ncol(theta.zscore)) {
        topic <- paste0("topic_", colnames(theta.zscore)[i])
        theta.rank.list[[i]] <- theta.zscore %>%
            as.data.frame %>%
            select(all_of(i)) %>%
            `colnames<-`("topic.zscore") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(topic.zscore), .before="topic.zscore") %>%
            mutate(zscore.specificity.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    theta.rank.df <- do.call(rbind, theta.rank.list) %>%  ## combine list to df
        `colnames<-`(c("topic.zscore", "Gene", "zscore.specificity.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL

    ## get list of topic genes by raw weight
    theta.raw.rank.list <- vector("list", ncol(theta.raw))## initialize storage list
    for(i in 1:ncol(theta.raw)) {
        topic <- paste0("topic_", colnames(theta.raw)[i])
        theta.raw.rank.list[[i]] <- theta.raw %>%
            as.data.frame %>%
            select(all_of(i)) %>%
            `colnames<-`("topic.raw") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(topic.raw), .before="topic.raw") %>%
            mutate(raw.score.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    theta.raw.rank.df <- do.call(rbind, theta.raw.rank.list) %>%  ## combine list to df
        `colnames<-`(c("topic.raw", "Gene", "raw.score.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL


    ## get list of topic genes by median spectra weight
    median.spectra.rank.list <- vector("list", ncol(median.spectra))## initialize storage list
    for(i in 1:ncol(median.spectra)) {
        topic <- paste0("topic_", colnames(median.spectra)[i])
        median.spectra.rank.list[[i]] <- median.spectra %>%
            as.data.frame %>%
            select(all_of(i)) %>%
            `colnames<-`("median.spectra") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(median.spectra), .before="median.spectra") %>%
            mutate(median.spectra.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    median.spectra.rank.df <- do.call(rbind, median.spectra.rank.list) %>%  ## combine list to df
        `colnames<-`(c("median.spectra", "Gene", "median.spectra.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL
    ## median.spectra.zscore.df <- median.spectra.zscore.df %>% mutate(Gene = ENSGID) ## quick fix, need to add "Gene" column to this dataframe in analysis script





    write.table(theta.zscore, file=paste0(OUTDIRSAMPLE, "/topic.zscore_",SUBSCRIPT.SHORT, ".txt"), row.names=T, quote=F, sep="\t")
    write.table(theta.raw, file=paste0(OUTDIRSAMPLE, "/topic.tpm.score_",SUBSCRIPT.SHORT, ".txt"), row.names=T, quote=F, sep="\t")
    write.table(theta.zscore.ensembl, file=paste0(OUTDIRSAMPLE, "/topic.zscore.ensembl_",SUBSCRIPT.SHORT, ".txt"), row.names=F, quote=F, sep="\t")
    write.table(theta.raw.ensembl, file=paste0(OUTDIRSAMPLE, "/topic.tpm.ensembl_",SUBSCRIPT.SHORT, ".txt"), row.names=F, quote=F, sep="\t")
    write.table(theta.zscore.ensembl.scaled, file=paste0(OUTDIRSAMPLE, "/topic.zscore.ensembl.scaled_", SUBSCRIPT.SHORT, ".txt"), row.names=F, quote=F, sep = "\t")
    write.table(theta.raw.ensembl.scaled, file=paste0(OUTDIRSAMPLE, "/topic.tpm.ensembl.scaled_", SUBSCRIPT.SHORT, ".txt"), row.names=F, quote=F, sep = "\t")
    write.table(median.spectra.zscore.df, file=paste0(OUTDIRSAMPLE, "/median.spectra.zscore.df_", SUBSCRIPT.SHORT, ".txt"), sep="\t", quote=F, row.names=F)
    write.table(median.spectra.zscore.formatted.scaled, file=paste0(OUTDIRSAMPLE, "/median.spectra.zscore.ensembl.scaled_", SUBSCRIPT.SHORT, ".txt"), sep="\t", quote=F, row.names=F)

    save(theta, theta.raw, theta.raw.rank.df, theta.zscore, median.spectra.zscore.df, median.spectra, median.spectra.rank.df, omega, 
         theta.path, omega.path, median.spectra.path, barcode.names,
         file=cNMF.result.file)

    print("finished writing all tables")
}
print("finished analysis script")



  # # modify GO.list if "pooled"
  # if (SEP) {
  #   tmp.plus <- GO.list %>% mutate(Gene = paste0(GO.list$Gene,rep1.label), Pathway = paste0(GO.list$Pathway,rep1.label))
  #   tmp.no <- GO.list %>% mutate(Gene = paste0(GO.list$Gene,rep2.label), Pathway = paste0(GO.list$Pathway,rep2.label))
  #   GO.list <- rbind(tmp.no, tmp.plus)
  # }
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
library(conflicted)
conflict_prefer("Position", "base")
packages <- c("optparse","dplyr", "ggplot2", "reshape2", "ggrepel", "conflicted", "gplots", "org.Hs.eg.db")
## library(Seurat)
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"),
  # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  # make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210625_snakemake_output/top3000VariableGenes_acrossK/2kG.library/", help="Topic model results directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  # make_option("--K.list", type="character", default="2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,19,21,23,25", help="K values available for analysis"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  # make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  # make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  # make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  # make_option("--raw.mtx.dir",type="character",default="stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  # make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"), # the first lane: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.expandedMultiTargetGuide.RDS"
  # make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),
  # make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/barcodes.tsv", help="barcodes.tsv for all cells"),
  # make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  # make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  # make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  # make_option("--enhancer.fimo.threshold", type="character", default="1.0E-4", help="Enhancer fimo motif match threshold"),

  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60

## ## mouse ENCODE adrenal data sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230117_snakemake_mouse_ENCODE_adrenal/figures/top2000VariableGenes/"
## opt$sampleName <- "mouse_ENCODE_adrenal"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230117_snakemake_mouse_ENCODE_adrenal/analysis/top2000VariableGenes"
## opt$K.val <- 60

## ## mouse ENCODE heart data sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/figures/top2000VariableGenes/"
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes"
## opt$K.val <- 15

## ## K562 gwps sdev
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$sampleName <- "WeissmanK562gwps"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$K.val <- 20


mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))
mytheme <- theme_classic() + theme(axis.text = element_text(size = 7),
                                   axis.title = element_text(size = 8),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=10),
                                   axis.line = element_line(color = "black", size = 0.25),
                                   axis.ticks = element_line(color = "black", size = 0.25),
                                   legend.key.size = unit(10, units="pt"),
                                   legend.text = element_text(size=7),
                                   legend.title = element_text(size=8)
                                   )


SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
# STATIC.SAMPLE=c("Telo_no_IL1B_T200_1", "Telo_no_IL1B_T200_2", "Telo_plus_IL1B_T200_1", "Telo_plus_IL1B_T200_2", "no_IL1B", "plus_IL1B",  "pooled")
# DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
## TMDIR=opt$topic.model.result.dir
## SEP=opt$sep
# K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
## FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
## FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

message(FIGDIRTOP)

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
# SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr


# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)




######################################################################
## Process topic model results
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
} else {
    warning(paste0(cNMF.result.file, " does not exist"))
}


## End of data loading


## (patch up cNMF_analysis.R ENSGID to Gene conversion)
## gene mapping function
map.ENSGID.SYMBOL <- function(topFeatures) {
    gene.type <- ifelse(sum(as.numeric(colnames(topFeatures) %in% "Gene")) > 0,
                        ##(median.spectra.zscore.df) == sum(as.numeric(grepl("^ENS", median.spectra.zscore %>% rownames))),
                        "Gene",
                        "ENSGID")
    db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
    library(!!db)
    if(gene.type == "Gene") {
        if (nrow(topFeatures) == sum(as.numeric(grepl("^ENS", topFeatures$Gene)))) {
            ## put median spectra zscore into ENSGID format for PoPS
            mapped.genes <- mapIds(get(db), keys=topFeatures$Gene, keytype = "ENSEMBL", column = "SYMBOL")
            topFeatures <- topFeatures %>%
                mutate(ENSGID = Gene) %>%
                as.data.frame %>%
                mutate(Gene = mapped.genes)
            na.index <- which(is.na(topFeatures$Gene))
            if(length(na.index) > 0) topFeatures$Gene[na.index] <- topFeatures$ENSGID[na.index]
        }

    } else {
        mapped.genes <- mapIds(get(db), keys=topFeatures$ENSGID, keytype = "ENSEMBL", column = "SYMBOL")
        topFeatures <- topFeatures %>%
            as.data.frame %>%
            mutate(Gene = mapped.genes)

    }
    return(topFeatures)
}

## end of ENSGID to Gene conversion



##########################################################################
## Plots


##########################################################################
## topic gene z-score list
pdf(file=paste0(FIGDIRTOP,"top50GeneInTopics.zscore.pdf"), width=2.5, height=4.5)
topFeatures.raw.weight <- theta.zscore %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene", variable.name="topic", value.name="scores") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
                         Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores))
    p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col() + theme_minimal()
    p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("z-score (Specificity)") + ggtitle(paste(SAMPLE, ", Topic ", t, sep="")) + mytheme
    print(p)
}
dev.off()


##########################################################################
## Topic's top gene list, ranked by raw weight
pdf(file=paste0(FIGDIRTOP,"top50GeneInTopics.rawWeight.pdf"), width=2.5, height=4.5)
topFeatures <- theta %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(Gene),
                         Score=topFeatures %>% subset(topic == t) %>% pull(scores))
    p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col() + theme_minimal()
    p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("Raw Score (gene's weight in topic)") + ggtitle(paste(SAMPLE, ", Topic ", t, sep="")) + mytheme
    print(p)
}
dev.off()


##########################################################################
## raw program TPM list with annotataion
pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.rawWeight.pdf"), width=2.5, height=3)
topFeatures <- theta %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(Gene),
                         Score=topFeatures %>% subset(topic == t) %>% pull(scores)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
    p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("Raw Score (gene's weight in topic)") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
    print(p4)
}
dev.off()



##########################################################################
## raw program zscore list (top 10)  (can potentially include annotation)
pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.zscore.pdf"), width=2.5, height=3)
topFeatures <- theta.zscore %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(Gene),
                         Score=topFeatures %>% subset(topic == t) %>% pull(scores)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
    p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("z-score") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, "\nK = ", k, ", Topic ", t))
    print(p4)
}
dev.off()



##########################################################################
## median spectra list (top 10)  (can potentially include annotation)
pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.median_spectra.zscore.pdf"), width=2.5, height=3)
topFeatures <- median.spectra.zscore.df %>% as.data.frame() %>% group_by(ProgramID) %>% arrange(desc(median.spectra.zscore)) %>% slice(1:10)  %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    ProgramID.here <- paste0("K", k, "_", t)
    toPlot <- data.frame(Gene=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Gene),
                         Score=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(median.spectra.zscore)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p7 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
    p7 <- p7 + coord_flip() + xlab("Top 10 Genes") + ylab("z-score") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ",\nK = ", k, ", Program ", t, "\nMedian Spectra"))
    print(p7)
}
dev.off()



##########################################################################
## median spectra raw list (top 10)  (can potentially include annotation)
pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.median_spectra.raw.pdf"), width=2.5, height=3)
topFeatures <- median.spectra %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene",value.name="Score", variable.name="ProgramID") %>% group_by(ProgramID) %>% arrange(desc(Score)) %>% slice(1:10) %>% mutate(ProgramID = paste0("K", k, "_", ProgramID)) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    ProgramID.here <- paste0("K", k, "_", t)
    toPlot <- data.frame(Gene=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Gene),
                         Score=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Score)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p8 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
    p8 <- p8 + coord_flip() + xlab("Top 10 Genes") + ylab("Weight in Program") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ",\nK = ", k, ", Program ", t, "\nMedian Spectra"))
    print(p8)
}
dev.off()



##########################################################################
## Topic's top gene list, ranked by median spectra zscore
pdf(file=paste0(FIGDIRTOP,"top50GeneInProgram.median_spectra.zscore.pdf"), width=2.5, height=4.5)
topFeatures <- median.spectra.zscore.df %>% as.data.frame() %>% group_by(ProgramID) %>% arrange(desc(median.spectra.zscore)) %>% slice(1:50) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    ProgramID.here <- paste0("K", k, "_", t)
    toPlot <- data.frame(Gene=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Gene),
                         Score=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(median.spectra.zscore)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p9 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="gray30") + theme_minimal()
    p9 <- p9 + coord_flip() + xlab("Top 50 Genes") + ylab("z-score") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ",\nK = ", k, ", Program ", t, "\nMedian Spectra"))
    print(p9)
}
dev.off()


##########################################################################
## median spectra raw list (top 50)  (can potentially include annotation)
pdf(file=paste0(FIGDIRTOP,"top50GeneInProgram.median_spectra.raw.pdf"), width=2.5, height=4.5)
topFeatures <- median.spectra %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene",value.name="Score", variable.name="ProgramID") %>% group_by(ProgramID) %>% arrange(desc(Score)) %>% slice(1:50) %>% mutate(ProgramID = paste0("K", k, "_", ProgramID)) %>% map.ENSGID.SYMBOL
for ( t in 1:dim(theta)[2] ) {
    ProgramID.here <- paste0("K", k, "_", t)
    toPlot <- data.frame(Gene=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Gene),
                         Score=topFeatures %>% subset(ProgramID == ProgramID.here) %>% pull(Score)) # %>%
        ## merge(., gene.def.pathways, by="Gene", all.x=T)
    ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
    p10 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="gray30") + theme_minimal()
    p10 <- p10 + coord_flip() + xlab("Top 50 Genes") + ylab("Weight in Program") +
        mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ",\nK = ", k, ", Program ", t, "\nMedian Spectra"))
    print(p10)
}
dev.off()


##########################################################################
## topic Pearson correlation heatmap
## remove NA from theta.zscore
tokeep <- (!is.na(theta.zscore)) %>% apply(1, sum) == k ## remove genes with NA ## why is there NA?
d <- cor(theta.zscore[tokeep,], method="pearson")
m <- as.matrix(d)

## Function for plotting heatmap  # new version (adjusted font size)
plotHeatmap <- function(mtx, labCol, title, margins=c(12,6), ...) { #original
  heatmap.2(
    mtx %>% t(), 
    Rowv=T, 
    Colv=T,
    trace='none',
    key=T,
    col=palette,
    labCol=labCol,
    margins=margins, 
    cex.main=0.8, 
    cexCol=4.8/sqrt(ncol(mtx)), cexRow=4.8/sqrt(ncol(mtx)), #4.8/sqrt(nrow(mtx))
    ## cexCol=1/(ncol(mtx)^(1/3)), cexRow=1/(ncol(mtx)^(1/3)), #4.8/sqrt(nrow(mtx))
    main=title,
    ...
  )
}

pdf(file=paste0(FIGDIRTOP, "topic.Pearson.correlation.pdf"))
plotHeatmap(m, labCol=rownames(m), margins=c(3,3), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
import numpy as np
import pandas as pd
import os, errno
import datetime
import uuid
import itertools
import yaml
import subprocess
import scipy.sparse as sp


from scipy.spatial.distance import squareform
from sklearn.decomposition import non_negative_factorization
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.utils import sparsefuncs


from fastcluster import linkage
from scipy.cluster.hierarchy import leaves_list

import matplotlib.pyplot as plt

import scanpy as sc

def save_df_to_npz(obj, filename):
    np.savez_compressed(filename, data=obj.values, index=obj.index.values, columns=obj.columns.values)

def save_df_to_text(obj, filename):
    obj.to_csv(filename, sep='\t')

def load_df_from_npz(filename):
    with np.load(filename, allow_pickle=True) as f:
        obj = pd.DataFrame(**f)
    return obj

def check_dir_exists(path):
    """
    Checks if directory already exists or not and creates it if it doesn't
    """
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise

def worker_filter(iterable, worker_index, total_workers):
    return (p for i,p in enumerate(iterable) if (i-worker_index)%total_workers==0)

def fast_euclidean(mat):
    D = mat.dot(mat.T)
    squared_norms = np.diag(D).copy()
    D *= -2.0
    D += squared_norms.reshape((-1,1))
    D += squared_norms.reshape((1,-1))
    D = np.sqrt(D)
    D[D < 0] = 0
    return squareform(D, checks=False)

def fast_ols_all_cols(X, Y):
    pinv = np.linalg.pinv(X)
    beta = np.dot(pinv, Y)
    return(beta)

def fast_ols_all_cols_df(X,Y):
    beta = fast_ols_all_cols(X, Y)
    beta = pd.DataFrame(beta, index=X.columns, columns=Y.columns)
    return(beta)

def var_sparse_matrix(X):
    mean = np.array(X.mean(axis=0)).reshape(-1)
    Xcopy = X.copy()
    Xcopy.data **= 2
    var = np.array(Xcopy.mean(axis=0)).reshape(-1) - (mean**2)
    return(var)


def get_highvar_genes_sparse(expression, expected_fano_threshold=None,
                       minimal_mean=0.5, numgenes=None):
    # Find high variance genes within those cells
    gene_mean = np.array(expression.mean(axis=0)).astype(float).reshape(-1)
    E2 = expression.copy(); E2.data **= 2; gene2_mean = np.array(E2.mean(axis=0)).reshape(-1)
    gene_var = pd.Series(gene2_mean - (gene_mean**2))
    del(E2)
    gene_mean = pd.Series(gene_mean)
    gene_fano = gene_var / gene_mean

    # Find parameters for expected fano line
    top_genes = gene_mean.sort_values(ascending=False)[:20].index
    A = (np.sqrt(gene_var)/gene_mean)[top_genes].min()

    w_mean_low, w_mean_high = gene_mean.quantile([0.10, 0.90])
    w_fano_low, w_fano_high = gene_fano.quantile([0.10, 0.90])
    winsor_box = ((gene_fano > w_fano_low) &
                    (gene_fano < w_fano_high) &
                    (gene_mean > w_mean_low) &
                    (gene_mean < w_mean_high))
    fano_median = gene_fano[winsor_box].median()
    B = np.sqrt(fano_median)

    gene_expected_fano = (A**2)*gene_mean + (B**2)
    fano_ratio = (gene_fano/gene_expected_fano)

    # Identify high var genes
    if numgenes is not None:
        highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes]
        high_var_genes_ind = fano_ratio.index.isin(highvargenes)
        T=None


    else:
        if not expected_fano_threshold:
            T = (1. + gene_counts_fano[winsor_box].std())
        else:
            T = expected_fano_threshold

        high_var_genes_ind = (fano_ratio > T) & (gene_counts_mean > minimal_mean)

    gene_counts_stats = pd.DataFrame({
        'mean': gene_mean,
        'var': gene_var,
        'fano': gene_fano,
        'expected_fano': gene_expected_fano,
        'high_var': high_var_genes_ind,
        'fano_ratio': fano_ratio
        })
    gene_fano_parameters = {
            'A': A, 'B': B, 'T':T, 'minimal_mean': minimal_mean,
        }
    return(gene_counts_stats, gene_fano_parameters)



def get_highvar_genes(input_counts, expected_fano_threshold=None,
                       minimal_mean=0.5, numgenes=None):
    # Find high variance genes within those cells
    gene_counts_mean = pd.Series(input_counts.mean(axis=0).astype(float))
    gene_counts_var = pd.Series(input_counts.var(ddof=0, axis=0).astype(float))
    gene_counts_fano = pd.Series(gene_counts_var/gene_counts_mean)

    # Find parameters for expected fano line
    top_genes = gene_counts_mean.sort_values(ascending=False)[:20].index
    A = (np.sqrt(gene_counts_var)/gene_counts_mean)[top_genes].min()

    w_mean_low, w_mean_high = gene_counts_mean.quantile([0.10, 0.90])
    w_fano_low, w_fano_high = gene_counts_fano.quantile([0.10, 0.90])
    winsor_box = ((gene_counts_fano > w_fano_low) &
                    (gene_counts_fano < w_fano_high) &
                    (gene_counts_mean > w_mean_low) &
                    (gene_counts_mean < w_mean_high))
    fano_median = gene_counts_fano[winsor_box].median()
    B = np.sqrt(fano_median)

    gene_expected_fano = (A**2)*gene_counts_mean + (B**2)

    fano_ratio = (gene_counts_fano/gene_expected_fano)

    # Identify high var genes
    if numgenes is not None:
        highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes]
        high_var_genes_ind = fano_ratio.index.isin(highvargenes)
        T=None


    else:
        if not expected_fano_threshold:
            T = (1. + gene_counts_fano[winsor_box].std())
        else:
            T = expected_fano_threshold

        high_var_genes_ind = (fano_ratio > T) & (gene_counts_mean > minimal_mean)

    gene_counts_stats = pd.DataFrame({
        'mean': gene_counts_mean,
        'var': gene_counts_var,
        'fano': gene_counts_fano,
        'expected_fano': gene_expected_fano,
        'high_var': high_var_genes_ind,
        'fano_ratio': fano_ratio
        })
    gene_fano_parameters = {
            'A': A, 'B': B, 'T':T, 'minimal_mean': minimal_mean,
        }
    return(gene_counts_stats, gene_fano_parameters)


def compute_tpm(input_counts):
    """
    Default TPM normalization
    """
    tpm = input_counts.copy()
    sc.pp.normalize_per_cell(tpm, counts_per_cell_after=1e6)
    return(tpm)


class cNMF():


    def __init__(self, output_dir=".", name=None):
        """
        Parameters
        ----------

        output_dir : path, optional (default=".")
            Output directory for analysis files.

        name : string, optional (default=None)
            A name for this analysis. Will be prefixed to all output files.
            If set to None, will be automatically generated from date (and random string).
        """

        self.output_dir = output_dir
        if name is None:
            now = datetime.datetime.now()
            rand_hash =  uuid.uuid4().hex[:6]
            name = '%s_%s' % (now.strftime("%Y_%m_%d"), rand_hash)
        self.name = name
        self.paths = None


    def _initialize_dirs(self):
        if self.paths is None:
            # Check that output directory exists, create it if needed.
            check_dir_exists(self.output_dir)
            check_dir_exists(os.path.join(self.output_dir, self.name))
            check_dir_exists(os.path.join(self.output_dir, self.name, 'cnmf_tmp'))

            self.paths = {
                'normalized_counts' : os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.norm_counts.h5ad'),
                'nmf_replicate_parameters' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.nmf_params.df.npz'),
                'nmf_run_parameters' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.nmf_idvrun_params.yaml'),
                'nmf_genes_list' :  os.path.join(self.output_dir, self.name, self.name+'.overdispersed_genes.txt'),

                'tpm' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.tpm.h5ad'),
                'tpm_stats' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.tpm_stats.df.npz'),

                'iter_spectra' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.iter_%d.df.npz'),
                'iter_usages' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.usages.k_%d.iter_%d.df.npz'),
                'merged_spectra': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.merged.df.npz'),

                'local_density_cache': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.local_density_cache.k_%d.merged.df.npz'),
                'consensus_spectra': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.dt_%s.consensus.df.npz'),
                'consensus_spectra__txt': os.path.join(self.output_dir, self.name, self.name+'.spectra.k_%d.dt_%s.consensus.txt'),
                'consensus_usages': os.path.join(self.output_dir, self.name, 'cnmf_tmp',self.name+'.usages.k_%d.dt_%s.consensus.df.npz'),
                'consensus_usages__txt': os.path.join(self.output_dir, self.name, self.name+'.usages.k_%d.dt_%s.consensus.txt'),

                'consensus_stats': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.stats.k_%d.dt_%s.df.npz'),

                'clustering_plot': os.path.join(self.output_dir, self.name, self.name+'.clustering.k_%d.dt_%s.png'),
                'gene_spectra_score': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.gene_spectra_score.k_%d.dt_%s.df.npz'),
                'gene_spectra_score__txt': os.path.join(self.output_dir, self.name, self.name+'.gene_spectra_score.k_%d.dt_%s.txt'),
                'gene_spectra_tpm': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.gene_spectra_tpm.k_%d.dt_%s.df.npz'),
                'gene_spectra_tpm__txt': os.path.join(self.output_dir, self.name, self.name+'.gene_spectra_tpm.k_%d.dt_%s.txt'),

                'k_selection_plot' :  os.path.join(self.output_dir, self.name, self.name+'.k_selection.png'),
                'k_selection_stats' :  os.path.join(self.output_dir, self.name, self.name+'.k_selection_stats.df.npz'),
            }


    def get_norm_counts(self, counts, tpm,
                         high_variance_genes_filter = None,
                         num_highvar_genes = None
                         ):
        """
        Parameters
        ----------

        counts : anndata.AnnData
            Scanpy AnnData object (cells x genes) containing raw counts. Filtered such that
            no genes or cells with 0 counts

        tpm : anndata.AnnData
            Scanpy AnnData object (cells x genes) containing tpm normalized data matching
            counts

        high_variance_genes_filter : np.array, optional (default=None)
            A pre-specified list of genes considered to be high-variance.
            Only these genes will be used during factorization of the counts matrix.
            Must match the .var index of counts and tpm.
            If set to None, high-variance genes will be automatically computed, using the
            parameters below.

        num_highvar_genes : int, optional (default=None)
            Instead of providing an array of high-variance genes, identify this many most overdispersed genes
            for filtering

        Returns
        -------

        normcounts : anndata.AnnData, shape (cells, num_highvar_genes)
            A counts matrix containing only the high variance genes and with columns (genes)normalized to unit
            variance

        """

        if high_variance_genes_filter is None:
            ## Get list of high-var genes if one wasn't provided
            if sp.issparse(tpm.X):
                (gene_counts_stats, gene_fano_params) = get_highvar_genes_sparse(tpm.X, numgenes=num_highvar_genes)  
            else:
                (gene_counts_stats, gene_fano_params) = get_highvar_genes(np.array(tpm.X), numgenes=num_highvar_genes)

            high_variance_genes_filter = list(tpm.var.index[gene_counts_stats.high_var.values])

        ## Subset out high-variance genes
        norm_counts = counts[:, high_variance_genes_filter]

        ## Scale genes to unit variance
        if sp.issparse(tpm.X):
            sc.pp.scale(norm_counts, zero_center=False)
            if np.isnan(norm_counts.X.data).sum() > 0:
                print('Warning NaNs in normalized counts matrix')                       
        else:
            norm_counts.X /= norm_counts.X.std(axis=0, ddof=1)
            if np.isnan(norm_counts.X).sum().sum() > 0:
                print('Warning NaNs in normalized counts matrix')                    

        ## Save a \n-delimited list of the high-variance genes used for factorization
        open(self.paths['nmf_genes_list'], 'w').write('\n'.join(high_variance_genes_filter))

        ## Check for any cells that have 0 counts of the overdispersed genes
        zerocells = norm_counts.X.sum(axis=1)==0
        if zerocells.sum()>0:
            examples = norm_counts.obs.index[zerocells]
            print('Warning: %d cells have zero counts of overdispersed genes. E.g. %s' % (zerocells.sum(), examples[0]))
            print('Consensus step may not run when this is the case')

        return(norm_counts)


    def save_norm_counts(self, norm_counts):
        self._initialize_dirs()
        sc.write(self.paths['normalized_counts'], norm_counts)


    def get_nmf_iter_params(self, ks, n_iter = 100,
                               random_state_seed = None,
                               beta_loss = 'kullback-leibler'):
        """
        Create a DataFrame with parameters for NMF iterations.


        Parameters
        ----------
        ks : integer, or list-like.
            Number of topics (components) for factorization.
            Several values can be specified at the same time, which will be run independently.

        n_iter : integer, optional (defailt=100)
            Number of iterations for factorization. If several ``k`` are specified, this many
            iterations will be run for each value of ``k``.

        random_state_seed : int or None, optional (default=None)
            Seed for sklearn random state.

        """

        if type(ks) is int:
            ks = [ks]

        # Remove any repeated k values, and order.
        k_list = sorted(set(list(ks)))

        n_runs = len(ks)* n_iter

        np.random.seed(seed=random_state_seed)
        nmf_seeds = np.random.randint(low=1, high=(2**32)-1, size=n_runs)

        replicate_params = []
        for i, (k, r) in enumerate(itertools.product(k_list, range(n_iter))):
            replicate_params.append([k, r, nmf_seeds[i]])
        replicate_params = pd.DataFrame(replicate_params, columns = ['n_components', 'iter', 'nmf_seed'])

        _nmf_kwargs = dict(
                        alpha=0.0,
                        l1_ratio=0.0,
                        beta_loss=beta_loss,
                        solver='mu',
                        tol=1e-4,
                        max_iter=1000,
                        regularization=None,
                        init='random'
                        )

        ## Coordinate descent is faster than multiplicative update but only works for frobenius
        if beta_loss == 'frobenius':
            _nmf_kwargs['solver'] = 'cd'

        return(replicate_params, _nmf_kwargs)


    def save_nmf_iter_params(self, replicate_params, run_params):
        self._initialize_dirs()
        save_df_to_npz(replicate_params, self.paths['nmf_replicate_parameters'])
        with open(self.paths['nmf_run_parameters'], 'w') as F:
            yaml.dump(run_params, F)


    def _nmf(self, X, nmf_kwargs):
        """
        Parameters
        ----------
        X : pandas.DataFrame,
            Normalized counts dataFrame to be factorized.

        nmf_kwargs : dict,
            Arguments to be passed to ``non_negative_factorization``

        """
        (usages, spectra, niter) = non_negative_factorization(X, **nmf_kwargs)

        return(spectra, usages)


    def run_nmf(self,
                worker_i=1, total_workers=1,
                ):
        """
        Iteratively run NMF with prespecified parameters.

        Use the `worker_i` and `total_workers` parameters for parallelization.

        Generic kwargs for NMF are loaded from self.paths['nmf_run_parameters'], defaults below::

            ``non_negative_factorization`` default arguments:
                alpha=0.0
                l1_ratio=0.0
                beta_loss='kullback-leibler'
                solver='mu'
                tol=1e-4,
                max_iter=200
                regularization=None
                init='random'
                random_state, n_components are both set by the prespecified self.paths['nmf_replicate_parameters'].


        Parameters
        ----------
        norm_counts : pandas.DataFrame,
            Normalized counts dataFrame to be factorized.
            (Output of ``normalize_counts``)

        run_params : pandas.DataFrame,
            Parameters for NMF iterations.
            (Output of ``prepare_nmf_iter_params``)

        """
        self._initialize_dirs()
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        norm_counts = sc.read(self.paths['normalized_counts'])
        _nmf_kwargs = yaml.load(open(self.paths['nmf_run_parameters']), Loader=yaml.FullLoader)

        jobs_for_this_worker = worker_filter(range(len(run_params)), worker_i, total_workers)
        for idx in jobs_for_this_worker:

            p = run_params.iloc[idx, :]
            print('[Worker %d]. Starting task %d.' % (worker_i, idx))
            _nmf_kwargs['random_state'] = p['nmf_seed']
            _nmf_kwargs['n_components'] = p['n_components']

            (spectra, usages) = self._nmf(norm_counts.X, _nmf_kwargs)
            spectra = pd.DataFrame(spectra,
                                   index=np.arange(1, _nmf_kwargs['n_components']+1),
                                   columns=norm_counts.var.index)
            save_df_to_npz(spectra, self.paths['iter_spectra'] % (p['n_components'], p['iter']))


    def combine_nmf(self, k, remove_individual_iterations=False):
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        print('Combining factorizations for k=%d.'%k)

        self._initialize_dirs()

        combined_spectra = None
        n_iter = sum(run_params.n_components==k)

        run_params_subset = run_params[run_params.n_components==k].sort_values('iter')
        spectra_labels = []

        for i,p in run_params_subset.iterrows():

            spectra = load_df_from_npz(self.paths['iter_spectra'] % (p['n_components'], p['iter']))
            if combined_spectra is None:
                combined_spectra = np.zeros((n_iter, k, spectra.shape[1]))
            combined_spectra[p['iter'], :, :] = spectra.values

            for t in range(k):
                spectra_labels.append('iter%d_topic%d'%(p['iter'], t+1))

        combined_spectra = combined_spectra.reshape(-1, combined_spectra.shape[-1])
        combined_spectra = pd.DataFrame(combined_spectra, columns=spectra.columns, index=spectra_labels)

        save_df_to_npz(combined_spectra, self.paths['merged_spectra']%k)
        return combined_spectra


    def consensus(self, k, density_threshold_str='0.5', local_neighborhood_size = 0.30,show_clustering = False,
                  skip_density_and_return_after_stats = False, close_clustergram_fig=True):
        merged_spectra = load_df_from_npz(self.paths['merged_spectra']%k)
        norm_counts = sc.read(self.paths['normalized_counts'])
        ##here210830
        if norm_counts.X.dtype != np.float64:
            norm_counts.X = norm_counts.X.astype(np.float64)


        if skip_density_and_return_after_stats:
            density_threshold_str = '2'
        density_threshold_repl = density_threshold_str.replace('.', '_')
        density_threshold = float(density_threshold_str)
        n_neighbors = int(local_neighborhood_size * merged_spectra.shape[0]/k)

        # Rescale topics such to length of 1.
        l2_spectra = (merged_spectra.T/np.sqrt((merged_spectra**2).sum(axis=1))).T


        if not skip_density_and_return_after_stats:
            # Compute the local density matrix (if not previously cached)
            topics_dist = None
            if os.path.isfile(self.paths['local_density_cache'] % k):
                local_density = load_df_from_npz(self.paths['local_density_cache'] % k)
            else:
                #   first find the full distance matrix
                topics_dist = squareform(fast_euclidean(l2_spectra.values))
                #   partition based on the first n neighbors
                partitioning_order  = np.argpartition(topics_dist, n_neighbors+1)[:, :n_neighbors+1]
                #   find the mean over those n_neighbors (excluding self, which has a distance of 0)
                distance_to_nearest_neighbors = topics_dist[np.arange(topics_dist.shape[0])[:, None], partitioning_order]
                local_density = pd.DataFrame(distance_to_nearest_neighbors.sum(1)/(n_neighbors),
                                             columns=['local_density'],
                                             index=l2_spectra.index)
                save_df_to_npz(local_density, self.paths['local_density_cache'] % k)
                del(partitioning_order)
                del(distance_to_nearest_neighbors)


            density_filter = local_density.iloc[:, 0] < density_threshold
            l2_spectra = l2_spectra.loc[density_filter, :]

        kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=1)
        kmeans_model.fit(l2_spectra)
        kmeans_cluster_labels = pd.Series(kmeans_model.labels_+1, index=l2_spectra.index)

        # Find median usage for each gene across cluster
        median_spectra = l2_spectra.groupby(kmeans_cluster_labels).median()

        # Normalize median spectra to probability distributions.
        median_spectra = (median_spectra.T/median_spectra.sum(1)).T

        # Compute the silhouette score
        stability = silhouette_score(l2_spectra.values, kmeans_cluster_labels, metric='euclidean')

        # Obtain the reconstructed count matrix by re-fitting the usage matrix and computing the dot product: usage.dot(spectra)
        refit_nmf_kwargs = yaml.load(open(self.paths['nmf_run_parameters']), Loader=yaml.FullLoader)
        refit_nmf_kwargs.update(dict(
                                    n_components = k,
                                    H = median_spectra.values,
                                    update_H = False
                                    ))

        # change refit_nmf_kwargs['H'] data type to match with norm_counts.X's
        refit_nmf_kwargs['H'] = refit_nmf_kwargs['H'].astype(norm_counts.X.dtype)

        ##here210830
        _, rf_usages = self._nmf(norm_counts.X,
                                          nmf_kwargs=refit_nmf_kwargs)
        rf_usages = pd.DataFrame(rf_usages, index=norm_counts.obs.index, columns=median_spectra.index)
        rf_pred_norm_counts = rf_usages.dot(median_spectra)

        # Compute prediction error as a frobenius norm
        if sp.issparse(norm_counts.X):
            prediction_error = ((norm_counts.X.todense() - rf_pred_norm_counts)**2).sum().sum()
        else:
            prediction_error = ((norm_counts.X - rf_pred_norm_counts)**2).sum().sum()

        consensus_stats = pd.DataFrame([k, density_threshold, stability, prediction_error],
                    index = ['k', 'local_density_threshold', 'stability', 'prediction_error'],
                    columns = ['stats'])

        if skip_density_and_return_after_stats:
            return consensus_stats

        save_df_to_npz(median_spectra, self.paths['consensus_spectra']%(k, density_threshold_repl))
        save_df_to_npz(rf_usages, self.paths['consensus_usages']%(k, density_threshold_repl))
        save_df_to_npz(consensus_stats, self.paths['consensus_stats']%(k, density_threshold_repl))
        save_df_to_text(median_spectra, self.paths['consensus_spectra__txt']%(k, density_threshold_repl))
        save_df_to_text(rf_usages, self.paths['consensus_usages__txt']%(k, density_threshold_repl))

        # Compute gene-scores for each GEP by regressing usage on Z-scores of TPM
        tpm = sc.read(self.paths['tpm'])
        tpm_stats = load_df_from_npz(self.paths['tpm_stats'])

        if sp.issparse(tpm.X):
            norm_tpm = (np.array(tpm.X.todense()) - tpm_stats['__mean'].values) / tpm_stats['__std'].values
        else:
            norm_tpm = (tpm.X - tpm_stats['__mean'].values) / tpm_stats['__std'].values

        # if norm_tpm.dtype != np.float64:
        #     norm_tpm = norm_tpm.astype(np.float64)

        usage_coef = fast_ols_all_cols(rf_usages.values, norm_tpm)
        usage_coef = pd.DataFrame(usage_coef, index=rf_usages.columns, columns=tpm.var.index)

        save_df_to_npz(usage_coef, self.paths['gene_spectra_score']%(k, density_threshold_repl))
        save_df_to_text(usage_coef, self.paths['gene_spectra_score__txt']%(k, density_threshold_repl))

        # Convert spectra to TPM units, and obtain results for all genes by running last step of NMF
        # with usages fixed and TPM as the input matrix
        norm_usages = rf_usages.div(rf_usages.sum(axis=1), axis=0)
        refit_nmf_kwargs.update(dict(
                                    H = norm_usages.T.values,
                                ))

        # Needed otherwise _nmf will crash because with inconsistent dtypes
        if tpm.X.dtype != np.float64:
            tpm.X = tpm.X.astype(np.float64)

        _, spectra_tpm = self._nmf(tpm.X.T, nmf_kwargs=refit_nmf_kwargs)
        spectra_tpm = pd.DataFrame(spectra_tpm.T, index=rf_usages.columns, columns=tpm.var.index)
        save_df_to_npz(spectra_tpm, self.paths['gene_spectra_tpm']%(k, density_threshold_repl))
        save_df_to_text(spectra_tpm, self.paths['gene_spectra_tpm__txt']%(k, density_threshold_repl))

        if show_clustering:
            if topics_dist is None:
                topics_dist = squareform(fast_euclidean(l2_spectra.values))
                # (l2_spectra was already filtered using the density filter)
            else:
                # (but the previously computed topics_dist was not!)
                topics_dist = topics_dist[density_filter.values, :][:, density_filter.values]


            spectra_order = []
            for cl in sorted(set(kmeans_cluster_labels)):

                cl_filter = kmeans_cluster_labels==cl

                if cl_filter.sum() > 1:
                    cl_dist = squareform(topics_dist[cl_filter, :][:, cl_filter])
                    cl_dist[cl_dist < 0] = 0 #Rarely get floating point arithmetic issues
                    cl_link = linkage(cl_dist, 'average')
                    cl_leaves_order = leaves_list(cl_link)

                    spectra_order += list(np.where(cl_filter)[0][cl_leaves_order])
                else:
                    ## Corner case where a component only has one element
                    spectra_order += list(np.where(cl_filter)[0])


            from matplotlib import gridspec
            import matplotlib.pyplot as plt

            width_ratios = [0.5, 9, 0.5, 4, 1]
            height_ratios = [0.5, 9]
            fig = plt.figure(figsize=(sum(width_ratios), sum(height_ratios)))
            gs = gridspec.GridSpec(len(height_ratios), len(width_ratios), fig,
                                    0.01, 0.01, 0.98, 0.98,
                                   height_ratios=height_ratios,
                                   width_ratios=width_ratios,
                                   wspace=0, hspace=0)

            dist_ax = fig.add_subplot(gs[1,1], xscale='linear', yscale='linear',
                                      xticks=[], yticks=[],xlabel='', ylabel='',
                                      frameon=True)

            D = topics_dist[spectra_order, :][:, spectra_order]
            dist_im = dist_ax.imshow(D, interpolation='none', cmap='viridis', aspect='auto',
                                rasterized=True)

            left_ax = fig.add_subplot(gs[1,0], xscale='linear', yscale='linear', xticks=[], yticks=[],
                xlabel='', ylabel='', frameon=True)
            left_ax.imshow(kmeans_cluster_labels.values[spectra_order].reshape(-1, 1),
                            interpolation='none', cmap='Spectral', aspect='auto',
                            rasterized=True)


            top_ax = fig.add_subplot(gs[0,1], xscale='linear', yscale='linear', xticks=[], yticks=[],
                xlabel='', ylabel='', frameon=True)
            top_ax.imshow(kmeans_cluster_labels.values[spectra_order].reshape(1, -1),
                              interpolation='none', cmap='Spectral', aspect='auto',
                                rasterized=True)


            hist_gs = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[1, 3],
                                   wspace=0, hspace=0)

            hist_ax = fig.add_subplot(hist_gs[0,0], xscale='linear', yscale='linear',
                xlabel='', ylabel='', frameon=True, title='Local density histogram')
            hist_ax.hist(local_density.values, bins=np.linspace(0, 1, 50))
            hist_ax.yaxis.tick_right()

            xlim = hist_ax.get_xlim()
            ylim = hist_ax.get_ylim()
            if density_threshold < xlim[1]:
                hist_ax.axvline(density_threshold, linestyle='--', color='k')
                hist_ax.text(density_threshold  + 0.02, ylim[1] * 0.95, 'filtering\nthreshold\n\n', va='top')
            hist_ax.set_xlim(xlim)
            hist_ax.set_xlabel('Mean distance to k nearest neighbors\n\n%d/%d (%.0f%%) spectra above threshold\nwere removed prior to clustering'%(sum(~density_filter), len(density_filter), 100*(~density_filter).mean()))

            ## Add colorbar
            cbar_gs = gridspec.GridSpecFromSubplotSpec(8, 1, subplot_spec=hist_gs[1, 0],
                                   wspace=0, hspace=0)
            cbar_ax = fig.add_subplot(cbar_gs[4,0], xscale='linear', yscale='linear',
                xlabel='', ylabel='', frameon=True, title='Euclidean Distance')
            vmin = D.min().min()
            vmax = D.max().max()
            fig.colorbar(dist_im, cax=cbar_ax,
            ticks=np.linspace(vmin, vmax, 3),
            orientation='horizontal')


            #hist_ax.hist(local_density.values, bins=np.linspace(0, 1, 50))
            #hist_ax.yaxis.tick_right()            

            fig.savefig(self.paths['clustering_plot']%(k, density_threshold_repl), dpi=250)
            if close_clustergram_fig:
                plt.close(fig)


    def k_selection_plot(self, close_fig=True):
        '''
        Borrowed from Alexandrov Et Al. 2013 Deciphering Mutational Signatures
        publication in Cell Reports
        '''
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        stats = []
        for k in sorted(set(run_params.n_components)):

            stats.append(self.consensus(k, skip_density_and_return_after_stats=True).stats)

        stats = pd.DataFrame(stats)
        stats.reset_index(drop = True, inplace = True)

        save_df_to_npz(stats, self.paths['k_selection_stats'])

        fig = plt.figure(figsize=(6, 4))
        ax1 = fig.add_subplot(111)
        ax2 = ax1.twinx()


        ax1.plot(stats.k, stats.stability, 'o-', color='b')
        ax1.set_ylabel('Stability', color='b', fontsize=15)
        for tl in ax1.get_yticklabels():
            tl.set_color('b')
        #ax1.set_xlabel('K', fontsize=15)

        ax2.plot(stats.k, stats.prediction_error, 'o-', color='r')
        ax2.set_ylabel('Error', color='r', fontsize=15)
        for tl in ax2.get_yticklabels():
            tl.set_color('r')

        ax1.set_xlabel('Number of Components', fontsize=15)
        ax1.grid('on')
        plt.tight_layout()
        fig.savefig(self.paths['k_selection_plot'], dpi=250)
        if close_fig:
            plt.close(fig)



if __name__=="__main__":
    """
    Example commands for now:

        output_dir="/Users/averes/Projects/Melton/Notebooks/2018/07-2018/cnmf_test/"


        python cnmf.py prepare --output-dir $output_dir \
           --name test --counts /Users/averes/Projects/Melton/Notebooks/2018/07-2018/cnmf_test/test_data.df.npz \
           -k 6 7 8 9 --n-iter 5

        python cnmf.py factorize  --name test --output-dir $output_dir

        THis can be parallelized as such:

        python cnmf.py factorize  --name test --output-dir $output_dir --total-workers 2 --worker-index WORKER_INDEX (where worker_index starts with 0)

        python cnmf.py combine  --name test --output-dir $output_dir

        python cnmf.py consensus  --name test --output-dir $output_dir

    """

    import sys, argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('command', type=str, choices=['prepare', 'factorize', 'combine', 'consensus', 'k_selection_plot'])
    parser.add_argument('--name', type=str, help='[all] Name for analysis. All output will be placed in [output-dir]/[name]/...', nargs='?', default='cNMF')
    parser.add_argument('--output-dir', type=str, help='[all] Output directory. All output will be placed in [output-dir]/[name]/...', nargs='?', default='.')

    parser.add_argument('-c', '--counts', type=str, help='[prepare] Input (cell x gene) counts matrix as df.npz or tab delimited text file')
    parser.add_argument('-k', '--components', type=int, help='[prepare] Number of components (k) for matrix factorization. Several can be specified with "-k 8 9 10"', nargs='+')
    parser.add_argument('-n', '--n-iter', type=int, help='[prepare] Numper of factorization replicates', default=100)
    parser.add_argument('--total-workers', type=int, help='[all] Total number of workers to distribute jobs to', default=1)
    parser.add_argument('--seed', type=int, help='[prepare] Seed for pseudorandom number generation', default=None)
    parser.add_argument('--genes-file', type=str, help='[prepare] File containing a list of genes to include, one gene per line. Must match column labels of counts matrix.', default=None)
    parser.add_argument('--numgenes', type=int, help='[prepare] Number of high variance genes to use for matrix factorization.', default=2000)
    parser.add_argument('--tpm', type=str, help='[prepare] Pre-computed (cell x gene) TPM values as df.npz or tab separated txt file. If not provided TPM will be calculated automatically', default=None)
    parser.add_argument('--beta-loss', type=str, choices=['frobenius', 'kullback-leibler', 'itakura-saito'], help='[prepare] Loss function for NMF.', default='frobenius')
    parser.add_argument('--densify', dest='densify', help='[prepare] Treat the input data as non-sparse', action='store_true', default=False)


    parser.add_argument('--worker-index', type=int, help='[factorize] Index of current worker (the first worker should have index 0)', default=0)

    parser.add_argument('--local-density-threshold', type=str, help='[consensus] Threshold for the local density filtering. This string must convert to a float >0 and <=2', default='0.5')
    parser.add_argument('--local-neighborhood-size', type=float, help='[consensus] Fraction of the number of replicates to use as nearest neighbors for local density filtering', default=0.30)
    parser.add_argument('--show-clustering', dest='show_clustering', help='[consensus] Produce a clustergram figure summarizing the spectra clustering', action='store_true')

    args = parser.parse_args()

    cnmf_obj = cNMF(output_dir=args.output_dir, name=args.name)
    cnmf_obj._initialize_dirs()

    if args.command == 'prepare':

        if args.counts.endswith('.h5ad'):
            input_counts = sc.read(args.counts)
        else:
            ## Load txt or compressed dataframe and convert to scanpy object
            if args.counts.endswith('.npz'):
                input_counts = load_df_from_npz(args.counts)
            else:
                input_counts = pd.read_csv(args.counts, sep='\t', index_col=0)

            if args.densify:
                input_counts = sc.AnnData(X=input_counts.values,
                                       obs=pd.DataFrame(index=input_counts.index),
                                       var=pd.DataFrame(index=input_counts.columns))
            else:
                input_counts = sc.AnnData(X=sp.csr_matrix(input_counts.values),
                                       obs=pd.DataFrame(index=input_counts.index),
                                       var=pd.DataFrame(index=input_counts.columns))


        if sp.issparse(input_counts.X) & args.densify:
            input_counts.X = np.array(input_counts.X.todense())

        if args.tpm is None:
            tpm = compute_tpm(input_counts)
            sc.write(cnmf_obj.paths['tpm'], tpm)
        elif args.tpm.endswith('.h5ad'):
            subprocess.call('cp %s %s' % (args.tpm, cnmf_obj.paths['tpm']), shell=True)
            tpm = sc.read(cnmf_obj.paths['tpm'])
        else:
            if args.tpm.endswith('.npz'):
                tpm = load_df_from_npz(args.tpm)
            else:
                tpm = pd.read_csv(args.tpm, sep='\t', index_col=0)

            if args.densify:
                tpm = sc.AnnData(X=tpm.values,
                            obs=pd.DataFrame(index=tpm.index),
                            var=pd.DataFrame(index=tpm.columns)) 
            else:
                tpm = sc.AnnData(X=sp.csr_matrix(tpm.values),
                            obs=pd.DataFrame(index=tpm.index),
                            var=pd.DataFrame(index=tpm.columns)) 

            sc.write(cnmf_obj.paths['tpm'], tpm)

        if sp.issparse(tpm.X):
            gene_tpm_mean = np.array(tpm.X.mean(axis=0)).reshape(-1)
            gene_tpm_stddev = var_sparse_matrix(tpm.X)**.5
        else:
            gene_tpm_mean = np.array(tpm.X.mean(axis=0)).reshape(-1)
            gene_tpm_stddev = np.array(tpm.X.std(axis=0, ddof=0)).reshape(-1)


        input_tpm_stats = pd.DataFrame([gene_tpm_mean, gene_tpm_stddev],
             index = ['__mean', '__std']).T
        save_df_to_npz(input_tpm_stats, cnmf_obj.paths['tpm_stats'])

        if args.genes_file is not None:
            highvargenes = open(args.genes_file).read().rstrip().split('\n')
        else:
            highvargenes = None

        norm_counts = cnmf_obj.get_norm_counts(input_counts, tpm, num_highvar_genes=args.numgenes,
                                               high_variance_genes_filter=highvargenes)


        if norm_counts.X.dtype != np.float64:
            norm_counts.X = norm_counts.X.astype(np.float64)

        cnmf_obj.save_norm_counts(norm_counts)
        (replicate_params, run_params) = cnmf_obj.get_nmf_iter_params(ks=args.components, n_iter=args.n_iter, random_state_seed=args.seed, beta_loss=args.beta_loss)
        cnmf_obj.save_nmf_iter_params(replicate_params, run_params)


    elif args.command == 'factorize':
        cnmf_obj.run_nmf(worker_i=args.worker_index, total_workers=args.total_workers)

    elif args.command == 'combine':
        run_params = load_df_from_npz(cnmf_obj.paths['nmf_replicate_parameters'])

        if type(args.components) is int:
            ks = [args.components]
        elif args.components is None:
            ks = sorted(set(run_params.n_components))
        else:
            ks = args.components

        for k in ks:
            cnmf_obj.combine_nmf(k)

    elif args.command == 'consensus':
        run_params = load_df_from_npz(cnmf_obj.paths['nmf_replicate_parameters'])

        if type(args.components) is int:
            ks = [args.components]
        elif args.components is None:
            ks = sorted(set(run_params.n_components))
        else:
            ks = args.components

        for k in ks:
            merged_spectra = load_df_from_npz(cnmf_obj.paths['merged_spectra']%k)
            cnmf_obj.consensus(k, args.local_density_threshold, args.local_neighborhood_size, args.show_clustering)

    elif args.command == 'k_selection_plot':
        cnmf_obj.k_selection_plot()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
import numpy as np
import pandas as pd
import os, errno
import datetime
import uuid
import itertools
import yaml
import subprocess
import scipy.sparse as sp


from scipy.spatial.distance import squareform
from sklearn.decomposition import non_negative_factorization
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.utils import sparsefuncs


from fastcluster import linkage
from scipy.cluster.hierarchy import leaves_list

import matplotlib.pyplot as plt

import scanpy as sc

def save_df_to_npz(obj, filename):
    np.savez_compressed(filename, data=obj.values, index=obj.index.values, columns=obj.columns.values)

def save_df_to_text(obj, filename):
    obj.to_csv(filename, sep='\t')

def load_df_from_npz(filename):
    with np.load(filename, allow_pickle=True) as f:
        obj = pd.DataFrame(**f)
    return obj

def check_dir_exists(path):
    """
    Checks if directory already exists or not and creates it if it doesn't
    """
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise

def worker_filter(iterable, worker_index, total_workers):
    return (p for i,p in enumerate(iterable) if (i-worker_index)%total_workers==0)

def fast_euclidean(mat):
    D = mat.dot(mat.T)
    squared_norms = np.diag(D).copy()
    D *= -2.0
    D += squared_norms.reshape((-1,1))
    D += squared_norms.reshape((1,-1))
    D = np.sqrt(D)
    D[D < 0] = 0
    return squareform(D, checks=False)

def fast_ols_all_cols(X, Y):
    pinv = np.linalg.pinv(X)
    beta = np.dot(pinv, Y)
    return(beta)

def fast_ols_all_cols_df(X,Y):
    beta = fast_ols_all_cols(X, Y)
    beta = pd.DataFrame(beta, index=X.columns, columns=Y.columns)
    return(beta)

def var_sparse_matrix(X):
    mean = np.array(X.mean(axis=0)).reshape(-1)
    Xcopy = X.copy()
    Xcopy.data **= 2
    var = np.array(Xcopy.mean(axis=0)).reshape(-1) - (mean**2)
    return(var)


def get_highvar_genes_sparse(expression, expected_fano_threshold=None,
                       minimal_mean=0.5, numgenes=None):
    # Find high variance genes within those cells
    gene_mean = np.array(expression.mean(axis=0)).astype(float).reshape(-1)
    E2 = expression.copy(); E2.data **= 2; gene2_mean = np.array(E2.mean(axis=0)).reshape(-1)
    gene_var = pd.Series(gene2_mean - (gene_mean**2))
    del(E2)
    gene_mean = pd.Series(gene_mean)
    gene_fano = gene_var / gene_mean

    # Find parameters for expected fano line
    top_genes = gene_mean.sort_values(ascending=False)[:20].index
    A = (np.sqrt(gene_var)/gene_mean)[top_genes].min()

    w_mean_low, w_mean_high = gene_mean.quantile([0.10, 0.90])
    w_fano_low, w_fano_high = gene_fano.quantile([0.10, 0.90])
    winsor_box = ((gene_fano > w_fano_low) &
                    (gene_fano < w_fano_high) &
                    (gene_mean > w_mean_low) &
                    (gene_mean < w_mean_high))
    fano_median = gene_fano[winsor_box].median()
    B = np.sqrt(fano_median)

    gene_expected_fano = (A**2)*gene_mean + (B**2)
    fano_ratio = (gene_fano/gene_expected_fano)

    # Identify high var genes
    if numgenes is not None:
        highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes]
        high_var_genes_ind = fano_ratio.index.isin(highvargenes)
        T=None


    else:
        if not expected_fano_threshold:
            T = (1. + gene_counts_fano[winsor_box].std())
        else:
            T = expected_fano_threshold

        high_var_genes_ind = (fano_ratio > T) & (gene_counts_mean > minimal_mean)

    gene_counts_stats = pd.DataFrame({
        'mean': gene_mean,
        'var': gene_var,
        'fano': gene_fano,
        'expected_fano': gene_expected_fano,
        'high_var': high_var_genes_ind,
        'fano_ratio': fano_ratio
        })
    gene_fano_parameters = {
            'A': A, 'B': B, 'T':T, 'minimal_mean': minimal_mean,
        }
    return(gene_counts_stats, gene_fano_parameters)



def get_highvar_genes(input_counts, expected_fano_threshold=None,
                       minimal_mean=0.5, numgenes=None):
    # Find high variance genes within those cells
    gene_counts_mean = pd.Series(input_counts.mean(axis=0).astype(float))
    gene_counts_var = pd.Series(input_counts.var(ddof=0, axis=0).astype(float))
    gene_counts_fano = pd.Series(gene_counts_var/gene_counts_mean)

    # Find parameters for expected fano line
    top_genes = gene_counts_mean.sort_values(ascending=False)[:20].index
    A = (np.sqrt(gene_counts_var)/gene_counts_mean)[top_genes].min()

    w_mean_low, w_mean_high = gene_counts_mean.quantile([0.10, 0.90])
    w_fano_low, w_fano_high = gene_counts_fano.quantile([0.10, 0.90])
    winsor_box = ((gene_counts_fano > w_fano_low) &
                    (gene_counts_fano < w_fano_high) &
                    (gene_counts_mean > w_mean_low) &
                    (gene_counts_mean < w_mean_high))
    fano_median = gene_counts_fano[winsor_box].median()
    B = np.sqrt(fano_median)

    gene_expected_fano = (A**2)*gene_counts_mean + (B**2)

    fano_ratio = (gene_counts_fano/gene_expected_fano)

    # Identify high var genes
    if numgenes is not None:
        highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes]
        high_var_genes_ind = fano_ratio.index.isin(highvargenes)
        T=None


    else:
        if not expected_fano_threshold:
            T = (1. + gene_counts_fano[winsor_box].std())
        else:
            T = expected_fano_threshold

        high_var_genes_ind = (fano_ratio > T) & (gene_counts_mean > minimal_mean)

    gene_counts_stats = pd.DataFrame({
        'mean': gene_counts_mean,
        'var': gene_counts_var,
        'fano': gene_counts_fano,
        'expected_fano': gene_expected_fano,
        'high_var': high_var_genes_ind,
        'fano_ratio': fano_ratio
        })
    gene_fano_parameters = {
            'A': A, 'B': B, 'T':T, 'minimal_mean': minimal_mean,
        }
    return(gene_counts_stats, gene_fano_parameters)


def compute_tpm(input_counts):
    """
    Default TPM normalization
    """
    tpm = input_counts.copy()
    sc.pp.normalize_per_cell(tpm, counts_per_cell_after=1e6)
    return(tpm)


class cNMF():


    def __init__(self, output_dir=".", name=None):
        """
        Parameters
        ----------

        output_dir : path, optional (default=".")
            Output directory for analysis files.

        name : string, optional (default=None)
            A name for this analysis. Will be prefixed to all output files.
            If set to None, will be automatically generated from date (and random string).
        """

        self.output_dir = output_dir
        if name is None:
            now = datetime.datetime.now()
            rand_hash =  uuid.uuid4().hex[:6]
            name = '%s_%s' % (now.strftime("%Y_%m_%d"), rand_hash)
        self.name = name
        self.paths = None


    def _initialize_dirs(self):
        if self.paths is None:
            # Check that output directory exists, create it if needed.
            check_dir_exists(self.output_dir)
            check_dir_exists(os.path.join(self.output_dir, self.name))
            check_dir_exists(os.path.join(self.output_dir, self.name, 'cnmf_tmp'))

            self.paths = {
                'normalized_counts' : os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.norm_counts.h5ad'),
                'nmf_replicate_parameters' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.nmf_params.df.npz'),
                'nmf_run_parameters' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.nmf_idvrun_params.yaml'),
                'nmf_genes_list' :  os.path.join(self.output_dir, self.name, self.name+'.overdispersed_genes.txt'),

                'tpm' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.tpm.h5ad'),
                'tpm_stats' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.tpm_stats.df.npz'),

                'iter_spectra' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.iter_%d.df.npz'),
                'iter_usages' :  os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.usages.k_%d.iter_%d.df.npz'),
                'merged_spectra': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.merged.df.npz'),

                'local_density_cache': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.local_density_cache.k_%d.merged.df.npz'),
                'consensus_spectra': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.spectra.k_%d.dt_%s.consensus.df.npz'),
                'consensus_spectra__txt': os.path.join(self.output_dir, self.name, self.name+'.spectra.k_%d.dt_%s.consensus.txt'),
                'consensus_usages': os.path.join(self.output_dir, self.name, 'cnmf_tmp',self.name+'.usages.k_%d.dt_%s.consensus.df.npz'),
                'consensus_usages__txt': os.path.join(self.output_dir, self.name, self.name+'.usages.k_%d.dt_%s.consensus.txt'),

                'consensus_stats': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.stats.k_%d.dt_%s.df.npz'),

                'clustering_plot': os.path.join(self.output_dir, self.name, self.name+'.clustering.k_%d.dt_%s.png'),
                'gene_spectra_score': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.gene_spectra_score.k_%d.dt_%s.df.npz'),
                'gene_spectra_score__txt': os.path.join(self.output_dir, self.name, self.name+'.gene_spectra_score.k_%d.dt_%s.txt'),
                'gene_spectra_tpm': os.path.join(self.output_dir, self.name, 'cnmf_tmp', self.name+'.gene_spectra_tpm.k_%d.dt_%s.df.npz'),
                'gene_spectra_tpm__txt': os.path.join(self.output_dir, self.name, self.name+'.gene_spectra_tpm.k_%d.dt_%s.txt'),

                'k_selection_plot' :  os.path.join(self.output_dir, self.name, self.name+'.k_selection.png'),
                'k_selection_stats' :  os.path.join(self.output_dir, self.name, self.name+'.k_selection_stats.df.npz'),
            }


    def get_norm_counts(self, counts, tpm,
                         high_variance_genes_filter = None,
                         num_highvar_genes = None
                         ):
        """
        Parameters
        ----------

        counts : anndata.AnnData
            Scanpy AnnData object (cells x genes) containing raw counts. Filtered such that
            no genes or cells with 0 counts

        tpm : anndata.AnnData
            Scanpy AnnData object (cells x genes) containing tpm normalized data matching
            counts

        high_variance_genes_filter : np.array, optional (default=None)
            A pre-specified list of genes considered to be high-variance.
            Only these genes will be used during factorization of the counts matrix.
            Must match the .var index of counts and tpm.
            If set to None, high-variance genes will be automatically computed, using the
            parameters below.

        num_highvar_genes : int, optional (default=None)
            Instead of providing an array of high-variance genes, identify this many most overdispersed genes
            for filtering

        Returns
        -------

        normcounts : anndata.AnnData, shape (cells, num_highvar_genes)
            A counts matrix containing only the high variance genes and with columns (genes)normalized to unit
            variance

        """

        if high_variance_genes_filter is None:
            ## Get list of high-var genes if one wasn't provided
            if sp.issparse(tpm.X):
                (gene_counts_stats, gene_fano_params) = get_highvar_genes_sparse(tpm.X, numgenes=num_highvar_genes)  
            else:
                (gene_counts_stats, gene_fano_params) = get_highvar_genes(np.array(tpm.X), numgenes=num_highvar_genes)

            high_variance_genes_filter = list(tpm.var.index[gene_counts_stats.high_var.values])

        ## Subset out high-variance genes
        norm_counts = counts[:, high_variance_genes_filter]

        ## Scale genes to unit variance
        if sp.issparse(tpm.X):
            sc.pp.scale(norm_counts, zero_center=False)
            if np.isnan(norm_counts.X.data).sum() > 0:
                print('Warning NaNs in normalized counts matrix')                       
        else:
            norm_counts.X /= norm_counts.X.std(axis=0, ddof=1)
            if np.isnan(norm_counts.X).sum().sum() > 0:
                print('Warning NaNs in normalized counts matrix')                    

        ## Save a \n-delimited list of the high-variance genes used for factorization
        open(self.paths['nmf_genes_list'], 'w').write('\n'.join(high_variance_genes_filter))

        ## Check for any cells that have 0 counts of the overdispersed genes
        zerocells = norm_counts.X.sum(axis=1)==0
        if zerocells.sum()>0:
            examples = norm_counts.obs.index[zerocells]
            print('Warning: %d cells have zero counts of overdispersed genes. E.g. %s' % (zerocells.sum(), examples[0]))
            print('Consensus step may not run when this is the case')

        return(norm_counts)


    def save_norm_counts(self, norm_counts):
        self._initialize_dirs()
        sc.write(self.paths['normalized_counts'], norm_counts)


    def get_nmf_iter_params(self, ks, n_iter = 100,
                               random_state_seed = None,
                               beta_loss = 'kullback-leibler'):
        """
        Create a DataFrame with parameters for NMF iterations.


        Parameters
        ----------
        ks : integer, or list-like.
            Number of topics (components) for factorization.
            Several values can be specified at the same time, which will be run independently.

        n_iter : integer, optional (defailt=100)
            Number of iterations for factorization. If several ``k`` are specified, this many
            iterations will be run for each value of ``k``.

        random_state_seed : int or None, optional (default=None)
            Seed for sklearn random state.

        """

        if type(ks) is int:
            ks = [ks]

        # Remove any repeated k values, and order.
        k_list = sorted(set(list(ks)))

        n_runs = len(ks)* n_iter

        np.random.seed(seed=random_state_seed)
        nmf_seeds = np.random.randint(low=1, high=(2**32)-1, size=n_runs)

        replicate_params = []
        for i, (k, r) in enumerate(itertools.product(k_list, range(n_iter))):
            replicate_params.append([k, r, nmf_seeds[i]])
        replicate_params = pd.DataFrame(replicate_params, columns = ['n_components', 'iter', 'nmf_seed'])

        _nmf_kwargs = dict(
                        alpha=0.0,
                        l1_ratio=0.0,
                        beta_loss=beta_loss,
                        solver='mu',
                        tol=1e-4,
                        max_iter=1000,
                        regularization=None,
                        init='random'
                        )

        ## Coordinate descent is faster than multiplicative update but only works for frobenius
        if beta_loss == 'frobenius':
            _nmf_kwargs['solver'] = 'cd'

        return(replicate_params, _nmf_kwargs)


    def save_nmf_iter_params(self, replicate_params, run_params):
        self._initialize_dirs()
        save_df_to_npz(replicate_params, self.paths['nmf_replicate_parameters'])
        with open(self.paths['nmf_run_parameters'], 'w') as F:
            yaml.dump(run_params, F)


    def _nmf(self, X, nmf_kwargs):
        """
        Parameters
        ----------
        X : pandas.DataFrame,
            Normalized counts dataFrame to be factorized.

        nmf_kwargs : dict,
            Arguments to be passed to ``non_negative_factorization``

        """
        (usages, spectra, niter) = non_negative_factorization(X, **nmf_kwargs)

        return(spectra, usages)


    def run_nmf(self,
                worker_i=1, total_workers=1,
                ):
        """
        Iteratively run NMF with prespecified parameters.

        Use the `worker_i` and `total_workers` parameters for parallelization.

        Generic kwargs for NMF are loaded from self.paths['nmf_run_parameters'], defaults below::

            ``non_negative_factorization`` default arguments:
                alpha=0.0
                l1_ratio=0.0
                beta_loss='kullback-leibler'
                solver='mu'
                tol=1e-4,
                max_iter=200
                regularization=None
                init='random'
                random_state, n_components are both set by the prespecified self.paths['nmf_replicate_parameters'].


        Parameters
        ----------
        norm_counts : pandas.DataFrame,
            Normalized counts dataFrame to be factorized.
            (Output of ``normalize_counts``)

        run_params : pandas.DataFrame,
            Parameters for NMF iterations.
            (Output of ``prepare_nmf_iter_params``)

        """
        self._initialize_dirs()
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        norm_counts = sc.read(self.paths['normalized_counts'])
        _nmf_kwargs = yaml.load(open(self.paths['nmf_run_parameters']), Loader=yaml.FullLoader)

        jobs_for_this_worker = worker_filter(range(len(run_params)), worker_i, total_workers)
        for idx in jobs_for_this_worker:

            p = run_params.iloc[idx, :]
            print('[Worker %d]. Starting task %d.' % (worker_i, idx))
            _nmf_kwargs['random_state'] = p['nmf_seed']
            _nmf_kwargs['n_components'] = p['n_components']

            (spectra, usages) = self._nmf(norm_counts.X, _nmf_kwargs)
            spectra = pd.DataFrame(spectra,
                                   index=np.arange(1, _nmf_kwargs['n_components']+1),
                                   columns=norm_counts.var.index)
            save_df_to_npz(spectra, self.paths['iter_spectra'] % (p['n_components'], p['iter']))


    def combine_nmf(self, k, remove_individual_iterations=False):
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        print('Combining factorizations for k=%d.'%k)

        self._initialize_dirs()

        combined_spectra = None
        n_iter = sum(run_params.n_components==k)

        run_params_subset = run_params[run_params.n_components==k].sort_values('iter')
        spectra_labels = []

        for i,p in run_params_subset.iterrows():

            spectra = load_df_from_npz(self.paths['iter_spectra'] % (p['n_components'], p['iter']))
            if combined_spectra is None:
                combined_spectra = np.zeros((n_iter, k, spectra.shape[1]))
            combined_spectra[p['iter'], :, :] = spectra.values

            for t in range(k):
                spectra_labels.append('iter%d_topic%d'%(p['iter'], t+1))

        combined_spectra = combined_spectra.reshape(-1, combined_spectra.shape[-1])
        combined_spectra = pd.DataFrame(combined_spectra, columns=spectra.columns, index=spectra_labels)

        save_df_to_npz(combined_spectra, self.paths['merged_spectra']%k)
        return combined_spectra


    def consensus(self, k, density_threshold_str='0.5', local_neighborhood_size = 0.30,show_clustering = False,
                  skip_density_and_return_after_stats = False, close_clustergram_fig=True):
        merged_spectra = load_df_from_npz(self.paths['merged_spectra']%k)
        norm_counts = sc.read(self.paths['normalized_counts'])

        if skip_density_and_return_after_stats:
            density_threshold_str = '2'
        density_threshold_repl = density_threshold_str.replace('.', '_')
        density_threshold = float(density_threshold_str)
        n_neighbors = int(local_neighborhood_size * merged_spectra.shape[0]/k)

        # Rescale topics such to length of 1.
        l2_spectra = (merged_spectra.T/np.sqrt((merged_spectra**2).sum(axis=1))).T


        if not skip_density_and_return_after_stats:
            # Compute the local density matrix (if not previously cached)
            topics_dist = None
            if os.path.isfile(self.paths['local_density_cache'] % k):
                local_density = load_df_from_npz(self.paths['local_density_cache'] % k)
            else:
                #   first find the full distance matrix
                topics_dist = squareform(fast_euclidean(l2_spectra.values))
                #   partition based on the first n neighbors
                partitioning_order  = np.argpartition(topics_dist, n_neighbors+1)[:, :n_neighbors+1]
                #   find the mean over those n_neighbors (excluding self, which has a distance of 0)
                distance_to_nearest_neighbors = topics_dist[np.arange(topics_dist.shape[0])[:, None], partitioning_order]
                local_density = pd.DataFrame(distance_to_nearest_neighbors.sum(1)/(n_neighbors),
                                             columns=['local_density'],
                                             index=l2_spectra.index)
                save_df_to_npz(local_density, self.paths['local_density_cache'] % k)
                del(partitioning_order)
                del(distance_to_nearest_neighbors)

            density_filter = local_density.iloc[:, 0] < density_threshold
            l2_spectra = l2_spectra.loc[density_filter, :]

        kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=1)
        kmeans_model.fit(l2_spectra)
        kmeans_cluster_labels = pd.Series(kmeans_model.labels_+1, index=l2_spectra.index)

        # Find median usage for each gene across cluster
        median_spectra = l2_spectra.groupby(kmeans_cluster_labels).median()

        # Normalize median spectra to probability distributions.
        median_spectra = (median_spectra.T/median_spectra.sum(1)).T

        # Compute the silhouette score
        stability = silhouette_score(l2_spectra.values, kmeans_cluster_labels, metric='euclidean')

        # Obtain the reconstructed count matrix by re-fitting the usage matrix and computing the dot product: usage.dot(spectra)
        refit_nmf_kwargs = yaml.load(open(self.paths['nmf_run_parameters']), Loader=yaml.FullLoader)
        refit_nmf_kwargs.update(dict(
                                    n_components = k,
                                    H = median_spectra.values,
                                    update_H = False
                                    ))

        _, rf_usages = self._nmf(norm_counts.X,
                                          nmf_kwargs=refit_nmf_kwargs)
        rf_usages = pd.DataFrame(rf_usages, index=norm_counts.obs.index, columns=median_spectra.index)
        rf_pred_norm_counts = rf_usages.dot(median_spectra)

        # Compute prediction error as a frobenius norm
        if sp.issparse(norm_counts.X):
            prediction_error = ((norm_counts.X.todense() - rf_pred_norm_counts)**2).sum().sum()
        else:
            prediction_error = ((norm_counts.X - rf_pred_norm_counts)**2).sum().sum()

        consensus_stats = pd.DataFrame([k, density_threshold, stability, prediction_error],
                    index = ['k', 'local_density_threshold', 'stability', 'prediction_error'],
                    columns = ['stats'])

        if skip_density_and_return_after_stats:
            return consensus_stats

        save_df_to_npz(median_spectra, self.paths['consensus_spectra']%(k, density_threshold_repl))
        save_df_to_npz(rf_usages, self.paths['consensus_usages']%(k, density_threshold_repl))
        save_df_to_npz(consensus_stats, self.paths['consensus_stats']%(k, density_threshold_repl))
        save_df_to_text(median_spectra, self.paths['consensus_spectra__txt']%(k, density_threshold_repl))
        save_df_to_text(rf_usages, self.paths['consensus_usages__txt']%(k, density_threshold_repl))

        # Compute gene-scores for each GEP by regressing usage on Z-scores of TPM
        tpm = sc.read(self.paths['tpm'])
        tpm_stats = load_df_from_npz(self.paths['tpm_stats'])

        if sp.issparse(tpm.X):
            norm_tpm = (np.array(tpm.X.todense()) - tpm_stats['__mean'].values) / tpm_stats['__std'].values
        else:
            norm_tpm = (tpm.X - tpm_stats['__mean'].values) / tpm_stats['__std'].values

        if norm_tpm.dtype != np.float64:
            norm_tpm = norm_tpm.astype(np.float64)

        usage_coef = fast_ols_all_cols(rf_usages.values, norm_tpm)
        usage_coef = pd.DataFrame(usage_coef, index=rf_usages.columns, columns=tpm.var.index)

        save_df_to_npz(usage_coef, self.paths['gene_spectra_score']%(k, density_threshold_repl))
        save_df_to_text(usage_coef, self.paths['gene_spectra_score__txt']%(k, density_threshold_repl))

        # Convert spectra to TPM units, and obtain results for all genes by running last step of NMF
        # with usages fixed and TPM as the input matrix
        norm_usages = rf_usages.div(rf_usages.sum(axis=1), axis=0)
        refit_nmf_kwargs.update(dict(
                                    H = norm_usages.T.values,
                                ))

        # Needed otherwise _nmf will crash because with inconsistent dtypes
        if tpm.X.dtype != np.float64:
            tpm.X = tpm.X.astype(np.float64)

        _, spectra_tpm = self._nmf(tpm.X.T, nmf_kwargs=refit_nmf_kwargs)
        spectra_tpm = pd.DataFrame(spectra_tpm.T, index=rf_usages.columns, columns=tpm.var.index)
        save_df_to_npz(spectra_tpm, self.paths['gene_spectra_tpm']%(k, density_threshold_repl))
        save_df_to_text(spectra_tpm, self.paths['gene_spectra_tpm__txt']%(k, density_threshold_repl))

        if show_clustering:
            if topics_dist is None:
                topics_dist = squareform(fast_euclidean(l2_spectra.values))
                # (l2_spectra was already filtered using the density filter)
            else:
                # (but the previously computed topics_dist was not!)
                topics_dist = topics_dist[density_filter.values, :][:, density_filter.values]


            spectra_order = []
            for cl in sorted(set(kmeans_cluster_labels)):

                cl_filter = kmeans_cluster_labels==cl

                if cl_filter.sum() > 1:
                    cl_dist = squareform(topics_dist[cl_filter, :][:, cl_filter])
                    cl_dist[cl_dist < 0] = 0 #Rarely get floating point arithmetic issues
                    cl_link = linkage(cl_dist, 'average')
                    cl_leaves_order = leaves_list(cl_link)

                    spectra_order += list(np.where(cl_filter)[0][cl_leaves_order])
                else:
                    ## Corner case where a component only has one element
                    spectra_order += list(np.where(cl_filter)[0])


            from matplotlib import gridspec
            import matplotlib.pyplot as plt

            width_ratios = [0.5, 9, 0.5, 4, 1]
            height_ratios = [0.5, 9]
            fig = plt.figure(figsize=(sum(width_ratios), sum(height_ratios)))
            gs = gridspec.GridSpec(len(height_ratios), len(width_ratios), fig,
                                    0.01, 0.01, 0.98, 0.98,
                                   height_ratios=height_ratios,
                                   width_ratios=width_ratios,
                                   wspace=0, hspace=0)

            dist_ax = fig.add_subplot(gs[1,1], xscale='linear', yscale='linear',
                                      xticks=[], yticks=[],xlabel='', ylabel='',
                                      frameon=True)

            D = topics_dist[spectra_order, :][:, spectra_order]
            dist_im = dist_ax.imshow(D, interpolation='none', cmap='viridis', aspect='auto',
                                rasterized=True)

            left_ax = fig.add_subplot(gs[1,0], xscale='linear', yscale='linear', xticks=[], yticks=[],
                xlabel='', ylabel='', frameon=True)
            left_ax.imshow(kmeans_cluster_labels.values[spectra_order].reshape(-1, 1),
                            interpolation='none', cmap='Spectral', aspect='auto',
                            rasterized=True)


            top_ax = fig.add_subplot(gs[0,1], xscale='linear', yscale='linear', xticks=[], yticks=[],
                xlabel='', ylabel='', frameon=True)
            top_ax.imshow(kmeans_cluster_labels.values[spectra_order].reshape(1, -1),
                              interpolation='none', cmap='Spectral', aspect='auto',
                                rasterized=True)


            hist_gs = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[1, 3],
                                   wspace=0, hspace=0)

            hist_ax = fig.add_subplot(hist_gs[0,0], xscale='linear', yscale='linear',
                xlabel='', ylabel='', frameon=True, title='Local density histogram')
            hist_ax.hist(local_density.values, bins=np.linspace(0, 1, 50))
            hist_ax.yaxis.tick_right()

            xlim = hist_ax.get_xlim()
            ylim = hist_ax.get_ylim()
            if density_threshold < xlim[1]:
                hist_ax.axvline(density_threshold, linestyle='--', color='k')
                hist_ax.text(density_threshold  + 0.02, ylim[1] * 0.95, 'filtering\nthreshold\n\n', va='top')
            hist_ax.set_xlim(xlim)
            hist_ax.set_xlabel('Mean distance to k nearest neighbors\n\n%d/%d (%.0f%%) spectra above threshold\nwere removed prior to clustering'%(sum(~density_filter), len(density_filter), 100*(~density_filter).mean()))

            ## Add colorbar
            cbar_gs = gridspec.GridSpecFromSubplotSpec(8, 1, subplot_spec=hist_gs[1, 0],
                                   wspace=0, hspace=0)
            cbar_ax = fig.add_subplot(cbar_gs[4,0], xscale='linear', yscale='linear',
                xlabel='', ylabel='', frameon=True, title='Euclidean Distance')
            vmin = D.min().min()
            vmax = D.max().max()
            fig.colorbar(dist_im, cax=cbar_ax,
            ticks=np.linspace(vmin, vmax, 3),
            orientation='horizontal')


            #hist_ax.hist(local_density.values, bins=np.linspace(0, 1, 50))
            #hist_ax.yaxis.tick_right()            

            fig.savefig(self.paths['clustering_plot']%(k, density_threshold_repl), dpi=250)
            if close_clustergram_fig:
                plt.close(fig)


    def k_selection_plot(self, close_fig=True):
        '''
        Borrowed from Alexandrov Et Al. 2013 Deciphering Mutational Signatures
        publication in Cell Reports
        '''
        run_params = load_df_from_npz(self.paths['nmf_replicate_parameters'])
        stats = []
        for k in sorted(set(run_params.n_components)):

            stats.append(self.consensus(k, skip_density_and_return_after_stats=True).stats)

        stats = pd.DataFrame(stats)
        stats.reset_index(drop = True, inplace = True)

        save_df_to_npz(stats, self.paths['k_selection_stats'])

        fig = plt.figure(figsize=(6, 4))
        ax1 = fig.add_subplot(111)
        ax2 = ax1.twinx()


        ax1.plot(stats.k, stats.stability, 'o-', color='b')
        ax1.set_ylabel('Stability', color='b', fontsize=15)
        for tl in ax1.get_yticklabels():
            tl.set_color('b')
        #ax1.set_xlabel('K', fontsize=15)

        ax2.plot(stats.k, stats.prediction_error, 'o-', color='r')
        ax2.set_ylabel('Error', color='r', fontsize=15)
        for tl in ax2.get_yticklabels():
            tl.set_color('r')

        ax1.set_xlabel('Number of Components', fontsize=15)
        ax1.grid('on')
        plt.tight_layout()
        fig.savefig(self.paths['k_selection_plot'], dpi=250)
        if close_fig:
            plt.close(fig)



if __name__=="__main__":
    """
    Example commands for now:

        output_dir="/Users/averes/Projects/Melton/Notebooks/2018/07-2018/cnmf_test/"


        python cnmf.py prepare --output-dir $output_dir \
           --name test --counts /Users/averes/Projects/Melton/Notebooks/2018/07-2018/cnmf_test/test_data.df.npz \
           -k 6 7 8 9 --n-iter 5

        python cnmf.py factorize  --name test --output-dir $output_dir

        THis can be parallelized as such:

        python cnmf.py factorize  --name test --output-dir $output_dir --total-workers 2 --worker-index WORKER_INDEX (where worker_index starts with 0)

        python cnmf.py combine  --name test --output-dir $output_dir

        python cnmf.py consensus  --name test --output-dir $output_dir

    """

    import sys, argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('command', type=str, choices=['prepare', 'factorize', 'combine', 'consensus', 'k_selection_plot'])
    parser.add_argument('--name', type=str, help='[all] Name for analysis. All output will be placed in [output-dir]/[name]/...', nargs='?', default='cNMF')
    parser.add_argument('--output-dir', type=str, help='[all] Output directory. All output will be placed in [output-dir]/[name]/...', nargs='?', default='.')

    parser.add_argument('-c', '--counts', type=str, help='[prepare] Input (cell x gene) counts matrix as df.npz or tab delimited text file')
    parser.add_argument('-k', '--components', type=int, help='[prepare] Numper of components (k) for matrix factorization. Several can be specified with "-k 8 9 10"', nargs='+')
    parser.add_argument('-n', '--n-iter', type=int, help='[prepare] Numper of factorization replicates', default=100)
    parser.add_argument('--total-workers', type=int, help='[all] Total number of workers to distribute jobs to', default=1)
    parser.add_argument('--seed', type=int, help='[prepare] Seed for pseudorandom number generation', default=None)
    parser.add_argument('--genes-file', type=str, help='[prepare] File containing a list of genes to include, one gene per line. Must match column labels of counts matrix.', default=None)
    parser.add_argument('--numgenes', type=int, help='[prepare] Number of high variance genes to use for matrix factorization.', default=2000)
    parser.add_argument('--tpm', type=str, help='[prepare] Pre-computed (cell x gene) TPM values as df.npz or tab separated txt file. If not provided TPM will be calculated automatically', default=None)
    parser.add_argument('--beta-loss', type=str, choices=['frobenius', 'kullback-leibler', 'itakura-saito'], help='[prepare] Loss function for NMF.', default='frobenius')
    parser.add_argument('--densify', dest='densify', help='[prepare] Treat the input data as non-sparse', action='store_true', default=False)


    parser.add_argument('--worker-index', type=int, help='[factorize] Index of current worker (the first worker should have index 0)', default=0)

    parser.add_argument('--local-density-threshold', type=str, help='[consensus] Threshold for the local density filtering. This string must convert to a float >0 and <=2', default='0.5')
    parser.add_argument('--local-neighborhood-size', type=float, help='[consensus] Fraction of the number of replicates to use as nearest neighbors for local density filtering', default=0.30)
    parser.add_argument('--show-clustering', dest='show_clustering', help='[consensus] Produce a clustergram figure summarizing the spectra clustering', action='store_true')

    args = parser.parse_args()

    cnmf_obj = cNMF(output_dir=args.output_dir, name=args.name)
    cnmf_obj._initialize_dirs()

    if args.command == 'prepare':

        if args.counts.endswith('.h5ad'):
            input_counts = sc.read(args.counts)
        else:
            ## Load txt or compressed dataframe and convert to scanpy object
            if args.counts.endswith('.npz'):
                input_counts = load_df_from_npz(args.counts)
            else:
                input_counts = pd.read_csv(args.counts, sep='\t', index_col=0)

            if args.densify:
                input_counts = sc.AnnData(X=input_counts.values,
                                       obs=pd.DataFrame(index=input_counts.index),
                                       var=pd.DataFrame(index=input_counts.columns))
            else:
                input_counts = sc.AnnData(X=sp.csr_matrix(input_counts.values),
                                       obs=pd.DataFrame(index=input_counts.index),
                                       var=pd.DataFrame(index=input_counts.columns))


        if sp.issparse(input_counts.X) & args.densify:
            input_counts.X = np.array(input_counts.X.todense())

        if args.tpm is None:
            tpm = compute_tpm(input_counts)
            sc.write(cnmf_obj.paths['tpm'], tpm)
        elif args.tpm.endswith('.h5ad'):
            subprocess.call('cp %s %s' % (args.tpm, cnmf_obj.paths['tpm']), shell=True)
            tpm = sc.read(cnmf_obj.paths['tpm'])
        else:
            if args.tpm.endswith('.npz'):
                tpm = load_df_from_npz(args.tpm)
            else:
                tpm = pd.read_csv(args.tpm, sep='\t', index_col=0)

            if args.densify:
                tpm = sc.AnnData(X=tpm.values,
                            obs=pd.DataFrame(index=tpm.index),
                            var=pd.DataFrame(index=tpm.columns)) 
            else:
                tpm = sc.AnnData(X=sp.csr_matrix(tpm.values),
                            obs=pd.DataFrame(index=tpm.index),
                            var=pd.DataFrame(index=tpm.columns)) 

            sc.write(cnmf_obj.paths['tpm'], tpm)

        if sp.issparse(tpm.X):
            gene_tpm_mean = np.array(tpm.X.mean(axis=0)).reshape(-1)
            gene_tpm_stddev = var_sparse_matrix(tpm.X)**.5
        else:
            gene_tpm_mean = np.array(tpm.X.mean(axis=0)).reshape(-1)
            gene_tpm_stddev = np.array(tpm.X.std(axis=0, ddof=0)).reshape(-1)


        input_tpm_stats = pd.DataFrame([gene_tpm_mean, gene_tpm_stddev],
             index = ['__mean', '__std']).T
        save_df_to_npz(input_tpm_stats, cnmf_obj.paths['tpm_stats'])

        if args.genes_file is not None:
            highvargenes = open(args.genes_file).read().rstrip().split('\n')
        else:
            highvargenes = None

        norm_counts = cnmf_obj.get_norm_counts(input_counts, tpm, num_highvar_genes=args.numgenes,
                                               high_variance_genes_filter=highvargenes)


        if norm_counts.X.dtype != np.float64:
            norm_counts.X = norm_counts.X.astype(np.float64)

        cnmf_obj.save_norm_counts(norm_counts)
        (replicate_params, run_params) = cnmf_obj.get_nmf_iter_params(ks=args.components, n_iter=args.n_iter, random_state_seed=args.seed, beta_loss=args.beta_loss)
        cnmf_obj.save_nmf_iter_params(replicate_params, run_params)


    elif args.command == 'factorize':
        cnmf_obj.run_nmf(worker_i=args.worker_index, total_workers=args.total_workers)

    elif args.command == 'combine':
        run_params = load_df_from_npz(cnmf_obj.paths['nmf_replicate_parameters'])

        if type(args.components) is int:
            ks = [args.components]
        elif args.components is None:
            ks = sorted(set(run_params.n_components))
        else:
            ks = args.components

        for k in ks:
            cnmf_obj.combine_nmf(k)

    elif args.command == 'consensus':
        run_params = load_df_from_npz(cnmf_obj.paths['nmf_replicate_parameters'])

        if type(args.components) is int:
            ks = [args.components]
        elif args.components is None:
            ks = sorted(set(run_params.n_components))
        else:
            ks = args.components

        for k in ks:
            merged_spectra = load_df_from_npz(cnmf_obj.paths['merged_spectra']%k)
            cnmf_obj.consensus(k, args.local_density_threshold, args.local_neighborhood_size, args.show_clustering)

    elif args.command == 'k_selection_plot':
        cnmf_obj.k_selection_plot()
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(ggpubr))
suppressPackageStartupMessages(library(data.table))
suppressPackageStartupMessages(library(tidyr))
suppressPackageStartupMessages(library(readxl))
suppressPackageStartupMessages(library(ggrepel))
suppressPackageStartupMessages(library(optparse))
suppressPackageStartupMessages(library(gplots))
suppressPackageStartupMessages(library(cowplot))

## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/2kG.library/acrossK/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK/", help="Output directory"),
  make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  make_option("--p.adj.threshold", type="numeric", default=0.1, help="Threshold for fdr and adjusted p-value"),
  make_option("--aggregated.data", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis//all_genes/Perturb_2kG_dup4/acrossK/aggregated.outputs.findK.perturb-seq.RData")
)
opt <- parse_args(OptionParser(option_list=option.list))

## ## sdev for 2n1.99x singlets
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/figures/all_genes/Perturb_2kG_dup4/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/outputs/all_genes/Perturb_2kG_dup4/acrossK/"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/acrossK/aggregated.outputs.findK.perturb-seq.RData"
## opt$sampleName <- "Perturb_2kG_dup4"

## ## for all genes (in sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/2kG.library/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK/"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK//aggregated.outputs.findK.perturb-seq.RData"


## ## ## for testing cNMF_ pipeline
## ## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/2kG.library/acrossK/"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_2min/acrossK/aggregated.outputs.findK.RData"


## ## for testing cNMF_pipeline with FT010_fresh_4min
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/FT010_fresh_4min/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_4min/acrossK/"
## opt$sampleName <- "FT010_fresh_4min"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_4min/acrossK/aggregated.outputs.findK.RData"


## ## for testing findK_plots for scRNAseq_2kG_11AMDox_1
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/figures/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/"
## opt$sampleName <- "scRNAseq_2kG_11AMDox_1"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/aggregated.outputs.findK.RData"

## ## for testing findK_plots for control only cells
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/2kG.library.ctrl.only/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/acrossK/"
## opt$sampleName <- "2kG.library.ctrl.only"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/acrossK/aggregated.outputs.findK.RData"

## ## for testing findK_plots for perturb-seq only data for control only cells
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/acrossK/aggregated.outputs.findK.perturb-seq.RData"

## ## K562 gwps 2k overdispersed genes
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/acrossK/"
## opt$sampleName <- "WeissmanK562gwps"
## opt$p.adj.threshold <- 0.1
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/acrossK/aggregated.outputs.findK.perturb-seq.RData"

## Directories and Constants
SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DATADIR=opt$datadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
FIGDIR=opt$figdir
check.dir <- c(OUTDIR, FIGDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))
fdr.thr <- opt$p.adj.threshold

## load("/Volumes/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/analysis/no_IL1B/aggregated.outputs.findK.RData")
## load("/Volumes/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/differential_expression/210526_SeuratDE/outputs/no_IL1B/de.markers.RData")
mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))
mytheme <- theme_classic() + theme(axis.text = element_text(size = 5),
                                   axis.title = element_text(size = 6),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=7),
                                   axis.line = element_line(color = "black", size = 0.25),
                                   axis.ticks = element_line(color = "black", size = 0.25),
                                   legend.title  = element_text(size=6),
                                   legend.text = element_text(size=6))
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)

## Load data
load(opt$aggregated.data) ## MAST.df, all.test.df, all.enhancer.ttest.df, all.promoter.ttest.df
## ref.table <- read_xlsx(opt$reference.table, sheet="2000_gene_library_annotated") 

## process mast data
MAST.original <- MAST.df <- MAST.df %>% mutate(ProgramID = paste0("K", K, "_", gsub("topic_", "", primerid))) %>%
    mutate(perturbation = gsub("MESDC1", "TLNRD1", perturbation))
## for antiparallel genes, GeneA-and-GeneB, keep {GeneA, GeneB}_multiTarget and remove {GeneA, GeneB} perturbations
if(grepl("2kG.library", SAMPLE)) {
    antiparallel.perturbation <- MAST.df %>% subset(grepl("multiTarget", perturbation)) %>% pull(perturbation) %>% unique %>% gsub("_multiTarget", "", .)
    MAST.df <- MAST.df %>% subset(!(perturbation %in% antiparallel.perturbation))
}
MAST.df <- MAST.df %>%
    subset(zlm.model.name == "batch.correction") %>%
    group_by(zlm.model.name, K) %>%
    mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method="fdr")) %>%
    group_by(ProgramID) %>%
    arrange(desc(coef)) %>%
    mutate(coef_rank = 1:n()) %>%
    as.data.frame
sig.MAST.df <- MAST.df %>% subset(fdr.across.ptb < fdr.thr)

##################################################
## plots
fig.file.name <- paste0(FIGDIR, "/percent.batch.topics")
batch.percent.df$batch.thr <- as.character(batch.percent.df$batch.thr)
## percent of topics correlated with batch over K
pdf(paste0(fig.file.name, ".pdf"), width=3, height=2)
p <- batch.percent.df %>% ggplot(aes(x = K, y = percent.correlated, color = batch.thr)) + geom_line(size=0.5) + geom_point(size=0.5) + mytheme +
    ggtitle(paste0(SAMPLE, " Percent of Programs Correlated with Batch")) +
    ## scale_x_continuous(breaks = batch.percent.df$K %>% unique) + 
    scale_y_continuous(name = "% Programs Correlated with Batch", labels = scales::percent) +
    scale_color_discrete(name = "Pearson correlation")
print(p)
ggsave(paste0(fig.file.name, ".eps"))
dev.off()

## MAST DE topics results
## metrics:
## # programs 
## # unique programs
## fraction of significant programs
MAST.program.summary.df <- sig.MAST.df %>%
    select(K, ProgramID) %>%
    unique %>%
    group_by(K) %>%
    summarize(nPrograms = n()) %>%
    mutate(fractionPrograms = nPrograms / K) %>%
    as.data.frame
MAST.ptb.summary.df <- sig.MAST.df %>%
    select(K, perturbation) %>%
    unique %>%
    group_by(K) %>%
    summarize(nPerturbations = n()) %>%
    as.data.frame
MAST.ptb.program.pair.summary.df <- sig.MAST.df %>%
    select(K, ProgramID, perturbation) %>%
    unique %>%
    group_by(K) %>%
    summarize(nPerturbationProgramPairs = n()) %>%
    mutate(averagePerturbationProgramPairs = nPerturbationProgramPairs / K) %>%
    as.data.frame
MAST.summary.df <- merge(MAST.program.summary.df, MAST.ptb.summary.df, by="K", all=T) %>% merge(MAST.ptb.program.pair.summary.df, by="K", all=T)    
## function to produce find K plot based on MAST result
plotMAST <- function(toplot, MAST.metric, MAST.metric.label) {
    p <- toplot %>% ggplot(aes(x=K, y=get(MAST.metric))) + geom_point(size=0.5) + geom_line(size=0.5) + mytheme +
    xlab("K") + ylab(paste0(MAST.metric.label))
    print(p)
    return(p)
}

toplot <- MAST.summary.df
MAST.metrics <- MAST.summary.df %>% select(-K) %>% colnames
MAST.metric.labels <- c("# Significant\nPrograms", "Fraction of\nSignificant\nPrograms", "# Regulators", "# Regulator x\n Program Pairs", "Average #\nRegulator x\nProgram Pairs")

plotFilename <- paste0(FIGDIR, "MAST")
plot.list <- list()
pdf(paste0(plotFilename, ".pdf"), width=3, height=3)
for(i in 1:length(MAST.metrics)) {
    MAST.metric = MAST.metrics[i] 
   MAST.metric.label = MAST.metric.labels[i]
    plot.list[[MAST.metric]] <- plotMAST(toplot, MAST.metric, MAST.metric.label)
    eval(parse(text = paste0("p.", MAST.metric, " <- p")))
}
dev.off()

plotFilename <- paste0(FIGDIR, "All_MAST")
pdf(paste0(plotFilename, ".pdf"), width=2, height=4)
p <- plot_grid(plotlist = plot.list, nrow = length(plot.list), align="v", axis="lr")
p <- annotate_figure(p, top = text_grob("Statistical test by MAST", size=8))
print(p)
ggsave(paste0(plotFilename, ".eps"))
dev.off()


## Wilcoxon Test results


## ## Update the files for 2kG library
## ## Intersecting EdgeR and Topic Model Results
## threshold <- opt$p.adj.threshold
## # % of perturbations with a large number of DE genes in the per-gene analysis that ALSO have significantly DE topics  (p-value 0.005, logFC 1.2)
## strict.DE.all <- read_excel(path="/Users/helenkang/Documents/EngreitzLab/Pertube-seq_Analysis/200_gene_lib_EdgeR_DE.xlsx", sheet="EdgeR_DE_p.001_lfc1.2")
## lenient.DE.all <- read_excel(path="/Users/helenkang/Documents/EngreitzLab/Pertube-seq_Analysis/200_gene_lib_EdgeR_DE.xlsx", sheet="EdgeR_DE_p.01_lfc1.15")

## strict.DE <- strict.DE.all %>% select(Target...13, `NoI_#DE_genes`, `PlusI_#DE_genes`) %>% `colnames<-`(c("Gene", "NoI.num.genes", "PlusI.num.genes")) %>% mutate(cutoff.type = "strict", cutoff.detail = "p < 0.001, logFC > 1.2")
## lenient.DE <- lenient.DE.all %>% select(Target...13, NoI_DE...14, PlusI_DE...15) %>% `colnames<-`(c("Gene", "NoI.num.genes", "PlusI.num.genes")) %>% mutate(cutoff.type = "lenient", cutoff.detail = "p < 0.01, logFC > 1.15")
## edgeR.DE <- rbind(strict.DE, lenient.DE)

## perturbation.test.stat.df <- all.test.df %>% subset(adjusted.p.value < threshold) %>% select(K, Gene, test.type, Topic) %>% group_by(K, Gene, test.type) %>% summarize(topic.count = n())
## edgeR.test.df <- merge(edgeR.DE, perturbation.test.stat.df, by="Gene")


## MSigDB Pathway Enrichment 
## Number of GO enrichment per model


## Notes 210518
## also output pdf
## plot top enriched MSigDB pathways for each topic and K
## also consider eps file
## make sure that p.adjust is caluclated over all topics for a model


## plot number of GO enrichment on raw score ranking per K
## par(mar = c(4, 4, .1, .1))
threshold = opt$p.adj.threshold
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(ggpubr))
suppressPackageStartupMessages(library(data.table))
suppressPackageStartupMessages(library(tidyr))
suppressPackageStartupMessages(library(readxl))
suppressPackageStartupMessages(library(ggrepel))
suppressPackageStartupMessages(library(optparse))
suppressPackageStartupMessages(library(gplots))
suppressPackageStartupMessages(library(cowplot))
suppressPackageStartupMessages(library(ggpubr)) ## annotate arranged figure

## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/2kG.library/acrossK/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK/", help="Output directory"),
  make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  make_option("--p.adj.threshold", type="numeric", default=0.1, help="Threshold for fdr and adjusted p-value"),
  make_option("--aggregated.data", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK/aggregated.outputs.findK.RData")
)
opt <- parse_args(OptionParser(option_list=option.list))


## ## for overdispersedGenes 220617
## opt$figdi <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/figures/top2000VariableGenes/2kG.library_overdispersedGenes/acrossK/"

## ## for all genes 210707 folder
## ## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/2kG.library/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/acrossK/"
## opt$sampleName <- "2kG.library"

## ## for all genes (in sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/figures/all_genes/Perturb_2kG_dup4/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/acrossK/"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/acrossK//aggregated.outputs.findK.RData"
## opt$sampleName <- "Perturb_2kG_dup4"

## ## ## for testing cNMF_ pipeline
## ## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/2kG.library/acrossK/"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_2min/acrossK/aggregated.outputs.findK.RData"


## ## for testing cNMF_pipeline with FT010_fresh_4min
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/FT010_fresh_4min/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_4min/acrossK/"
## opt$sampleName <- "FT010_fresh_4min"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/all_genes/FT010_fresh_4min/acrossK/aggregated.outputs.findK.RData"


## ## for testing findK_plots for scRNAseq_2kG_11AMDox_1
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/figures/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/"
## opt$sampleName <- "scRNAseq_2kG_11AMDox_1"
## opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/acrossK/aggregated.outputs.findK.RData"

# ## for testing findK_plots for control only cells
# opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/2kG.library.ctrl.only/acrossK/"
# opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/acrossK/"
# opt$sampleName <- "2kG.library.ctrl.only"
# opt$aggregated.data <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/acrossK/aggregated.outputs.findK.RData"


## ## K562 gwps 2k overdispersed genes
## opt$figdir <- ""

## Directories and Constants
SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
threshold <- opt$p.adj.threshold
DATADIR=opt$datadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
FIGDIR=opt$figdir
check.dir <- c(OUTDIR, FIGDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))


## load("/Volumes/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2105_findK/analysis/no_IL1B/aggregated.outputs.findK.RData")
## load("/Volumes/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/differential_expression/210526_SeuratDE/outputs/no_IL1B/de.markers.RData")
mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))
mytheme <- theme_classic() + theme(axis.text = element_text(size = 5),
                                   axis.title = element_text(size = 6),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=7),
                                   axis.line = element_line(color = "black", size = 0.25),
                                   axis.ticks = element_line(color = "black", size = 0.25))
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
palette.log2FC = colorRampPalette(c("blue", "white", "red"))(n = 100)
palette.pos = colorRampPalette(c("white", "red"))(n = 100)
adjusted.p.thr <- 0.05
fdr.thr <- 0.05

## Load data
load(opt$aggregated.data) ## all.fdr.df, all.test.df, enhancer.fisher.df, count.by.GWAS, fgsea.results.df, promoter.fisher.df
## ref.table <- read_xlsx(opt$reference.table, sheet="2000_gene_library_annotated") 


## ## Update the files for 2kG library
## ## Intersecting EdgeR and Topic Model Results
## threshold <- opt$p.adj.threshold
## # % of perturbations with a large number of DE genes in the per-gene analysis that ALSO have significantly DE topics  (p-value 0.005, logFC 1.2)
## strict.DE.all <- read_excel(path="/Users/helenkang/Documents/EngreitzLab/Pertube-seq_Analysis/200_gene_lib_EdgeR_DE.xlsx", sheet="EdgeR_DE_p.001_lfc1.2")
## lenient.DE.all <- read_excel(path="/Users/helenkang/Documents/EngreitzLab/Pertube-seq_Analysis/200_gene_lib_EdgeR_DE.xlsx", sheet="EdgeR_DE_p.01_lfc1.15")

## strict.DE <- strict.DE.all %>% select(Target...13, `NoI_#DE_genes`, `PlusI_#DE_genes`) %>% `colnames<-`(c("Gene", "NoI.num.genes", "PlusI.num.genes")) %>% mutate(cutoff.type = "strict", cutoff.detail = "p < 0.001, logFC > 1.2")
## lenient.DE <- lenient.DE.all %>% select(Target...13, NoI_DE...14, PlusI_DE...15) %>% `colnames<-`(c("Gene", "NoI.num.genes", "PlusI.num.genes")) %>% mutate(cutoff.type = "lenient", cutoff.detail = "p < 0.01, logFC > 1.15")
## edgeR.DE <- rbind(strict.DE, lenient.DE)

## perturbation.test.stat.df <- all.test.df %>% subset(adjusted.p.value < threshold) %>% select(K, Gene, test.type, Topic) %>% group_by(K, Gene, test.type) %>% summarize(topic.count = n())
## edgeR.test.df <- merge(edgeR.DE, perturbation.test.stat.df, by="Gene")


## MSigDB Pathway Enrichment 
## Number of GO enrichment per model


## Notes 210518
## also output pdf
## plot top enriched MSigDB pathways for each topic and K
## also consider eps file
## make sure that p.adjust is caluclated over all topics for a model

##########################################################################################
## GSEA plots
##################################################
## labels and parameters for the for loops
ranking.types <- c("zscore", "raw")
GSEA.types <- c("GOEnrichment", "ByWeightGSEA", "GSEA")
GSEA.type.labels <- c("GO Term Enrichment\n(Top 300 Program Gene by Hypergeometric Test)", "Gene Set Enrichment\n(All Genes)", "Gene Set Enrichment\n(Top 300 Program Gene by Hypergeomteric Test)")
##################################################
## function to output plot in the same format
plotGSEA <- function(toplot, nPathwayMetric, nPathwayMetricLabel) {
    p <- toplot %>% ggplot(aes(x=K, y=get(nPathwayMetric))) + geom_line(size=0.5) + geom_point(size=0.5) + mytheme +
        xlab("K") + ylab(nPathwayMetricLabel) #+ #ggtitle(paste0(GSEA.type.label)) +
        ## scale_x_continuous("K", labels = as.character(K), breaks = K)
    print(p)
    return(p)
}

##################################################
for (GSEA.type.i in 1:length(GSEA.types)) {
    GSEA.type <- GSEA.types[GSEA.type.i]
    GSEA.type.label <- GSEA.type.labels[GSEA.type.i]

##################################################
    ## process the GSEA data here
    tmp.df <- get(paste0("clusterProfiler.", GSEA.type, ".df")) %>%
        subset(p.adjust < fdr.thr) %>%
        group_by(K, type) %>%
        mutate(nPathways = n()) %>%
        select(K, type, nPathways, ID, Description) %>%
        unique %>%
        mutate(nUniquePathways = n(),
               normalizedNPathways = nPathways / K,
               normalizedNUniquePathways = nUniquePathways / K) %>%
        as.data.frame
    summary.df <- tmp.df %>%
        select(-ID, -Description) %>%
        unique %>%
        `rownames<-`(paste0(.$type, "_K", .$K))
    K <- summary.df %>% pull(K) %>% unique() # K for x tick labels

##################################################
    for (ranking.type in ranking.types) {
        toplot <- summary.df %>%
            subset(type == ranking.type) ## subset to selected program gene ranking type (zscore or raw)

        nPathwayMetrics <- toplot %>% select(-K, -type) %>% colnames
        nPathwayMetricLabels <- c("# Total Pathways", "# Unique Pathways", "Average Pathways\nper Program", "Aveage Unique Pathways\nper Program")

        plotFilename <- paste0(FIGDIR, GSEA.type, "_", ranking.type)
        pdf(paste0(plotFilename, ".pdf"), width=3, height=3)
        plot.list <- list() ## create a list to store all metrics
        for (i in 1:length(nPathwayMetrics)) { ## make a line plot 
            nPathwayMetric <- nPathwayMetrics[i]
            nPathwayMetricLabel <- nPathwayMetricLabels[i]
            plot.list[[nPathwayMetric]] <- plotGSEA(toplot, nPathwayMetric, nPathwayMetricLabel)
        }
        p <- plot_grid(plotlist = plot.list, nrow=length(plot.list), align="v", axis="lr") 
        p <- annotate_figure(p, top = text_grob(paste0(GSEA.type.label, "\n", ranking.type), size=8))
        print(p)
        eval(parse(text = paste0("p.", GSEA.type, ".", ranking.type, " <- p")))
        dev.off()
    }
}

## combine all plots in one panel
p.all.GSEA <- plot_grid(p.GOEnrichment.zscore, p.GOEnrichment.raw,
               p.GSEA.zscore, p.GSEA.raw,
               p.ByWeightGSEA.zscore, p.ByWeightGSEA.raw,
               nrow = 3, ncol = 2, align="hv", axis="tblr")
plotFilename <- paste0(FIGDIR, "/All_GSEA")
pdf(paste0(plotFilename, ".pdf"), width=6, height=10)
print(p.all.GSEA)
ggsave(paste0(plotFilename, ".eps"))
dev.off()

##################################################
## End of GSEA plots
##################################################


## Notes:
## why is there a sharp change from K=19 to K=21?


## Fraction of topics that has at least one significant (metric) versus K
num.program.genes <- 300
enrichment.thr <- 1
## plot number of TF enriched per K for {enhancers, promoters}
for (ep.type in c("promoter", "enhancer")) {
    ep.type.label <- ifelse(ep.type == "promoter", "Promoter", "Enhancer")
    pdf(file=paste0(FIGDIR, "/TF.motif.enrichment.", ep.type, ".fdr.thr", as.character(fdr.thr), ".pdf"), width=3, height=3)
    ## promoters
    toplot <- get(paste0("all.", ep.type, ".ttest.df")) %>%
        mutate(significant = two.sided.p.adjust < fdr.thr & enrichment > enrichment.thr) %>%
        group_by(K) %>%
        summarise(total=significant %>% as.numeric %>% sum) %>%
        mutate(average.per.topic = total / K) %>%
        as.data.frame
    ## K <- toplot %>% pull(K) %>% unique() # K for x tick labels
    title <- paste0("Transcription Factors Enriched in \n", ep.type.label, " of Program Genes ")
    ## total number of significant TF plots
    p1 <- toplot %>% ggplot(aes(x=K, y=total)) + geom_line(size=0.5) + geom_point(size=0.5) +
    ## p1 <- toplot %>% ggplot(aes(x=K, y=total)) + geom_col(color="gray35") +
        xlab("K") + ylab(paste0("# Transcription Factors\n (FDR < ", fdr.thr, ")")) + mytheme +
        ## scale_x_continuous("K", labels = as.character(K), breaks = K) +
        ggtitle(title)
    print(p1)
    p2 <- toplot %>% ggplot(aes(x=K, y=average.per.topic)) + geom_line(size=0.5) + geom_point(size=0.5) +
    ## p2 <- toplot %>% ggplot(aes(x=K, y=average.per.topic)) + geom_col(width=0.5, color="gray35") +
        xlab("K") + ylab(paste0("Average # Transcription Factors per Program\n(FDR < ", fdr.thr, ")")) + mytheme +
        ## scale_x_continuous("K", labels = as.character(K), breaks = K) +
        ggtitle(paste0("Transcription Factors Enriched in \nPromoter of the Top ", num.program.genes, " Genes of Each Program"))
    print(p2)

    ## plot number of unique TF enriched per K for promoters
    toplot.unique <- get(paste0("all.", ep.type, ".ttest.df")) %>%
        mutate(significant = two.sided.p.adjust < fdr.thr & enrichment > enrichment.thr) %>%
        select(K, motif, significant) %>%
        unique() %>%
        group_by(K) %>%
        summarise(total=significant %>% as.numeric %>% sum) %>%
        mutate(average.per.topic = total / K) %>%
        as.data.frame
    K <- toplot.unique %>% pull(K) %>% unique() # K for x tick labels
    ## total number of significant TF plots
    p3 <- toplot.unique %>% ggplot(aes(x=K, y=total)) + geom_line(size=0.5) + geom_point(size=0.5) +
        xlab("K") + ylab(paste0("Number of Transcription Factors with adjusted p-value < ", fdr.thr)) + mytheme #+
        ## scale_x_continuous("K", labels = as.character(K), breaks = K)
    print(p3)
    p4 <- toplot.unique %>% ggplot(aes(x=K, y=average.per.topic)) + geom_line(size=0.5) + geom_point(size=0.5) +
        xlab("K") + ylab(paste0("Average Number of Unique Transcription Factors per Program \n with adjusted p-value < ", fdr.thr)) + mytheme #+
        ## scale_x_continuous("K", labels = as.character(K), breaks = K) +
        ggtitle(paste0("Unique Transcription Factors Enriched in \nPromoters of the Top 100 Genes of Each Program"))
    print(p4)

    p <- plot_grid(p1 + ggtitle("") + ylab("# TFs"), p2 + ggtitle("") + ylab("# TF per\nProgram"),
                   p3 + ggtitle("") + ylab("# Unique TFs"), p4 + ggtitle("") + ylab("# Unique TF\nper Program"), nrow=4, align = "hv", axis="tblr")
    ## title = ep.type.label
    eval(parse(text = paste0("p.", ep.type, " <- annotate_figure(p, top = text_grob(label=title, face='bold', size=8))")))
    print(get(paste0("p.", ep.type)))    
    dev.off()
}

## put together all motif enrichment plot in one panel
p.all.TFMotifEnrichment <- plot_grid(p.promoter, p.enhancer, nrow=1)
plotFilename <- paste0(FIGDIR, "/All_TFMotifEnrichment")
pdf(paste0(plotFilename, ".pdf"), width=4, height=4)
print(p.all.TFMotifEnrichment)
ggsave(paste0(plotFilename, ".eps"))
dev.off()


## ## cluster theta.zscore across topics
## ## old 211115
## theta.zscore.df.wide <- theta.zscore.df %>% mutate(K_Factor = paste0("K",K,"_",Factor), Gene = rownames(.)) %>% select(Gene, weight, K_Factor) %>% spread(key = "K_Factor", value = "weight")
## write.table(theta.zscore.df.wide, paste0(OUTDIR, "topic.zscore.Pearson.corr.txt"), row.names=F, quote=F, sep="\t")
## theta.zscore.df.wide.mtx <- theta.zscore.df.wide %>% `rownames<-`(.$Gene) %>% select(-Gene) %>% as.matrix()
## d <- cor(theta.zscore.df.wide.mtx, method="pearson")
## m <- as.matrix(d)

d <- cor(theta.zscore.df, method="pearson")
m <- as.matrix(d)

## Function for plotting heatmap  # new version (adjusted font size)
plotHeatmap <- function(mtx, labCol, title, margins=c(12,6), ...) { #original
  heatmap.2(
    mtx %>% t(), 
    Rowv=T, 
    Colv=T,
    trace='none',
    key=T,
    col=palette,
    labCol=labCol,
    margins=margins, 
    cex.main=0.8, 
    cexCol=0.01 * ncol(mtx), cexRow=0.01 * ncol(mtx), #4.8/sqrt(nrow(mtx))
    ## cexCol=1/(ncol(mtx)^(1/3)), cexRow=1/(ncol(mtx)^(1/3)), #4.8/sqrt(nrow(mtx))
    main=title,
    ...
  )
}


pdf(file=paste0(FIGDIR,"/cluster.topic.zscore.by.Pearson.corr.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.zscore.by.Pearson.corr.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()

## higher values of K
index <- which(theta.zscore.df %>% colnames() %>% strsplit(split="_") %>% sapply("[[",1) %>% gsub("K","",.) >= 30)
d <- cor(theta.zscore.df[index,index], method="pearson")
m <- as.matrix(d)
pdf(file=paste0(FIGDIR,"/cluster.topic.zscore.by.Pearson.corr.K30_higher.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.zscore.by.Pearson.corr.K30_higher.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()

## set correlation < 0.5 to zero to expand the color range
m[m<0.5] <- 0
pdf(file=paste0(FIGDIR,"/cluster.topic.zscore.by.Pearson.corr.K30_higher.threshold_cor_0.5.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.zscore.by.Pearson.corr.K30_higher.threshold_cor_0.5.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic zscore clustering by Pearson Correlation"))
dev.off()



## topic defined by raw weights
d <- cor(theta.raw.df, method="pearson")
m <- as.matrix(d)

pdf(file=paste0(FIGDIR,"/cluster.topic.raw.by.Pearson.corr.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.raw.by.Pearson.corr.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()


## higher values of K
index <- which(theta.raw.df %>% colnames() %>% strsplit(split="_") %>% sapply("[[",1) %>% gsub("K","",.) >= 30)
d <- cor(theta.raw.df[index,index], method="pearson")
m <- as.matrix(d)
pdf(file=paste0(FIGDIR,"/cluster.topic.raw.by.Pearson.corr.K30_higher.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.raw.by.Pearson.corr.K30_higher.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()

## set correlation < 0.5 to zero to expand the color range
m[m<0.5] <- 0
pdf(file=paste0(FIGDIR,"/cluster.topic.raw.by.Pearson.corr.K30_higher.threshold_cor_0.5.pdf"), width=30, height=30)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()
png(file=paste0(FIGDIR, "/cluster.topic.raw.by.Pearson.corr.K30_higher.threshold_cor_0.5.png"), width=3000, height=3000)
plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, topic raw weight clustering by Pearson Correlation"))
dev.off()


##########################################################################################
## Variance Explained Plots
##########################################################################################
toplot <- varianceExplainedByModel.df
p <- toplot %>% ggplot(aes(x=K, y=Total)) + geom_point(size=0.5) + geom_line(size=0.5) + mytheme +
    xlab("K") + ylab("Fraction of Variance\nExplained by the Model")
filename <- paste0(FIGDIR, "/variance.explained.by.model")
pdf(paste0(filename, ".pdf"), width=3, height=1.5)
print(p)
dev.off()
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
packages <- c("optparse","dplyr", "ggplot2", "reshape2", "ggrepel", "conflicted", "Seurat", "SeuratObject")
## library(Seurat)
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")

option.list <- list(
  make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/figures/all_genes/", help="Figure directory"),
  make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"),
  make_option("--inputSeuratObject", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2105_FT005_Analysis/outputs/FT005_gex/withUMAP.SeuratObject.RDS", help="Path to the Seurat Object"),
  # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
  make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
  # make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210625_snakemake_output/top3000VariableGenes_acrossK/2kG.library/", help="Topic model results directory"),
  make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
  # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
  # make_option("--K.list", type="character", default="2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,19,21,23,25", help="K values available for analysis"),
  make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
  make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
  make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
  # make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
  make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
  # make_option("--raw.mtx.dir",type="character",default="stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
  # make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"), # the first lane: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.expandedMultiTargetGuide.RDS"
  # make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),
  # make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/barcodes.tsv", help="barcodes.tsv for all cells"),
  # make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),

  ## fisher motif enrichment
  ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
  ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
  # make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
  # make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
  # make_option("--enhancer.fimo.threshold", type="character", default="1.0E-4", help="Enhancer fimo motif match threshold"),

  #summary plot parameters
  make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
  make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
  make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))

## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60
## opt$inputSeuratObject <- paste0(opt$datadir,"/", SAMPLE, "_SeuratObjectUMAP.rds")


mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))

SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
# STATIC.SAMPLE=c("Telo_no_IL1B_T200_1", "Telo_no_IL1B_T200_2", "Telo_plus_IL1B_T200_1", "Telo_plus_IL1B_T200_2", "no_IL1B", "plus_IL1B",  "pooled")
# DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
## TMDIR=opt$topic.model.result.dir
## SEP=opt$sep
# K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr

## ## directories for factor motif enrichment
## FILENAME=opt$filename


## ## modify motif.enhancer.background input directory ##HERE: perhaps do a for loop for all the desired thresholds (use strsplit on enhancer.fimo.threshold)
## opt$motif.enhancer.background <- paste0(opt$motif.enhancer.background, opt$enhancer.fimo.threshold, "/fimo.formatted.tsv")


# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE, FGSEADIR, FGSEAFIG)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
## selected.gene <- c("EDN1", "NOS3", "TP53", "GOSR2", "CDKN1A")
# ABC genes
# gene.set <- c("INPP5B", "SF3A3", "SERPINH1", "NR2C1", "FGD6", "VEZT", "SMAD3", "AAGAB", "GOSR2", "ATP5G1", "ANGPTL4", "SRBD1", "PRKCE", "DAGLB") # ABC_0.015_CAD_pp.1_genes #200 gene library

# # cell cycle genes
# ## need to update these for 2kG library
# gene.list.three.groups <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/ptbd.genes_three.groups.txt"), header=T, stringsAsFactors=F)
# enhancer.set <- gene.list.three.groups$Gene[grep("E_at_", gene.list.three.groups$Gene)]
# CAD.focus.gene.set <- gene.list.three.groups %>% subset(Group=="CAD_focus") %>% pull(Gene) %>% append(enhancer.set)
# EC.pos.ctrl.gene.set <- gene.list.three.groups %>% subset(Group=="EC_pos._ctrls") %>% pull(Gene)

# cell.count.thr <- opt$cell.count.thr # greater than this number, filter to keep the guides with greater than this number of cells
# guide.count.thr <- opt$guide.count.thr # greater than this number, filter to keep the perturbations with greater than this number of guides

# guide.design = read.delim(file=paste0(DATADIR, "/200607_ECPerturbSeqMiniPool.design.txt"), header=T, stringsAsFactors = F)


# ## add GO pathway log2FC
# GO <- read.delim(file=paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.table.brief.txt"), header=T, check.names=FALSE)
# GO.list <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/GO.Pathway.list.brief.txt", header=T, check.names=F)
# colnames(GO)[1] <- "Gene"
# colnames(GO.list)[1] <- "Gene"
# ## load all sample, K, topic's top 100 genes (by TopFeatures() KL-score measure)
# ## allGeneKtopic100 <- read.delim(paste0(TMDIR, "no.plus.pooled.top100.topicStats.txt"), header=T)
# # load non-expressed control gene list
# non.expressed.genes <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/non.expressed.ctrl.genes.txt", header=F, stringsAsFactors=F) %>% unlist %>% as.character() %>% sort()

# # perturbation type list
# gene.set.type.df <- data.frame(Gene=guide.design %>% pull(guideSet) %>% unique(),
#                                type=rep("other", guide.design %>% pull(guideSet) %>% unique() %>% length())) 
# gene.set.type.df$Gene <- gene.set.type.df$Gene %>% as.character()
# gene.set.type.df$type <- gene.set.type.df$type %>% as.character()
# gene.set.type.df$type[which(gene.set.type.df$Gene %in% non.expressed.genes)] <- "non-expressed"
# gene.set.type.df$type[which(gene.set.type.df$Gene %in% CAD.focus.gene.set)] <- "CAD focus"
# gene.set.type.df$type[grepl("^safe|^negative", gene.set.type.df$Gene)] <- "negative-control"
# gene.set.type.df$Gene[which(gene.set.type.df$Gene == "negative_control")] <- "negative-control"
# gene.set.type.df$Gene[which(gene.set.type.df$Gene == "safe_targeting")] <- "safe-targeting"
# # gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"

# gene.set.type.df.200 <- gene.set.type.df

# # reference table
# ref.table <- read_xlsx(opt$reference.table, sheet="2000_gene_library_annotated") 
# gene.set.type.df <- ref.table %>% select(Symbol, `Class(es)`) %>% `colnames<-`(c("Gene", "type"))
# gene.set.type.df$type[grepl("EC_ctrls", gene.set.type.df$type)] <- "EC_ctrls"
# gene.set.type.df$type[grepl("NonExpressed", gene.set.type.df$type)] <- "non-expressed"
# gene.set.type.df$type[grepl("abc.015", gene.set.type.df$type)] <- "ABC"
# gene.set.type.df <- rbind(gene.set.type.df, c("negative-control", "negative-control"), c("safe-targeting", "safe-targeting"))
# non.expressed.genes <- gene.set.type.df %>% subset(type == "non-expressed") %>% pull(Gene)
# # ABC genes
# gene.set <- gene.set.type.df %>% subset(grepl("ABC", type)) %>% pull(Gene)

# ## add GWAS classification
# modified.ref.table <- ref.table %>% mutate(GWAS.classification="")
# CAD.index <- which(grepl("CAD_Loci",ref.table$`Class(es)`))
# EC_ctrls.index <- which(grepl("^EC_ctrls",ref.table$`Class(es)`))
# ABC_linked.index <- which(grepl("MIG_etc",ref.table$`Class(es)`))
# IBD.index <- which(grepl("Non-CAD_loci_IBD",ref.table$`Class(es)`))
# non.expressed.index <- which(grepl("NonExpressed",ref.table$`Class(es)`))
# poorly.annotated.9p21.index <- which(grepl("9p21",ref.table$`Class(es)`))
#                                         # length(CAD.index) + length(EC_ctrls.index) + length(ABC_linked.index) + length(IBD.index) + length(non.expressed.index) + length(poorly.annotated.9p21.index)
# modified.ref.table$GWAS.classification[ABC_linked.index] <- "ABC"
# modified.ref.table$GWAS.classification[IBD.index] <- "IBD"
# modified.ref.table$GWAS.classification[non.expressed.index] <- "NonExpressed"
# modified.ref.table$GWAS.classification[poorly.annotated.9p21.index] <- "9p21.poorly.annotated"
# modified.ref.table$GWAS.classification[EC_ctrls.index] <- "EC_ctrls"
# modified.ref.table$GWAS.classification[CAD.index] <- "CAD"

# modified.ref.table <- modified.ref.table %>% group_by(GWAS.classification) %>% mutate(gene.count.per.GWAS.category = n())
# ref.table <- modified.ref.table

# ## add TSS distance to SNP
# modified.ref.table <- ref.table %>% mutate(TSS.dist.to.SNP = abs(`TSS v. SNP loc`))
# not.in.SNP.index <- which(is.na(modified.ref.table$`TSS v. SNP loc`))
# modified.ref.table$TSS.dist.to.SNP[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table %>% ungroup()

# ## add closest gene to top GWAS loci ranking
# modified.ref.table <- ref.table %>%
#     group_by(`Top SNP ID`) %>% # per SNP metrics
#     arrange(abs(`TSS v. SNP loc`)) %>%
#     mutate(TSS.v.SNP.ranking = 1:n(),
#            total.gene.in.this.loci = n()) %>% ungroup() %>%
#     group_by(`Top SNP ID`, GWAS.classification) %>% # per SNP per GWAS class (CAD, IBD, NonExpressed, ABC, 9p21.poorly.annotated) 
#     arrange(abs(`TSS v. SNP loc`)) %>%
#     mutate(TSS.v.SNP.ranking.in.GWAS.category = 1:n(),
#            total.gene.in.this.loci.in.GWAS.category = n()) %>% ungroup()
# not.in.SNP.index <- which(is.na(modified.ref.table$`TSS v. SNP loc`))
# modified.ref.table$TSS.v.SNP.ranking.in.GWAS.category[not.in.SNP.index] <- NA
# modified.ref.table$TSS.v.SNP.ranking[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table

# ## add gene count per distance ranking per GWAS loci
# modified.ref.table <- ref.table
# modified.ref.table <- modified.ref.table %>%
#     group_by(TSS.v.SNP.ranking) %>% # per ranking, not considering which GWAS category the gene is from
#     mutate(total.TSS.v.SNP.ranking.count = n()) %>% ungroup() %>%
#     group_by(GWAS.classification, TSS.v.SNP.ranking.in.GWAS.category) %>% # per GWAS category and per ranking
#     mutate(total.TSS.v.SNP.ranking.count.per.GWAS.classification = n()) %>% ungroup()
# not.in.SNP.index <- which(is.na(modified.ref.table$TSS.v.SNP.ranking))
# modified.ref.table$total.TSS.v.SNP.ranking.count[not.in.SNP.index] <- NA
# modified.ref.table$total.TSS.v.SNP.ranking.count.per.GWAS.classification[not.in.SNP.index] <- NA
# ref.table <- modified.ref.table

# write.table(ref.table, file=paste0(opt$datadir, "/ref.table.txt"), row.names=F, quote=F, sep="\t")

# ## ref.table ranking count summary table
# ref.table.gene.to.SNP.dist.ranking.count.summary.allGWAS <- ref.table %>% select(TSS.v.SNP.ranking, total.TSS.v.SNP.ranking.count) %>% mutate(GWAS.classification="all") %>% unique()
# ref.table.gene.to.SNP.dist.ranking.count.summary.indGWAS <- ref.table %>% select(TSS.v.SNP.ranking.in.GWAS.category, total.TSS.v.SNP.ranking.count.per.GWAS.classification, GWAS.classification) %>% `colnames<-`(c("TSS.v.SNP.ranking", "total.TSS.v.SNP.ranking.count", "GWAS.classification")) %>% unique()
# ref.table.gene.to.SNP.dist.ranking.count.summary <- rbind(ref.table.gene.to.SNP.dist.ranking.count.summary.allGWAS, ref.table.gene.to.SNP.dist.ranking.count.summary.indGWAS)
# ref.table.summary.na.index <- which(is.na(ref.table.gene.to.SNP.dist.ranking.count.summary$TSS.v.SNP.ranking))
# ref.table.gene.to.SNP.dist.ranking.count.summary <- ref.table.gene.to.SNP.dist.ranking.count.summary[-ref.table.summary.na.index,]
# rm(ref.table.summary.na.index)


# # convert enhancer SNP rs number to enhancer target gene name # need 2kG library version
# enh.snp.to.gene <- read.delim(paste0(DATADIR, "/enhancer.SNP.to.gene.name.txt"), header=T, stringsAsFactors = F) %>% mutate(Enhancer_name=gsub("_","-", Enhancer_name))

# # gene corresponding pathway
# gene.def.pathways <- read_excel(paste0(DATADIR,"topic.gene.definition.pathways.xlsx"), sheet="Gene_Pathway")

# ## Gavin's new list
# gene.classes.ranked <- read.table(paste0(opt$datadir, "Gene_Classes_Ranked_for_CAD_n_EC.txt"), header=T, stringsAsFactors = F)
# summaries <- read.delim(paste0(opt$datadir, "Gene_Summaries_n_Classes.txt"), sep="\t", header=T, stringsAsFactors = F)
# gene.summaries <- read_xlsx(paste0(opt$datadir, "Gene_Summaries.xlsx"), sheet="uniprot_summaries")


# ## Perturbation name and 10X gene name conversion table
# ptb.10X.name.conversion <- read_xlsx(paste0(opt$datadir, "Perturbation 10X names.xlsx"))

# ## EdgeR log2fcs and p-values
# log2fc.edgeR <- read.table(paste0(opt$datadir, "/EdgeR/ALL_log2fcs_dup4_s4n3.99x.txt"), header=T, stringsAsFactors=F)
# p.value.edgeR <- read.table(paste0(opt$datadir, "/EdgeR/ALL_Pvalues_dup4_s4n3.99x.txt"), header=T, stringsAsFactors=F)

# print("loaded all prerequisite data")




######################################################################
## Process topic model results
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
} else {
    warning(paste0(cNMF.result.file, " does not exist"))
}

## load ann.omega
file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.",SUBSCRIPT,".RData")
print(file.name) 
if(file.exists((file.name))) { 
    print(paste0("loading ",file.name))
    load(file.name) 
}



# ## load statistical test data
# toSave.features <- read.delim(paste0(OUTDIRSAMPLE, "/topic.KL.score_K", k, ".dt_", DENSITY.THRESHOLD, ".txt"), header=T, stringsAsFactors=F)
# all.test <- read.delim(file=paste0(OUTDIRSAMPLE, "/all.test.", SUBSCRIPT, ".txt"), header=T, stringsAsFactors=F) ##here
# realPvals.df <- read.delim(file=paste0(OUTDIRSAMPLE, "/all.expressed.genes.pval.fdr.",SUBSCRIPT,".txt"), header=T, stringsAsFactors=F)
# all.test.guide.w <- all.test %>% subset(test.type==opt$test.type)
# realPvals.df.guide.w <- realPvals.df %>% subset(test.type==opt$test.type)
# fdr.thr <- 0.1
# topFeatures.raw.weight <- theta.zscore %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene", variable.name="topic", value.name="scores") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50)


## load Seurat Object with UMAP


# ## load full matrix
# file.name=paste0(opt$raw.mtx.RDS.dir, "_modified.multiTargetGuide.cell.names.RDS")
# fc.file.name=paste0(opt$raw.mtx.RDS.dir, "_FC_modified.multiTargetGuide.cell.names.RDS")
# log2fc.file.name=paste0(opt$raw.mtx.RDS.dir, "_log2FC_modified.multiTargetGuide.cell.names.RDS")
# if(file.exists(file.name) & file.exists(fc.file.name) & file.exists(log2fc.file.name)) {
#     print(paste0("Loading ", file.name))
#     X.full = readRDS(file.name)
#     print(paste0("Loading ", fc.file.name))
#     fc.X.full = readRDS(fc.file.name)
#     print(paste0("Loading ", log2fc.file.name))
#     log2fc.X.full = readRDS(log2fc.file.name)##here210813
# } else {
#     warning(paste0("full scRNA-seq matrix does not exist"))
# }


# ## modify X.full
# ## old 210819
# tokeep.index <- which(rownames(X.full) %in% ann.omega.filtered$long.CBC) ## need to expand X.full to match ann.omega.filtered due to guides that target two promoters
# fc.X.full <- fc.X.full[tokeep.index,] 
# X.full <- X.full[tokeep.index,] 
# colnames(X.full) <- colnames(fc.X.full) <- colnames(fc.X.full)  %>% strsplit(., split=":") %>% sapply("[[",1)

# ## ## adjust colnames, remove ENSG number
# ann.X.full.filtered <- X.full
# ## add back ENSG names?
# ## tmp <- colnames(ann.X.full.filtered) %>% strsplit(., split=":") %>% sapply("[[",1)
# ## tmpp <- data.frame(table(tmp)) %>% subset(Freq > 1)  # keep row names that have duplicated gene names but different ENSG names
# ## tmp.copy <- tmp
# ## tmp.copy[grepl(paste0(tmpp$tmp,collapse="|"),tmp)] <- colnames(ann.X.full.filtered)[grepl(paste0(tmpp$tmp,collapse="|"), colnames(ann.X.full.filtered))]
# ## colnames(ann.X.full.filtered) <- tmp.copy # the above section takes a while
# ## end of old 210819

# ## get ctrl log2 transformed expression
# tokeep.index <- which(grepl("control|targeting",rownames(X.full)))
# ## ctrl.X <- ann.X.full.filtered %>% subset(grepl("control|targeting",Gene))
# ctrl.X <- X.full[tokeep.index,]
# fc.ctrl.X <- fc.X.full[tokeep.index,]
# ## get ctrl topic weights
# ctrl.ann.omega <- ann.omega.filtered %>% subset(grepl("control|targeting",Gene)) %>% `rownames<-`(.$long.CBC)
# X.gene.names <- rownames(X.full) %>% strsplit(., split=":") %>% sapply("[[",1) %>% gsub("_multiTarget|-TSS2","",.)

## load motif enrichment results
# file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.factorMotifEnrichment.",SUBSCRIPT.SHORT,".RData")
# print(file.name)
# if(file.exists((file.name))) { 
#     load(file.name)
#     print(paste0("loading ", file.name))
# }
# motif.enrichment.variables <- c("all.enhancer.fisher.df", "all.promoter.fisher.df", 
#                                 "promoter.wide", "enhancer.wide", "promoter.wide.binary", "enhancer.wide.binary",
#                                 "enhancer.wide.10en6", "enhancer.wide.binary.10en6", "all.enhancer.fisher.df.10en6",
#                                 "promoter.wide.10en6", "promoter.wide.binary.10en6", "all.promoter.fisher.df.10en6",
#                                 "all.promoter.ttest.df", "all.promoter.ttest.df.10en6", "all.enhancer.ttest.df", "all.enhancer.ttest.df.10en6")
# motif.enrichment.variables.missing <- (!(motif.enrichment.variables %in% ls())) %>% as.numeric %>% sum 
# if ( motif.enrichment.variables.missing > 0 ) {
#     warning(paste0(motif.enrichment.variables[!(motif.enrichment.variables %in% ls())], " not available"))
# }

# ## load count.by.GWAS
# count.by.GWAS <- read.delim(file=paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt"), header=T, stringsAsFactors = F)
# count.by.GWAS.withTopic <- read.delim(file=paste0(OUTDIRSAMPLE,"/count.by.GWAS.classes.withTopic_p.adj.",p.value.thr %>% as.character,"_",SUBSCRIPT,".txt"), header=T, stringsAsFactors=F)


# ## load UMAP data
# file.name <- paste0(OUTDIRSAMPLE, "gene.score.SeuratObject.RDS")
# if(file.exists(file.name)) {
#     s.gene.score <- readRDS(file.name)
# } else {
#     warning(paste0(file.name, " does not exist"))
# } 

## full UMAP
file.name <- opt$inputSeuratObject
# file.name <- paste0(opt$datadir,"/", SAMPLE, "_SeuratObjectUMAP.rds") ## todo: change to calcUMAP output
## subset.file.name <- paste0(opt$datadir,"/", SAMPLE, "_subset_SeuratObjectUMAP.rds")
options(future.globals.maxSize=1000*1024^2)
s <- readRDS(file.name)
## [email protected] <- [email protected][,-which(grepl("topic",colnames([email protected])))]
s <- AddMetaData(s, metadata = omega, col.name = paste0("K",k,"_",colnames(omega)))


## End of data loading



##########################################################################
## Plots


# ##########################################################################
# ## topic gene z-score list
# pdf(file=paste0(FIGDIRTOP,"top50GeneInTopics.zscore.pdf"), width=4, height=6)
# topFeatures.raw.weight <- theta.zscore %>% as.data.frame() %>% mutate(Gene=rownames(.)) %>% melt(id.vars="Gene", variable.name="topic", value.name="scores") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50)
# for ( t in 1:dim(theta)[2] ) {
#     toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
#                          Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores))
#     p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100) ) + geom_col() + theme_minimal()
#     p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("z-score (Specificity)") + ggtitle(paste(SAMPLE, ", Topic ", t, sep="")) + mytheme
#     print(p)
# }
# dev.off()


## ##########################################################################
## ## top expressed genes per topic by KL specificity score list
## pdf(file=paste0(FIGDIRTOP, "topGeneInTopics.KL.pdf"), width=4, height=6)
## topFeatures <- toSave.features %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50)
## for ( t in 1:dim(theta)[2] ) {
##     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
##                          Score=topFeatures %>% subset(topic == t) %>% pull(scores))
##     p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100) ) + geom_col() + theme_minimal()
##     p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("KL score (gene specific to this topic)") + ggtitle(paste(SAMPLE, ", K = ", k, ", Topic ", t, sep="")) + mytheme
##     print(p)
## }
## dev.off()


# ##########################################################################
# ## Topic's top gene list, ranked by raw weight
# pdf(file=paste0(FIGDIRTOP,"top50GeneInTopics.rawWeight.pdf"), width=4, height=6)
# topFeatures <- theta %>% as.data.frame() %>% mutate(genes=rownames(.)) %>% melt(id.vars="genes",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:50)
# for ( t in 1:dim(theta)[2] ) {
#     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
#                          Score=topFeatures %>% subset(topic == t) %>% pull(scores))
#     p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100) ) + geom_col() + theme_minimal()
#     p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("Raw Score (gene's weight in topic)") + ggtitle(paste(SAMPLE, ", Topic ", t, sep="")) + mytheme
#     print(p)
# }
# dev.off()


## ##########################################################################
## ## KL score list with annotataion
## pdf(file=paste0(FIGDIRTOP,"topGeneInTopics.annotated.KL.pdf"), width=4.5, height=5)
## topFeatures <- toSave.features %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10)
## for ( t in 1:dim(theta)[2] ) {
##     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
##                          Score=topFeatures %>% subset(topic == t) %>% pull(scores)) %>%
##         merge(., gene.def.pathways, by="Gene", all.x=T)
##     toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
##     p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100, fill=Pathway) ) + geom_col(width=0.5) + theme_minimal()
##     p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("KL score (gene specific to this topic)") +
##         mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
##     print(p4)
## }
## dev.off()


# ##########################################################################
# ## raw program TPM list with annotataion
# pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.TPM.pdf"), width=4.5, height=5)
# topFeatures <- theta %>% as.data.frame() %>% mutate(genes=rownames(.)) %>% melt(id.vars="genes",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10)
# for ( t in 1:dim(theta)[2] ) {
#     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
#                          Score=topFeatures %>% subset(topic == t) %>% pull(scores)) # %>%
#         ## merge(., gene.def.pathways, by="Gene", all.x=T)
#     ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#     p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*1000000) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
#     p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("Raw Weight (in TPM)") +
#         mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
#     print(p4)
# }
# dev.off()



# ##########################################################################
# ## raw program zscore list (top 10)  (can potentially include annotation)
# pdf(file=paste0(FIGDIRTOP,"top10GeneInTopics.zscore.pdf"), width=4.5, height=5)
# topFeatures <- theta.zscore %>% as.data.frame() %>% mutate(genes=rownames(.)) %>% melt(id.vars="genes",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10)
# for ( t in 1:dim(theta)[2] ) {
#     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
#                          Score=topFeatures %>% subset(topic == t) %>% pull(scores)) # %>%
#         ## merge(., gene.def.pathways, by="Gene", all.x=T)
#     ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#     p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
#     p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("z-score") +
#         mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
#     print(p4)
# }
# dev.off()



## ##########################################################################
## ## raw program zscore list without annotation
## pdf(file=paste0(FIGDIRTOP,"topGeneInTopics.shortList.zscore.pdf"), width=3.5, height=4)
## topFeatures <- theta.zscore %>% as.data.frame() %>% mutate(genes=rownames(.)) %>% melt(id.vars="genes",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10)
## for ( t in 1:dim(theta)[2] ) {
##     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
##                          Score=topFeatures %>% subset(topic == t) %>% pull(scores))
##     p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(width=0.5, fill="#38b4f7") + theme_minimal()
##     p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("z-score") +
##         mytheme + theme(legend.position="bottom", legend.direction="vertical", text=element_text(size=16), plot.title=element_text(size=12)) + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
##     print(p4)
## }
## dev.off()





# ##########################################################################
# ## Perturbation zscore list with annotataion
# pdf(file=paste0(FIGDIRTOP,"Perturbation_zscore.annotated.pdf"), width=4.5, height=5)
# topFeatures <- ptb.zscore %>% as.data.frame() %>% mutate(genes=rownames(.)) %>% melt(id.vars="genes",value.name="scores", variable.name="topic") %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10) %>% mutate(topic = gsub("topic_","", topic))
# for ( t in 1:dim(theta)[2] ) {
#     toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
#                          Score=topFeatures %>% subset(topic == t) %>% pull(scores)) %>%
#         merge(., gene.def.pathways, by="Gene", all.x=T)
#     toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#     p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score, fill=Pathway) ) + geom_col(width=0.5) + theme_minimal()
#     p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("Perturbation z-score") +
#         mytheme + theme(legend.position="bottom", legend.direction="vertical") + ggtitle(paste0(SAMPLE, ", K = ", k, ", Topic ", t))
#     print(p4)
# }
# dev.off()

# pdf(file=paste0(FIGDIRTOP, "Perturbation.zscore.sig.list.pdf"), width=6, height=6)
# ptb.zscore.long <- ptb.zscore %>% as.data.frame %>% mutate(Gene=rownames(.)) %>% melt(id.vars = "Gene", value.name = "perturbation.zscore", variable.name = "Topic")
# for ( topic in colnames(ptb.zscore) ) {
#     t <- gsub("topic_","",topic)
#     toPlot.all.test <- all.test %>% subset(test.type=="per.cell.wilcoxon" & Topic==topic)
#     toPlot.fdr <- realPvals.df %>% subset(test.type=="per.cell.wilcoxon" & Topic == topic) %>% select(Gene,fdr)##here210809
#                                         # assemble toPlot
#     toPlot <- ptb.zscore.long %>% subset(Topic == topic) %>% merge(.,toPlot.all.test,by=c("Gene","Topic"), all.x=T) %>%
#         merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#         merge(.,gene.set.type.df,by="Gene", all.x=T) %>% ##here210809
#         ## merge(.,gene.def.pathways, by="Gene", all.x=T) %>% 
#         merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#         mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
#         mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                  ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#         mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))

#     toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*"))

#     toPlot.top <- toPlot %>% arrange(desc(perturbation.zscore)) %>% slice(1:25)
#     toPlot.bottom <- toPlot %>% arrange(perturbation.zscore) %>% slice(1:25)
#     toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#         mutate(color=ifelse(grepl("CAD", type), "red",
#                      ifelse(type=="non-expressed", "gray",
#                      ifelse(type=="EC_ctrls", "blue", "black"))))  %>%
#         mutate(color=ifelse(is.na(type), "black", color))
#     ## colors <- toPlot.extreme$color[order(toPlot.extreme %>% arrange(desc(perturbation.zscore)) %>% pull(color))]
#     toPlot.extreme <- toPlot.extreme %>% arrange(perturbation.zscore)
#     ## add gene distance to CAD
#     toPlot.extreme$ann.Gene <- factor(toPlot.extreme$ann.Gene, levels = toPlot.extreme$ann.Gene)
#     p <- toPlot.extreme %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#         ggplot(aes(x=ann.Gene, y=perturbation.zscore, fill=significant)) + geom_col() + theme_minimal() +
#         coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("Perturbation z-score") + ggtitle(paste(SAMPLE, " perturbations, ", topic)) +
#         scale_fill_manual(values=c("grey", "#38b4f7")) +
#         geom_text(aes(label = significant)) +
#         theme(legend.position = "none", axis.text.y = element_text(colour = toPlot.extreme$color))
#     print(p)
# }
# dev.off()




# pdf(file=paste0(FIGDIRTOP, "Perturbation.zscore.sig.shortList.pdf"), width=6, height=6)
# ptb.zscore.long <- ptb.zscore %>% as.data.frame %>% mutate(Gene=rownames(.)) %>% melt(id.vars = "Gene", value.name = "perturbation.zscore", variable.name = "Topic")
# for ( topic in colnames(ptb.zscore) ) {
#     t <- gsub("topic_","",topic)
#     toPlot.all.test <- all.test %>% subset(test.type=="per.cell.wilcoxon" & Topic==topic)
#     toPlot.fdr <- realPvals.df %>% subset(test.type=="per.cell.wilcoxon" & Topic == topic) %>% select(Gene,fdr)##here210809
#                                         # assemble toPlot
#     toPlot <- ptb.zscore.long %>% subset(Topic == topic) %>% merge(.,toPlot.all.test,by=c("Gene","Topic"), all.x=T) %>%
#         merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#         merge(.,gene.set.type.df,by="Gene", all.x=T) %>% ##here210809
#         ## merge(.,gene.def.pathways, by="Gene", all.x=T) %>% 
#         merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#         mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
#         mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                  ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#         mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))

#     toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*"))

#     toPlot.top <- toPlot %>% arrange(desc(perturbation.zscore)) %>% slice(1:10)
#     toPlot.bottom <- toPlot %>% arrange(perturbation.zscore) %>% slice(1:10)
#     toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#         mutate(color=ifelse(grepl("CAD", type), "red",
#                      ifelse(type=="non-expressed", "gray",
#                      ifelse(type=="EC_ctrls", "blue", "black"))))  %>%
#         mutate(color=ifelse(is.na(type), "black", color))
#     ## colors <- toPlot.extreme$color[order(toPlot.extreme %>% arrange(desc(perturbation.zscore)) %>% pull(color))]
#     toPlot.extreme <- toPlot.extreme %>% arrange(perturbation.zscore)
#     ## add gene distance to CAD
#     toPlot.extreme$ann.Gene <- factor(toPlot.extreme$ann.Gene, levels = toPlot.extreme$ann.Gene)
#     p <- toPlot.extreme %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#         ggplot(aes(x=ann.Gene, y=perturbation.zscore, fill=significant)) + geom_col() + theme_minimal() +
#         coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("Perturbation z-score") + ggtitle(paste(SAMPLE, " perturbations, ", topic)) +
#         scale_fill_manual(values=c("grey", "#38b4f7")) +
#         geom_text(aes(label = significant)) +
#         theme(legend.position = "none", axis.text.y = element_text(colour = toPlot.extreme$color))
#     print(p)
# }
# dev.off()



# ## volcano plots
# volcano.plot <- function(toplot, ep.type, ranking.type, label.type="") {
#     if( label.type == "pos") {
#         label <- toplot %>% subset(-log10(p.adjust) > 1 & enrichment.log2fc > 0) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
#     } else {
#         label <- toplot %>% subset(-log10(p.adjust) > 1) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
#     }
#     t <- gsub("topic_", "", toplot$topic[1])
#     p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(p.adjust))) + geom_point(size=0.5) + mytheme +
#         ggtitle(paste0(SAMPLE[1], " Topic ", t, " Top 100 ", ranking.type," ", ifelse(ep.type=="promoter", "Promoter", "Enhancer"), " Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(adjusted p-value)") +
#         geom_text_repel(data=label, box.padding = 0.5,
#                         aes(label=motif.toshow), size=5,
#                         color="black") + theme(text=element_text(size=16), axis.title=element_text(size=16), axis.text=element_text(size=16), plot.title=element_text(size=14))
#     print(p)
#     p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(p.value))) + geom_point(size=0.5) + mytheme +
#         ggtitle(paste0(SAMPLE[1], " Topic ", t, " Top 100 ", ranking.type," ", ifelse(ep.type=="promoter", "Promoter", "Enhancer"), " Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(p-value)") +
#         geom_text_repel(data=label, box.padding = 0.5,
#                         aes(label=motif.toshow), size=5,
#                         color="black") + theme(text=element_text(size=16), axis.title=element_text(size=16), axis.text=element_text(size=16), plot.title=element_text(size=14))
#     return(p)
# }

# ## function for all volcano plots
# all.volcano.plots <- function(all.fisher.df, ep.type, ranking.type, label.type="") {
#     for ( t in 1:k ){
#         toplot <- all.fisher.df %>% subset(topic==paste0("topic_",t))
#         volcano.plot(toplot, ep.type, ranking.type, label.type) %>% print()
#     }
# }

# if(opt$subsample.type!="ctrl") {

# ##########################################################################
# ## q-q plot
# pdf(file=paste0(FIGDIRTOP,"p-value.qqplot.pdf"), width=8, height=8)
# for (test.name in unique(all.test$test.type)) {
#     toPlot <- all.test %>% subset(test.type==test.name)
#     toPlot$p.value <- -1*log10(toPlot$p.value)
#   min.value <- min(toPlot %>% subset(gene.type=="expressed") %>% select(p.value) %>% unlist() %>% as.numeric(),
#                    toPlot %>% subset(gene.type=="non-expressed") %>% select(p.value) %>% unlist() %>% as.numeric())
#   max.value <- 10
#   toPlot$p.value[toPlot$p.value < 10^-10] <- 10^-10
#   exp.gene.count <- toPlot %>% subset(gene.type=="expressed") %>% select(Gene) %>% unique() %>% unlist() %>% length()
#   nonexp.gene.count <- toPlot %>% subset(gene.type=="non-expressed") %>% select(Gene) %>% unique() %>% unlist() %>% length()


#   match.df <- toPlot %>% subset(adjusted.p.value < 0.1 & gene.type=="expressed")
#   my.qqplot(y = toPlot %>% subset(gene.type=="expressed") %>% select(p.value) %>% unlist() %>% as.numeric(),
#             x = toPlot %>% subset(gene.type=="non-expressed") %>% select(p.value) %>% unlist() %>% as.numeric(),
#             xlimit = c(min.value, max.value), ylimit = c(min.value, max.value),
#             ylab = paste0("Expressed Genes (-log10 p-value, n=[", exp.gene.count, " genes])"),
#             xlab = paste0("Control - Non-Expressed Genes (-log10 p-value, n=[", nonexp.gene.count, " genes])"),
#             main = paste0(SAMPLE, ", K = ", k, ", ", test.name),
#             match=T, match.y=T, match.df=match.df
#   )

# }
# dev.off()


# ##########################################################################
# ## empirical fdr vs p.adjust (BH) plot
# pdf(paste0(FIGDIRTOP,"empirical.fdr.vs.p.adjust.pdf"), width=8, height=6)
# for (j in 1:length(unique(realPvals.df$test.type))) {
#   test.name<-unique(all.test$test.type)[j]
#   toPlot <- realPvals.df %>% subset(test.type==test.name)
#   p <- toPlot %>% ggplot(aes(x=adjusted.p.value, y=fdr)) + geom_point(size=0.1) + geom_abline(color="red") +
#     mytheme + xlab("Adjusted p-value") + ylab("Empirical False Discovery Rate") + ggtitle(paste0(SAMPLE, ", ", test.name)) + coord_fixed()
#   print(p)
# }
# dev.off()


# ##########################################################################
# ## empirical FDR heatmaps
# for(fdr.method in c("empirical.fdr", "p.adjust")) {
# pdf(file=paste0(FIGDIRTOP, fdr.method, ".sig.ptbd.gene_fill.log2fc_heatmap.pdf"), width=12, height=6)
# for (emp.fdr.thr in c(0.05, 0.1, 0.25)) {
#   for (current.test.type in realPvals.df$test.type %>% unique()) {
#       if(fdr.method == "empirical.fdr") {
#           test.list <- realPvals.df %>% subset(test.type==current.test.type) %>% mutate(adjusted.p.value=fdr)
#       } else {
#           test.list <- realPvals.df %>% subset(test.type==current.test.type)
#       }
#     genes.toInclude <- test.list %>% subset(adjusted.p.value < emp.fdr.thr) %>% pull(Gene) %>% unique()
#     toPlot <- gene.score %>% subset(., rownames(gene.score) %in% genes.toInclude)
#     toPlot <- add.snp.gene.info(toPlot, type="rownames")
#     # plot heatmap
#     cols <- rep('black', nrow(toPlot))
#     #turn red the specified rows in tf
#     cols[row.names(toPlot) %in% (gene.set.type.df %>% subset(grepl("EC_ctrls", type)) %>% pull(Gene))] <- "blue"
#     cols[row.names(toPlot) %in% (gene.set.type.df %>% subset(grepl("CAD_Loci_all", type)) %>% pull(Gene))] <- "red"
#     # cols[row.names(toPlot) %in% gene.set] <- "red"
#     rownames(toPlot)[row.names(toPlot) %in% gene.set] <- paste0("[ ", rownames(toPlot)[which(row.names(toPlot) %in% gene.set)], " ]")

#     if(nrow(toPlot) > 1) {
#         ## plotHeatmap( toPlot, cellNote=NULL, rownames(toPlot), title=title, colCol=cols)

#         toHighlight.asterisk <- merge.score.with.test(gene.score, test.list, test.col.name="p.value", p.value.thr=1, adj.p.value.thr=emp.fdr.thr, fill.all=T, fill="", overlay=T)$score.mtx
#         title=paste0(SAMPLE,", K = ", k, ", ", current.test.type, ", \n", ifelse(fdr.method == "empirical.fdr", "empirical fdr", "BH adjusted p-value"), " < ", emp.fdr.thr, ", number of significant genes = ", nrow(toPlot)) 
#       plotHeatmap( toPlot, cellNote=toHighlight.asterisk, rownames(toPlot), title=title, colCol=cols)

#       toHighlight.value <- merge.score.with.test(gene.score, test.list, test.col.name="p.value", p.value.thr=1, adj.p.value.thr=emp.fdr.thr, fill.all=T, fill="", overlay=F, num.thr = 1)$score.mtx
#       plotHeatmap( toPlot, cellNote=toHighlight.value, rownames(toPlot), title=title, colCol=cols)

#       toHighlight.value <- merge.score.with.test(gene.score, test.list, test.col.name="p.value", p.value.thr=1, adj.p.value.thr=emp.fdr.thr, fill.all=T, fill="", overlay=F)$score.mtx
#       plotHeatmap( toPlot, cellNote=toHighlight.value, rownames(toPlot), title=title, colCol=cols)
#     }
#   }
# }
# dev.off()
# }


#     ## ##########################################################################
#     ## full heatmap
#     pdf(file = paste0(FIGDIRTOP,"Gene.full.heatmap.pdf"), width=36, height=6) 

#     current.test.type <- "per.guide.wilcoxon"
#     toPlot <- gene.score
#     toPlot <- add.snp.gene.info(toPlot, type="rownames")
#     cols <- rep('black', nrow(toPlot))
#                                         #turn red the specified rows in tf
#     cols <- rep('black', nrow(toPlot))
#                                         #turn red the specified rows in tf
#     cols[row.names(toPlot) %in% (gene.set.type.df %>% subset(grepl("EC_ctrls", type)) %>% pull(Gene))] <- "blue"
#     cols[row.names(toPlot) %in% (gene.set.type.df %>% subset(grepl("CAD_Loci_all", type)) %>% pull(Gene))] <- "red"

#     rownames(toPlot)[row.names(toPlot) %in% gene.set] <- paste0("[ ", rownames(toPlot)[which(row.names(toPlot) %in% gene.set)], " ]")

#     title=paste0(SAMPLE,", K = ", k)
#     plotHeatmap(toPlot, cellNote=NULL, rownames(toPlot), title=title, colCol=cols)

#     for(emp.fdr.thr in c(0.05, 0.1, 0.25)){
#                                         # add asterisks for significant perturbation/topic
#                                         # test.list <- realPvals.df %>% subset(test.type==current.test.type) %>% mutate(adjusted.p.value=fdr)
#                                         # toHighlight.asterisk <- merge.score.with.test(gene.score, test.list, test.col.name="p.value", p.value.thr=1, adj.p.value.thr=emp.fdr.thr, fill.all=T, fill="", overlay=T)$score.mtx
#                                         # title=paste0(SAMPLE,", K = ", k, ", ", current.test.type, ", \nempirical fdr < ", emp.fdr.thr)
#                                         # plotHeatmap(toPlot, cellNote=toHighlight.asterisk, rownames(toPlot), title=title, colCol=cols)
#     }
#     dev.off()




#     ## ##########################################################################
#     ## Lists of genes or perturbations to understand topics
#     all.test.guide.w <- all.test %>% subset(test.type=="per.guide.wilcoxon")
#     realPvals.df.guide.w <- realPvals.df %>% subset(test.type=="per.guide.wilcoxon")
#                                         # make a plot for each topic, refer to TopFeatures list code
#                                         # Genes with the lowest empirical FDR
#     pdf(file=paste0(FIGDIRTOP, "top.perturbation.list_empirical.fdr.pdf"), width=4, height=6)
#     for ( topic in realPvals.df.guide.w$Topic %>% unique() ) {
#         toPlot <- realPvals.df.guide.w %>% subset(Topic == topic) %>% arrange(fdr) %>% slice(1:50) %>% mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene))
#         p <- toPlot %>% ggplot(aes(x=reorder(Gene, -fdr), y=fdr) ) + geom_col() + theme_minimal() +
#             coord_flip() + xlab("Top 50 Gene (Perturbation)") + ylab("Empirical FDR") + ggtitle(paste(SAMPLE, " perturbations, ", topic))
#         print(p)
#     }
#     dev.off()


# }

# ##########################################################################
# ## most extreme log2FC (use omega)
# ## # add ABC to gene.set.type.df for this particular plot
# ## gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"

# for (test.type.here in c("per.cell.wilcoxon", "per.guide.wilcoxon")) {
#     all.test.subset <- all.test %>% subset(test.type==test.type.here)
#     realPvals.df.subset <- realPvals.df %>% subset(test.type==test.type.here)
#     pdf(file=paste0(FIGDIRTOP, "top.perturbation.list_log2FC_highlight.", test.type.here, ".pdf"), width=4, height=6)
#     for ( topic in colnames(gene.score) ) { 
#         toPlot.all.test <- all.test.subset %>% subset(Topic==topic)
#         toPlot.fdr <- realPvals.df.subset %>% subset(Topic == topic) %>% select(Gene,fdr)
#         ## assemble toPlot
#         toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             merge(.,gene.set.type.df,by="Gene") %>%
#             mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene))
#         colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"
#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(fdr)), "", "*"))
#         toPlot.top <- toPlot %>% arrange(desc(log2FC)) %>% slice(1:25)
#         toPlot.bottom <- toPlot %>% arrange(log2FC) %>% slice(1:25)
#         toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#             mutate(color=ifelse(grepl("CAD",type), "red",
#                          ifelse(type=="non-expressed", "grey",
#                          ifelse(type=="other", "blue", "black")))) %>%
#             mutate(Gene = ifelse(type=="ABC", paste0("[ ", Gene, " ]"), Gene))
#         colors <- toPlot.extreme$color[order(toPlot.extreme %>% arrange(desc(log2FC)) %>% pull(color))]
#         p <- toPlot.extreme %>% arrange(desc(log2FC)) %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#             ggplot(aes(x=reorder(Gene, log2FC), y=log2FC, fill=significant)) + geom_col() + theme_minimal() +
#             coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("log2 Fold Change") + ggtitle(paste(SAMPLE, " perturbations, ", topic)) +
#             scale_fill_manual(values=c("grey", "#38b4f7")) +
#             geom_text(aes(label = significant)) +
#             theme(legend.position = "none", axis.text.y = element_text(colour = colors))
#         print(p)#here
#     }
#     dev.off()
# }


# ##########################################################################
# ## most extreme log2FC (use omega) short list
# ## # add ABC to gene.set.type.df for this particular plot
# ## gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"
# for (test.type.here in c("per.cell.wilcoxon", "per.guide.wilcoxon")) {
#     all.test.subset <- all.test %>% subset(test.type==test.type.here)
#     realPvals.df.subset <- realPvals.df %>% subset(test.type==test.type.here)
#     pdf(file=paste0(FIGDIRTOP, "top.perturbation.shortList_log2FC_highlight.", test.type.here, ".pdf"), width=4, height=4)
#     for ( topic in colnames(gene.score) ) {
#         toPlot.all.test <- all.test.subset %>% subset(Topic==topic)
#         toPlot.fdr <- realPvals.df.subset %>% subset(Topic == topic) %>% select(Gene,fdr)
#                                         # assemble toPlot
#         toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             merge(.,gene.set.type.df,by="Gene") %>%
#             mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene))
#         colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"
#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(fdr)), "", "*"))
#         toPlot.top <- toPlot %>% arrange(desc(log2FC)) %>% slice(1:11)
#         toPlot.bottom <- toPlot %>% arrange(log2FC) %>% slice(1:11)
#         toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#             mutate(color=ifelse(grepl("CAD",type), "red",
#                          ifelse(type=="non-expressed", "grey",
#                          ifelse(type=="other", "blue", "black"))))
#         colors <- toPlot.extreme$color[order(toPlot.extreme %>% arrange(desc(log2FC)) %>% pull(color))]
#         p <- toPlot.extreme %>% arrange(desc(log2FC)) %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#             ggplot(aes(x=reorder(Gene, log2FC), y=log2FC, fill=significant)) + geom_col() + theme_minimal() +
#             coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("log2 Fold Change") + ggtitle(paste(SAMPLE, " perturbations, ", topic)) +
#             scale_fill_manual(values=c("grey", "#38b4f7")) +
#             geom_text(aes(label = significant)) +
#             theme(legend.position = "none", text=element_text(size=14), plot.title=element_text(size=12))
#         print(p)#here
#     }
#     dev.off()
# }



# ##########################################################################
# ## enrichment by GWAS classification
# pdf(file=paste0(FIGDIRTOP,"_GWAS.class.enrichment.pdf"))
# test.condition.names <- count.by.GWAS$test.type %>% unique()
# for( i in 1:length(test.condition.names) ){
#     test.condition <- test.condition.names[i]
#     toPlot <- count.by.GWAS %>% subset(test.type==test.condition) %>% select(GWAS.class.enrichment,GWAS.classification,gene.count.per.GWAS.category) %>% unique()
#     p <- toPlot %>% ggplot(aes(x=GWAS.classification, y=GWAS.class.enrichment)) + geom_bar(stat="identity",width=0.5,fill = "#38b4f7") + mytheme +
#         ggtitle(paste0("Fraction of genes with significant topics,", test.condition)) + xlab("GWAS Classification") + ylab("Fraction of Perturbations") 
#     print(p)
#     p <- toPlot %>% ggplot(aes(x=GWAS.classification, y=gene.count.per.GWAS.category)) + geom_bar(stat="identity",width=0.5) + mytheme +
#         ggtitle(paste0("Number of genes with significant topics,", test.condition)) + xlab("GWAS Classification") + ylab("Number of Perturbations") 
#     print(p)
# }
# dev.off()


# ##########################################################################
# ## GWAS gene ranking plot
# pdf(file=paste0(FIGDIRTOP,"_GWAS.gene.rank.barplot.pdf"), width=8, height=4)
# test.condition.names <- count.by.GWAS$test.type %>% unique()
# GWAS.type.list <- c("CAD","IBD")
# p.list <- vector("list", length(test.condition.names) * length(GWAS.type.list))
# for( i in 1:length(test.condition.names) ){
#     for (GWAS.type.index in 1:length(GWAS.type.list)) {
#         test.condition <- test.condition.names[i]
#         GWAS.type <- GWAS.type.list[GWAS.type.index]

#         ## select the columns and rows we need for this plot
#         toPlot.tmp <- count.by.GWAS %>% subset(test.type==test.condition & GWAS.classification==GWAS.type) %>% select(TSS.v.SNP.ranking.in.GWAS.category, passed.filter.ranking.count, GWAS.classification,total.TSS.v.SNP.ranking.count.per.GWAS.classification) %>% unique()
#         toPlot <- toPlot.tmp %>% mutate(ranking.fraction = passed.filter.ranking.count / total.TSS.v.SNP.ranking.count.per.GWAS.classification)
#         p <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.fraction)) + geom_bar(stat="identity", width=0.5) + mytheme +
#             ggtitle(paste0("Fraction of ", GWAS.type, " genes per distance ranking \n with significant topics, ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Fraction of Perturbations")
#         print(p)

#         ## concatenate all genes ranked as 5+
#         closest.ranking.boolean <- toPlot.tmp$TSS.v.SNP.ranking.in.GWAS.category < 5
#         toPlot.closer.genes <- toPlot.tmp[which(closest.ranking.boolean),] %>% as.data.frame()
#         toPlot.farther.genes <- toPlot.tmp[which(!closest.ranking.boolean),] %>% ungroup() %>% select(-GWAS.classification, -TSS.v.SNP.ranking.in.GWAS.category) %>% apply(2,sum) %>% t() %>% as.data.frame() %>% mutate(TSS.v.SNP.ranking.in.GWAS.category = "5+", .before = passed.filter.ranking.count) %>% mutate(GWAS.classification = GWAS.type, .after = passed.filter.ranking.count)

#         toPlot <- rbind(toPlot.closer.genes, toPlot.farther.genes) %>% mutate(ranking.fraction = passed.filter.ranking.count / total.TSS.v.SNP.ranking.count.per.GWAS.classification)
#         p <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.fraction)) + geom_bar(stat="identity", width=0.5) + mytheme +
#             ggtitle(paste0("Fraction of ", GWAS.type, " genes per distance ranking \n with significant topics, ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Fraction of Perturbations")
#         print(p)

#         p.list[[(i-1)*2 + GWAS.type.index]] <- p
#         toPlot <- toPlot %>% select(-ranking.fraction) %>% melt(id.vars=c("TSS.v.SNP.ranking.in.GWAS.category","GWAS.classification"), variable.name="ranking.type", value.name="ranking.count")
#         toPlot$GWAS.classification <- factor(toPlot$GWAS.classification)
#         toPlot$ranking.type <- factor(toPlot$ranking.type)
#         p <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.count, fill=ranking.type)) + geom_bar(position="dodge",stat="identity",width=0.5) + mytheme +
#             ggtitle(paste0("Number of ", GWAS.type, " genes with significant topics, ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Number of Perturbations") + scale_fill_manual(values=wes_palette(n=length(unique(toPlot$ranking.type)), name="Darjeeling2"), name = "Type", labels = c("Significant Genes", "All Genes"))
#         print(p)

#         ## plot 4
#         toPlot <- count.by.GWAS %>% subset(test.type==test.condition & GWAS.classification==GWAS.type) %>% select(TSS.v.SNP.ranking.in.GWAS.category, GWAS.classification, passed.filter.ranking.count, total.TSS.v.SNP.ranking.count.per.GWAS.classification) %>% unique() %>% melt(id.vars=c("TSS.v.SNP.ranking.in.GWAS.category","GWAS.classification"), variable.name="ranking.type", value.name="ranking.count")
#         p <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.count, fill=ranking.type)) + geom_bar(position="dodge",stat="identity",width=0.5) + mytheme +
#             ggtitle(paste0("Number of ", GWAS.type, " genes with significant topics, ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Number of Perturbations") + scale_fill_manual(values=wes_palette(n=length(unique(toPlot$ranking.type)), name="Darjeeling2"), name = "Type", labels = c("Significant Genes", "All Genes"))
#         print(p)
#     }
# }
# dev.off()


# ##########################################################################
# ## GWAS gene ranking by topic ## debug
# pdf(file=paste0(FIGDIRTOP,"_GWAS.gene.rank.barplot.by.topic.pdf"), width=8, height=4) ##todo:210812:
# test.condition.names <- count.by.GWAS.withTopic$test.type %>% unique()
# for (GWAS.type.index in 1:length(GWAS.type.list)) {
#     GWAS.type <- GWAS.type.list[GWAS.type.index]
#     for( i in 1:length(test.condition.names) ){
#         test.condition <- test.condition.names[i]
#         toPlot.allTopics <- count.by.GWAS.withTopic %>% subset(test.type==test.condition & GWAS.classification == GWAS.type) %>%
#             ungroup() %>%
#             select(Topic,
#                    TSS.v.SNP.ranking.in.GWAS.category,
#                    passed.filter.ranking.count,
#                    GWAS.classification,
#                    total.TSS.v.SNP.ranking.count.per.GWAS.classification
#                    ) %>% unique()
#         toPlot.allTopics$GWAS.classification <- factor(toPlot.allTopics$GWAS.classification)
#         for( t in sort(unique(toPlot.allTopics$Topic %>% gsub("topic_","",.))) ) {
#             toPlot.tmp <- toPlot.allTopics %>% subset(Topic == paste0("topic_",t)) 
#             toPlot <- toPlot.tmp %>% mutate(ranking.fraction = passed.filter.ranking.count / total.TSS.v.SNP.ranking.count.per.GWAS.classification)
#             p1 <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.fraction)) + geom_bar(stat="identity", width=0.5) + mytheme +
#                 ggtitle(paste0("Fraction of ", GWAS.type, " significant genes per distance ranking \n in topic ", t, ", ",  test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Fraction of Perturbations")
#             ## print(p1)

#             ## concatenate all genes ranked as 5+
#             closest.ranking.boolean <- toPlot.tmp$TSS.v.SNP.ranking.in.GWAS.category < 5
#             toPlot.closer.genes <- toPlot.tmp[which(closest.ranking.boolean),] %>% as.data.frame() %>% select(-Topic)
#             toPlot.farther.genes <- toPlot.tmp[which(!closest.ranking.boolean),] %>% ungroup() %>% select(-Topic, -GWAS.classification, -TSS.v.SNP.ranking.in.GWAS.category) %>% apply(2,sum) %>% t() %>% as.data.frame() %>% mutate(TSS.v.SNP.ranking.in.GWAS.category = "5+", .before = passed.filter.ranking.count) %>% mutate(GWAS.classification = GWAS.type, .after = passed.filter.ranking.count)

#             toPlot <- rbind(toPlot.closer.genes, toPlot.farther.genes) %>% mutate(ranking.fraction = passed.filter.ranking.count / total.TSS.v.SNP.ranking.count.per.GWAS.classification)
#             p2 <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.fraction)) + geom_bar(stat="identity", width=0.5) + mytheme +
#                 ggtitle(paste0("Fraction of ", GWAS.type, " significant genes per distance ranking \n in topic ", t, ",  ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Fraction of Perturbations")
#             ## print(p2)

#             toPlot <- toPlot %>% select(-ranking.fraction) %>% melt(id.vars=c("TSS.v.SNP.ranking.in.GWAS.category","GWAS.classification"), variable.name="ranking.type", value.name="ranking.count")
#             toPlot$GWAS.classification <- factor(toPlot$GWAS.classification)
#             toPlot$ranking.type <- factor(toPlot$ranking.type)
#             p3 <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.count, fill=ranking.type)) + geom_bar(position="dodge",stat="identity",width=0.5) + mytheme +
#                 ggtitle(paste0("Number of ", GWAS.type, " genes DE in topic ", t, ", ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Number of Perturbations") + scale_fill_manual(values=wes_palette(n=length(unique(toPlot$ranking.type)), name="Darjeeling2"), name = "Type", labels = c("Significant Genes", "All Genes"))
#             ## print(p3)

#             ## plot 4
#             toPlot <- count.by.GWAS.withTopic %>% ungroup() %>% subset(Topic == paste0("topic_",t) & test.type==test.condition & GWAS.classification==GWAS.type) %>% select(TSS.v.SNP.ranking.in.GWAS.category, GWAS.classification, passed.filter.ranking.count, total.TSS.v.SNP.ranking.count.per.GWAS.classification) %>% unique() %>% melt(id.vars=c("TSS.v.SNP.ranking.in.GWAS.category","GWAS.classification"), variable.name="ranking.type", value.name="ranking.count")
#             p4 <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking.in.GWAS.category, y=ranking.count, fill=ranking.type)) + geom_bar(position="dodge",stat="identity",width=0.5) + mytheme +
#                 ggtitle(paste0("Number of ", GWAS.type, " significant genes in topic ", t, ", ", test.condition)) + xlab("Rank of Distance to the Closest SNP") + ylab("Number of Perturbations") + scale_fill_manual(values=wes_palette(n=length(unique(toPlot$ranking.type)), name="Darjeeling2"), name = "Type", labels = c("Significant Genes", "All Genes"))
#             ## print(p4)

#             p <- ggarrange(p4 + ggtitle("") + xlab(""), p1 + ggtitle("") + xlab(""), p3 + ggtitle("") + xlab(""), p2 + ggtitle("") + xlab(""), nrow=2, ncol=2, common.legend=T)
#             p <- annotate_figure(p, top = text_grob(paste0("Significant ", GWAS.type, " Genes in Topic ", t, ", ", test.condition), face="bold", size=14),
#                                  bottom = "Rank of Distance to the Closest SNP")                
#             print(p)            
#             ## toPlot <- toPlot %>% subset(Topic==paste0("topic_",t) & test.type = test.condition)
#             ## toPlot$GWAS.classification <- factor(toPlot$GWAS.classification)
#             ## p <- toPlot %>% ggplot(aes(x=TSS.v.SNP.ranking, y=ranking.count, fill=GWAS.classification)) + geom_bar(position="dodge",stat="identity",width=0.5) + mytheme +
#             ##     ggtitle(paste0("Number of genes with significant topic ", t, ", ", test.condition)) + xlab("Distance Ranking to the Closest SNP") + ylab("Number of Perturbations") + scale_fill_manual(values=wes_palette(n=length(unique(toPlot$GWAS.classification)), name="Darjeeling2")) +
#             ##     theme(legend.title = element_blank())
#             ## print(p)
#         }
#     }
# }
# dev.off()


## ## GWAS ranking by distance in kb
## pdf(file=paste0(FIGDIRTOP,"_GWAS.gene.distance.by.bp.cdf.pdf"), width=8, height=4)
## test.condition.names <- count.by.GWAS$test.type %>% unique()
## GWAS.type.list <- c("CAD","IBD")
## toPlot.list <- vector("list", length(test.condition.names) * length(GWAS.type.list))
## for( i in 1:length(test.condition.names) ){
##     for (GWAS.type.index in 1:length(GWAS.type.list)) {
##         test.condition <- test.condition.names[i]
##         GWAS.type <- GWAS.type.list[GWAS.type.index]

##         toPlot <- count.by.GWAS %>% subset(test.type==test.condition & GWAS.classification==GWAS.type) %>% select(TSS.dist.to.SNP, GWAS.classification)
##         toPlot$GWAS.classification <- factor(toPlot$GWAS.classification)
##         p <- toPlot %>% ggplot(aes(x=TSS.dist.to.SNP)) + stat_ecdf() + mytheme +
##             ggtitle(paste0(GWAS.type, " genes with significant topics, ", test.condition)) + xlab("Distance to the Closest SNP (in bp)") + ylab("Fraction of Significant Perturbed Genes") 
##         print(p)
##         toPlot.list[[(i-1)*2 + GWAS.type.index]] <- toPlot %>% mutate(test.type=test.condition)
##     }
## }
## dev.off()

## Need to double check test.type
## toPlot.all <- do.call(rbind, toPlot.list) %>% mutate(GWAS.classification = paste0(GWAS.classification, ".significant")) %>% rbind(ref.table %>% ungroup() %>% subset(GWAS.classification %in% GWAS.type.list) %>% select(TSS.dist.to.SNP, GWAS.classification) %>% mutate(test.type = "None"))
## for( i in 1:length(test.condition.names) ){
##     test.condition <- test.condition.names[i]
##     toPlot <- toPlot.all %>% subset(test.type %in% c("None", test.condition))
##     p <- toPlot %>% ggplot(aes(x=TSS.dist.to.SNP, color = GWAS.classification)) + stat_ecdf() + mytheme +
##         xlim(0,1000000) +
##         ggtitle(paste0("Test by ", test.condition)) + xlab("Distance to the Closest SNP (in bp)") + ylab("Fraction of Perturbed Genes")
##     print(p)




## heatmap of ptb correlation for factor values##here210810



## heatmap of ptb correlation for factor values threshold with BH adjusted.p.value < 0.1

## heatmap of factor correlation by expressed gene raw weights

## heatmap of factor correlation by expressed gene zscore


# ##########################################################################
# ## UMAP based on factor log2FC values (to see how perturbations cluster)
# pdf(file=paste0(FIGDIRTOP,"perturbation.UMAP.based.on.factor.log2FC.pdf"))
# DimPlot(s.gene.score, reduction = "umap", label=TRUE) %>% print()
# dev.off()


# ##########################################################################
# ## motif enrichment plot
# ep.names <- c("enhancer", "promoter")
# for (ep.type in ep.names) {
#     pdf(file=paste0(FIGDIRTOP,"zscore.",ep.type,".motif.enrichment.pdf"))
#     all.volcano.plots(get(paste0("all.",ep.type,".fisher.df")), ep.type, ranking.type="z-score")##here210812
#     dev.off()
#     pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.enrichment_motif.thr.10e-6.pdf"))
#     all.volcano.plots(get(paste0("all.",ep.type,".fisher.df.10en6")), ep.type, ranking.type="z-score")##here210812
#     dev.off()
#     pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.enrichment.by.count.ttest.pdf"))
#     all.volcano.plots(get(paste0("all.",ep.type,".ttest.df")) %>% subset(top.gene.mean != 0 & !grepl("X.NA.",motif)), ep.type, ranking.type="z-score")
#     dev.off() 
#     pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.enrichment.by.count.ttest_motif.thr.10e-6.pdf"))
#     all.volcano.plots(get(paste0("all.",ep.type,".ttest.df.10en6")) %>% subset(top.gene.mean != 0 & !grepl("X.NA.",motif)), ep.type, ranking.type="z-score")
#     dev.off() 
#     pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.enrichment.by.count.ttest.labelPos.pdf"), width=6, height=6)
#     all.volcano.plots(get(paste0("all.",ep.type,".ttest.df")) %>% subset(top.gene.mean != 0 & !grepl("X.NA.",motif)), ep.type, ranking.type="z-score", label.type="pos")
#     dev.off() 
#     pdf(file=paste0(FIGDIRTOP, "zscore.",ep.type,".motif.enrichment.by.count.ttest_motif.thr.10e-6.labelPos.pdf"))
#     all.volcano.plots(get(paste0("all.",ep.type,".ttest.df.10en6")) %>% subset(top.gene.mean != 0 & !grepl("X.NA.",motif)), ep.type, ranking.type="z-score", label.type="pos")
#     dev.off() 
# }



## ##########################################################################
## ## GSEA

## ## check if files already exist
## check.file <- paste0(FGSEADIR,"/fgsea_all_pathways_df_", c("raw.score", "z.score"), "_", SUBSCRIPT, ".RData") 
## if(file.exists(check.file) %>% as.numeric() %>% sum() == length(check.file) & !opt$recompute ) {
##     for (i in 1:length(check.file)) {
##         print(paste0("Loading ", check.file[i]))
##         load(check.file[i])
##     }
## } else {warning(paste0(check.file, " does not exist"))
## } ## bracket for checking files


## pdf(paste0(FGSEAFIG, "/msigdb.all.sig.pathways.pdf"))
## p <- toplot %>% ggplot(aes(x=topic, y=sig.pathway.count)) + geom_bar(stat="identity", fill ="#38b4f7") + mytheme +
##     ggtitle(paste0(SAMPLE, " msigdb all, GSEA, number of significant pathway per topic")) + xlab("Topic") + ylab("Number of Significant Pathways")
## print(p)
## dev.off()


# ##########################################################################
# ## Pairwise Pearson correlation for perturbation's topic expression
# d <- cor(gene.score %>% t(), method="pearson")
# m <- as.matrix(d)
# pdf(file=paste0(FIGDIRTOP,"cluster.ptb.by.Pearson.corr.pdf"), width=75, height=75)
# plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, K=", k, ", ", SAMPLE, " topic clustering by Pearson correlation"))
# dev.off()
# png(file=paste0(FIGDIRTOP, "cluster.ptb.by.Pearson.corr.png"), width=1500, height=1500)
# plotHeatmap(m, labCol=rownames(m), margins=c(12,12), title=paste0("cNMF, K=", k, ", ", SAMPLE, " topic clustering by Pearson correlation"))
# dev.off()


######################################################################
## Seurat UMAPs

## plot UMAPs
pdf(paste0(FIGDIRTOP, "Factor.Expression.UMAP.pdf"))
DimPlot(s, reduction = "umap", label=TRUE) %>% print()
meta.data.names <- colnames(s[[]])
plot.features <- paste0("K",k,"_",colnames(omega))
for (feature.name in plot.features) {
    feature.vec <- s@meta.data %>% select(all_of(feature.name))
    if(feature.vec[1,1] %>% is.numeric()) { # numeric
        FeaturePlot(s, reduction = "umap", features=feature.name) %>% print()
    } else { # discrete values / categories
        Idents(s) <- s@meta.data %>% select(feature.name)
        DimPlot(s, reduction = "umap", label=TRUE) %>% print()
    }
}
dev.off()

## ## 10x lane sample label UMAP ## checking for batch effects
## CBC.sample <- colnames(s) %>% gsub("^.*-", "", .) %>% gsub("scRNAseq_2kG_","",.)
## CBC.sample.short <- CBC.sample %>% gsub("_.*$","",.)
## s <- AddMetaData(s, CBC.sample, col.name="sample.label.10X")
## s <- AddMetaData(s, CBC.sample.short, col.name="sample.label.short")
## pdf(file=paste0(FIGDIRTOP, "10X.lane.sample.label.UMAP.pdf"), width=10, height=6)
## Idents(s) <- s$sample.label.10X
## DimPlot(s, reduction = "umap", label=F, group.by="sample.label.10X") %>% print()
## Idents(s) <- s$sample.label.short
## DimPlot(s, reduction = "umap", label=F) %>% print()
## DimPlot(s, reduction = "umap", label=F, split.by="sample.label.short") %>% print()
## dev.off()






# ##########################################################################
# ##### SUMMARY PLOTS #####


# ##########################################################################
# ## functions
# get.average.ptb.gene.expression.based.on.ctrl <- function(ptb, ptb.df, ctrl.df, mode="per.guide") {
#   # ptb: column that has gene expression or topic weight, (e.g. "SWAP70" or "topic_4")
#   if (mode=="per.guide") { # first average by guide then average by perturbation
#       e.name <- ptb.df %>% rownames() %>% gsub("_multiTarget|-TSS2","",.) %>% strsplit(., split=":") %>% sapply("[[", 1) %>% unique() 
#       ptb.gene.column.index <- which(grepl(paste0("^",ptb,"$"), colnames(ptb.df)))
#       expression.of.ptb <- ptb.df[,ptb.gene.column.index] %>% as.data.frame() %>% `rownames<-`(rownames(ptb.df)) %>% `colnames<-`(ptb) %>% mutate(long.CBC=rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), remove=F, sep=":") %>% group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene=e.name)
#     ## expression.of.ptb <- ptb.df %>% select(c(all_of(ptb), colnames(guideCounts))) %>%
#     ##   group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene=e.name)

#       expression.ctrl.ptb <- ctrl.df[,ptb.gene.column.index] %>% as.data.frame() %>% `rownames<-`(rownames(ctrl.df)) %>% `colnames<-`(ptb) %>% mutate(long.CBC=rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), remove=F, sep=":") %>% group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene="Control")# %>% select(c(all_of(ptb), colnames(guideCounts))) %>% group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene="Control")
#       toCalculate <- rbind(expression.of.ptb, expression.ctrl.ptb)
#       toCalculate$Gene.expression <- toCalculate$Gene.expression / (toCalculate %>% subset(Gene=="Control") %>% pull(Gene.expression) %>% mean()) * 100
#     toCalculate <- toCalculate %>% group_by(Gene) %>% mutate(Gene = paste0(gsub("topic_","Topic ", Gene), "\n(n=", n(), ")"))
#     toPlot <- toCalculate %>% group_by(Gene) %>% summarise(mean.expression=mean(Gene.expression), error.bar=1.96 * sd(Gene.expression)/sqrt(n()), count=n()) # 1.96 * sd(vals)/sqrt(length(vals))
#   } else { # directly average by cell
#       e.name <- ptb.df %>% rownames() %>% gsub("_multiTarget|-TSS2","",.) %>% strsplit(., split=":") %>% sapply("[[", 1) %>% unique() 
#     # e.name <- ptb.df$Gene %>% unique()
#     ptb.gene.column.index <- which(grepl(paste0("^",ptb,"$"), colnames(ptb.df)))
#     expression.of.ptb <- ptb.df[,ptb.gene.column.index] %>% as.data.frame() %>% `rownames<-`(rownames(ptb.df)) %>% `colnames<-`(ptb) %>% mutate(long.CBC=rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), remove=F, sep=":") %>% mutate(Gene=e.name)
#     expression.ctrl.ptb <- ctrl.df[,ptb.gene.column.index] %>% as.data.frame() %>% `rownames<-`(rownames(ctrl.df)) %>% `colnames<-`(ptb) %>% mutate(long.CBC=rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), remove=F, sep=":") %>% mutate(Gene="Control")
#     ## expression.of.ptb <- ptb.df %>% select(c(all_of(ptb), colnames(guideCounts))) %>% mutate(Gene=e.name)
#     ## expression.ctrl.ptb <- ctrl.df %>% select(c(all_of(ptb), colnames(guideCounts))) %>% mutate(Gene="Control")
#     toCalculate <- rbind(expression.of.ptb, expression.ctrl.ptb) %>% select(all_of(ptb),Guide,Gene) %>% mutate(Gene.expression = get(ptb))
#     toCalculate$Gene.expression <- toCalculate$Gene.expression / (toCalculate %>% subset(Gene=="Control") %>% pull(Gene.expression) %>% mean()) * 100
#     toCalculate <- toCalculate %>% group_by(Gene) %>% mutate(Gene = paste0(gsub("topic_","Topic ", Gene), "\n(n=", n(), ")"))
#     toPlot <- toCalculate %>% group_by(Gene) %>% summarise(mean.expression=mean(Gene.expression), error.bar=1.96 * sd(Gene.expression)/sqrt(n()), count=n())
#     # 1.96 * sd(vals)/sqrt(length(vals))
#   }
#   return(list(toCalculate=toCalculate, toPlot=toPlot))
# }

# lm_eqn <- function(df){
#   m <- lm(y ~ x, df);
#   eq <- substitute(italic(y) == a + b %.% italic(x)*","~~italic(r)^2~"="~r2,
#                    list(a = format(unname(coef(m)[1]), digits = 2),
#                         b = format(unname(coef(m)[2]), digits = 2),
#                         r2 = format(summary(m)$r.squared, digits = 3)))
#   as.character(as.expression(eq));
# }
# lm_eqn_manual <- function(a, b){
#   eq <- substitute(italic(y) == p + q %.% italic(x),
#                    list(p = format(unname(a), digits = 2),
#                         q = format(unname(b), digits = 2)))
#   as.character(as.expression(eq)) %>% return()
# }



# normalize.by.ctrl.avg <- function(ptb, ptb.df, ctrl.df, mode="per.guide"){
#   e.name <- ptb.df$Gene %>% gsub("_multiTarget|-TSS2","",.) %>% strsplit(., split=":") %>% sapply("[[", 1) %>% unique()
#   ptb.col.index <- which(grepl(paste0("^",ptb,"$"), colnames(ctrl.df)))
#   ctrl.df.tmp <- ctrl.df[,ptb.col.index] %>% as.data.frame() %>% `rownames<-`(rownames(ctrl.df)) %>% `colnames<-`(ptb)
#   ctrl.df.tmp <- ctrl.df.tmp %>% mutate(long.CBC=rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), sep=":", remove=F)

#   if (mode=="per.guide") { # first average by guide then average by perturbation
#       expression.ctrl.ptb.tmp <- ctrl.df.tmp %>% group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene="Control")
#       ## expression.ctrl.ptb.tmp <- ctrl.df %>% select(c(all_of(ptb), colnames(guideCounts))) %>%
#       ## group_by(Guide) %>% summarise(Gene.expression=mean(get(ptb))) %>% mutate(Gene="Control")
#     # ptb.df with mean as % control and errorbar
#       expression.of.ptb <- ptb.df %>%
#       mutate(Gene.expression = get(ptb) / (expression.ctrl.ptb.tmp$Gene.expression %>% mean()) * 100) %>%
#       group_by(Guide) %>%
#       summarise(Gene.error.bar = 1.96 * sd(Gene.expression)/sqrt(n()),
#                 Gene.expression=mean(Gene.expression)) %>% mutate(Gene=e.name) # summarise or mutate? # need to keep CBC
#     # control with error bar
#       expression.ctrl.ptb <- ctrl.df.tmp %>%
#       mutate(Gene.expression = get(ptb) / (expression.ctrl.ptb.tmp$Gene.expression %>% mean()) * 100) %>%
#       group_by(Guide) %>%
#       summarise(Gene.error.bar = 1.96 * sd(Gene.expression)/sqrt(n()),
#                 Gene.expression=mean(Gene.expression)) %>% mutate(Gene="Control")
#       toCalculate <- rbind(expression.of.ptb, expression.ctrl.ptb)

#   } else { # directly average by cell
#     expression.ctrl.ptb <- ctrl.df.tmp %>% mutate(Gene="Control")
#     expression.of.ptb <- ptb.df %>% mutate(Gene=e.name)
#     toCalculate <- rbind(expression.of.ptb %>% select(all_of(ptb),Guide,Gene,long.CBC), expression.ctrl.ptb %>% select(all_of(ptb),Guide,Gene,long.CBC)) %>% mutate(Gene.expression = get(ptb))
#   }
#   toCalculate <- toCalculate %>% group_by(Gene) %>% mutate(Gene.count = paste0(gsub("topic_","Topic ", Gene), " (n=", n(), ")"))
#   return(toCalculate)
# }



# giant.summary.plot <- function(ptb, ptb.expressed.name, expressed.gene, mode.selection, enhancer=F) { 
#   # toPlot.per.guide.DE.test.wilcox <- per.guide.DE.test.wilcox %>% subset(Topic==ptb)
#   # if (enhancer) {
#   #   array <- ann.X.full.filtered$Gene[grep("E_at_", ann.X.full.filtered$Gene)] %>% unique()
#   #   e.name <- array[grep(ptb,array)]
#   # } else {
#   #   e.name <- ptb
#   # }


#   # tmp <- ptb # "E_at_" "-no"
#   # ptb <- expressed.gene # "GOSR2"
#   # e.name <- tmp

#     all.test.guide.w <- all.test %>% subset(test.type==paste0(mode.selection,".wilcoxon"))
#     realPvals.df.guide.w <- realPvals.df %>% subset(test.type==paste0(mode.selection,"per.guide.wilcoxon"))


#   #TODO: add rep1rep2 to ctrl.X %>% subset()
#   # for SEP=T, subset ctrl.X to the right sample's control
#   if(SEP) {
#     label.here <- strsplit(ptb, split="-") %>% unlist() %>% nth(2) %>% paste0("-",.)
#     ctrl.X.here <- ctrl.X %>% subset(grepl(label.here,Gene))
#     ctrl.ann.omega.here <- ctrl.ann.omega %>% subset(grepl(label.here,Gene))
#   } else  {
#     ctrl.X.here <- ctrl.X
#     ctrl.ann.omega.here <- ctrl.ann.omega
#   }

#     subset.index <- which(X.gene.names == ptb) 
#     ## ann.X.full.filtered.df <- 
#     toPlot.list <- get.average.ptb.gene.expression.based.on.ctrl(expressed.gene, ann.X.full.filtered[subset.index,], ctrl.X.here, mode=mode.selection)
#     toPlot1 <- toPlot.list[["toPlot"]]
#     order.toPlot.Gene <- function(df) {
#         df$Gene <- factor(x=df$Gene, levels=df$Gene[c(which(grepl("^Control", df$Gene)) %>% min(), which(!grepl("^Control", df$Gene))%>% min())])
#         return(df)
#     }
#     order.toPlot <- function(df, column) {
#         array <- df %>% pull(all_of(column))
#         df <- df %>% mutate(!!column:=factor(x=array, levels=array[c(which(grepl("^Control", array)) %>% min(), which(!grepl("^Control", array))%>% min())]))
#         return(df)
#     }
#     toPlot1 <- order.toPlot(toPlot1, column="Gene")

#   KD.per.guide <- toPlot.list[["toCalculate"]]
#   colnames(KD.per.guide)[colnames(KD.per.guide)=="Gene.expression"] <- "mean.expression"

#   # plot 5 data
#     subset.index <- which(X.gene.names == ptb)
#     ## ptb.gene.col.index <- which(grepl(ptb.expressed.name, colnames(ann.X.full.filtered)))
#     ptb.gene.col.index <- which(colnames(ann.X.full.filtered) == ptb.expressed.name)
#     ann.X.ptb <- ann.X.full.filtered[subset.index,ptb.gene.col.index] %>% as.data.frame() %>% `colnames<-`(ptb.expressed.name) %>% mutate(long.CBC = rownames(.)) %>% separate(., col=long.CBC, into=c("Gene", "Guide", "CBC"), sep=":", remove=F)# select ptb expression column
#   ## ann.X.ptb <- ann.X.full.filtered %>% subset(Gene==ptb) %>% select(colnames(guideCounts), all_of(expressed.gene)) # select ptb expression column
#   normalized.ann.X.ptb.by.guide <- normalize.by.ctrl.avg(expressed.gene, ann.X.ptb, ctrl.X.here, mode.selection) ##210823:debug expressed.gene or ptb
#   normalized.ann.X.ptb.by.guide[is.na(normalized.ann.X.ptb.by.guide)] <- 0
#   # topic
#   ptb.omega.filtered <- ann.omega.filtered %>% subset(Gene == ptb)


#   if (dim(ptb.omega.filtered)[1] > 0){  # if the guide didn't cause a fitness effect
#     if(!dir.exists(paste0(FIGDIRSAMPLE,"/gene.by.topic/"))) dir.create(paste0(FIGDIRSAMPLE,"/gene.by.topic/"), recursive=T)
#     if(!enhancer) pdf(file=paste0(FIGDIRSAMPLE,"/gene.by.topic/",ptb, ".", mode.selection, ".dt_", DENSITY.THRESHOLD, ".pdf"), width=16, height=8)
#     # Do gRNAs at the promoter reduce gene expression?
#     # old plot
#     p1 <- toPlot1 %>% ggplot(aes(x=Gene,y=mean.expression)) + geom_bar(stat='identity', width=0.5, fill="#38b4f7") +
#       geom_errorbar(data = toPlot1, aes(x=Gene, ymin=mean.expression-error.bar, ymax=mean.expression+error.bar), width=.15) +
#       # ylim(min(KD.per.guide$mean.expression - 25, 0), max(KD.per.guide$mean.expression + 50, 150)) +
#       ylab(paste0(expressed.gene, " RNA Expression\n(% vs control)")) + xlab("Guides") + mytheme +
#       # scale_y_continuous(limits = c(0, max(KD.per.guide$mean.expression + 50, 150)), breaks = round(seq(0, max(KD.per.guide$mean.expression + 50, 150), 20),20) ) +
#       geom_hline(yintercept = 100, color="grey", linetype="dashed", size=0.5) +
#       geom_jitter(data=KD.per.guide, size=0.25, width=0.15, color="red") +
#               ggdist::stat_halfeye(
#                     data=KD.per.guide,
#                     ## custom bandwidth
#                     adjust = .5, 
#                     ## adjust height
#                     width = .3, 
#                     ## move geom to the right
#                     justification = -1,
#                     ## remove slab interval
#                     .width = 0, 
#                     point_colour = NA,
#                     ## change violin color 
#                     fill = "#a0dafa"
#                 ) 


#     ## ## referenced https://www.cedricscherer.com/2021/06/06/visualizing-distributions-with-raincloud-plots-with-ggplot2/
#     ## p1 <- KD.per.guide %>% ggplot(aes(x=Gene, y=mean.expression)) +
#     ##     ## add half-violin from {ggdist} package 
#     ##     ggdist::stat_halfeye(
#     ##                 ## custom bandwidth
#     ##                 adjust = .5, 
#     ##                 ## adjust height
#     ##                 width = .3, 
#     ##                 ## move geom to the right
#     ##                 justification = -.4, 
#     ##                 ## remove slab interval
#     ##                 .width = 0, 
#     ##                 point_colour = NA
#     ##             ) +
#     ##     geom_boxplot(
#     ##         width = .1, 
#     ##         ## remove outliers
#     ##         outlier.color = NA ## `outlier.shape = NA` works as well
#     ##     ) +
#     ##     ## add justified jitter from the {gghalves} package
#     ##     gghalves::geom_half_point(
#     ##                   ## control point size
#     ##                   size = 0.5,
#     ##                   ## draw jitter on the left
#     ##                   side = "l", 
#     ##                   ## control range of jitter
#     ##                   range_scale = .4, 
#     ##                   ## add some transparency
#     ##                   alpha = .3
#     ##               ) +
#     ##     coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#     ##     mytheme +
#     ##     geom_hline(
#     ##         yintercept = 100,
#     ##         linetype = "dashed",
#     ##         color = "#38b4f7",
#     ##         size = 0.5
#     ##     ) +
#     ##     ylab(paste0(expressed.gene, " RNA Expression\n(% vs control)")) +
#     ##     xlab("Perturbation") 



#     for (t in 1:dim(omega)[2]) { #:dim(omega)[2]
#       topic <- paste0("topic_", t)
#       toPlot.all.test <- all.test.guide.w %>% subset(Topic==topic)
#       toPlot.fdr <- realPvals.df.guide.w %>% subset(Topic == topic) %>% select(Gene,fdr)

#       # Which topics does the gene regulate?
#       toPlot.list <- get.average.ptb.gene.expression.based.on.ctrl(topic, ptb.omega.filtered %>% `rownames<-`(ptb.omega.filtered$long.CBC), ctrl.ann.omega.here %>% `rownames<-`(ctrl.ann.omega.here$long.CBC), mode=mode.selection)
#       toPlot2 <- toPlot.list[["toPlot"]]
#       # levels(toPlot2$Gene)[grep(topic,levels(toPlot2$Gene))] <- gsub("_", " ", toPlot2$Gene[grep(topic,levels(toPlot2$Gene))]) # change topic_x to "topic x"
#       toPlot2 <- order.toPlot.Gene(toPlot2)

#       toCalculate <- toPlot.list[["toCalculate"]]
#       colnames(toCalculate)[colnames(toCalculate)=="Gene.expression"] <- "mean.expression"

#       p2.y.step <- (max(toCalculate$mean.expression/5) %/% 20) * 20
#       p2.y.step <- ifelse(p2.y.step==0, 20, p2.y.step)
#       p2.y.step <- ifelse(p2.y.step > 100, 100, p2.y.step)

#       ## ## old plot
#       ## p2 <- toPlot2 %>% ggplot(aes(x=Gene,y=mean.expression)) + geom_bar(stat='identity', width=0.5, fill="#38b4f7") +
#       ##   geom_errorbar(aes(ymin=mean.expression-error.bar, ymax=mean.expression+error.bar), width=.15) +
#       ##   # ylim(0, max(toPlot2$mean.expression + toPlot2$error.bar + 50, 150)) +
#       ##   ylab(paste0(gsub("topic_","Topic ", topic), " Expression\n(% vs control)")) + xlab("Guides") + mytheme +
#       ##   # scale_y_continuous(limits = c(0, max(toCalculate$mean.expression + 50, 150)), breaks = seq(0, max(toCalculate$mean.expression + 50, 150), p2.y.step)) +
#       ##   geom_hline(yintercept = 100, color="grey", linetype="dashed", size=0.5) +
#       ##   geom_jitter(data=toCalculate, size=0.25, width=0.15, color="red")

#        p2 <- toPlot2 %>% ggplot(aes(x=Gene,y=mean.expression)) + geom_bar(stat='identity', width=0.5, fill="#38b4f7") +
#           ## geom_jitter(data=toCalculate, aes(x=Gene, y=mean.expression), size=0.25, width=0.15, color="red") +
#           geom_errorbar(aes(ymin=mean.expression-error.bar, ymax=mean.expression+error.bar), width=.15) +
#         # ylim(0, max(toPlot2$mean.expression + toPlot2$error.bar + 50, 150)) +
#         ylab(paste0(gsub("topic_","Topic ", topic), " Expression\n(% vs control)")) + xlab("Guides") + mytheme +
#         # scale_y_continuous(limits = c(0, max(toCalculate$mean.expression + 50, 150)), breaks = seq(0, max(toCalculate$mean.expression + 50, 150), p2.y.step)) +
#         geom_hline(yintercept = 100, color="grey", linetype="dashed", size=0.5) +
#         geom_jitter(data=toCalculate, size=0.25, width=0.15, color="red") +
#         ggdist::stat_halfeye(
#                     data=toCalculate,
#                     ## custom bandwidth
#                     adjust = .5, 
#                     ## adjust height
#                     width = .3, 
#                     ## move geom to the right
#                     justification = -1,
#                     ## remove slab interval
#                     .width = 0, 
#                     point_colour = NA,
#                     ## change violin color 
#                     fill = "#a0dafa"
#                 ) 

#       ## ## referenced https://www.cedricscherer.com/2021/06/06/visualizing-distributions-with-raincloud-plots-with-ggplot2/
#       ## p2 <- toCalculate %>% ggplot(aes(x=Gene, y = mean.expression)) +
#       ##     ggdist::stat_halfeye(
#       ##                 ## custom bandwidth
#       ##                 adjust = .5, 
#       ##                 ## adjust height
#       ##                 width = .3, 
#       ##                 ## move geom to the right
#       ##                 justification = -.4, 
#       ##                 ## remove slab interval
#       ##                 .width = 0, 
#       ##                 point_colour = NA
#       ##             ) +
#       ##     geom_boxplot(
#       ##         width = .1, 
#       ##         ## remove outliers
#       ##         outlier.color = NA ## `outlier.shape = NA` works as well
#       ##     ) +
#       ##     ## add justified jitter from the {gghalves} package
#       ##     gghalves::geom_half_point(
#       ##                   ## control point size
#       ##                   size = 0.5,
#       ##                   ## draw jitter on the left
#       ##                   side = "l", 
#       ##                   ## control range of jitter
#       ##                   range_scale = .4, 
#       ##                   ## add some transparency
#       ##                   alpha = .3
#       ##               ) +
#       ##     coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#       ##     mytheme +
#       ##     geom_hline(
#       ##         yintercept = 100,
#       ##         linetype = "dashed",
#       ##         color = "#38b4f7",
#       ##         size = 0.5
#       ##     ) +
#       ##     ylab(paste0(gsub("topic_","Topic ", topic), " Expression\n(% vs control)")) +
#       ##     xlab("Perturbation")






#       # ## alternative to p2, but basically the same thing
#       # ## stat_summary removes some values before averaging and causes control to be not at 100%
#       # p3 <- toPlot.list[["toCalculate"]] %>% ggplot(aes(x=Gene,y=Gene.expression)) +
#       #   geom_bar(stat = "summary", fun = "mean", width=0.5, fill="#38b4f7") +
#       #   stat_summary(fun.data = mean_cl_normal,
#       #                geom = "errorbar", width=0.15) +
#       #   # ylim(0, max(toPlot2$mean.expression + toPlot2$error.bar + 50, 150)) +
#       #   ylab(paste0(gsub("_"," ", topic), " Expression\n(% vs control)")) + xlab("") + mytheme +
#       #   scale_y_continuous(limits = c(0, max(toPlot2$mean.expression + toPlot2$error.bar + 50, 150)), breaks = seq(0, max(toPlot2$mean.expression + toPlot2$error.bar + 50, 150), p2.y.step)) +
#       #   geom_hline(yintercept = 100, color="grey", linetype="dashed") +
#       #   geom_jitter(size=0.25, width=0.15, color="red")

#       ## #TODO: fix KL specificity score
#       ## # What cellular program does this topic represent?
#       ## toPlot <- data.frame(Gene=topFeatures %>% subset(topic == t) %>% pull(genes),
#       ##                      Score=topFeatures %>% subset(topic == t) %>% pull(scores)) %>%
#       ##   merge(., gene.def.pathways, by="Gene", all.x=T)
#       ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#       ## p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100, fill=Pathway) ) + geom_col(width=0.5) + theme_minimal()
#       ## p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("KL score (gene specific to this topic)") +
#       ##   mytheme + theme(legend.position="bottom", legend.direction="vertical")

#       # raw weight version
#       ## hand annotated files by Helen
#       toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
#                            Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores)) %>%
#         merge(., gene.def.pathways, by="Gene", all.x=T) %>% arrange(desc(Score)) %>% slice(1:10)
#       ## 210804 use Gavin's new summary annotations
#       ## toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
#       ##                      Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores)) %>%
#       ##   merge(., summaries %>% select(Symbol, top_class), by.x="Gene", by.y="Symbol", all.x=T) %>% unique %>% `colnames<-`(c("Gene", "Score", "Pathway")) %>% arrange(desc(Score)) %>% slice(1:10)
#       toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#       toPlot$Pathway[toPlot$Pathway == "unclassified"] <- "Other/Unclassified"
#       p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100, fill=Pathway) ) + geom_col(width=0.5) + theme_minimal()
#       p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("Specificity Score (z-score)") +
#         mytheme + theme(legend.position="bottom", legend.direction="vertical")



#       # plot 4
#       # add ABC to gene.set.type.df for this particular plot
#       ## gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"
#       # assemble toPlot
#         toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>%
#             merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             ## merge(.,gene.set.type.df,by="Gene") %>%
#             mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene)) %>%
#             merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#             mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
#             mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                             ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#             mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))
#         colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"
#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*")) %>%
#             arrange(log2FC) %>%
#             mutate(x = seq(nrow(.)))
#         label <- toPlot %>% subset(x <= 3 | x > (nrow(toPlot)-3) | Gene == ptb | adjusted.p.value < fdr.thr)
#       ## toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#       ##   merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#       ##   merge(.,gene.set.type.df,by="Gene") %>%
#       ##   mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene))
#       ## colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"
#       ## toPlot <- toPlot %>% mutate(significant=ifelse((fdr >= fdr.thr | is.na(fdr)), "", "*")) %>%
#       ##   arrange(log2FC) %>%
#       ##   mutate(x = seq(nrow(.)))
#       ## label <- toPlot %>% subset(fdr < fdr.thr | x <= 3 | x > (nrow(toPlot)-3) | Gene == ptb | adjusted.p.value < fdr.thr)
#       mytheme <- theme_classic() + theme(axis.text = element_text(size = 13), axis.title = element_text(size = 15), plot.title = element_text(hjust = 0.5, face = "bold"))
#       p5 <- toPlot %>% ggplot(aes(x=reorder(Gene, log2FC), y=log2FC, color=significant)) + geom_point(size=0.75) + mytheme +
#         theme(axis.ticks.x=element_blank(), axis.text.x=element_blank()) + scale_color_manual(values = c("#E0E0E0", "#38b4f7")) +
#         xlab("Perturbed Genes") + ylab(paste0("Topic ", t, " Expression log2 Fold Change")) +
#         geom_text_repel(data=label, box.padding = 0.5,
#                         aes(label=ann.Gene), size=5,
#                         color="black") +
#         theme(legend.position = "none") # legend at the bottom?


#       # plot 5: topic expression vs KD efficacy
#       # data
#       ptb.omega.t <- ptb.omega.filtered %>% select(all_of(c("Gene","Gene.full.name","long.CBC","CBC","Guide")), all_of(topic))

#       normalized.ptb.omega.t.by.guide <- normalize.by.ctrl.avg(topic, ptb.omega.t, ctrl.ann.omega.here, mode=mode.selection) %>%
#         select(-Gene, -Gene.count) %>%  # remove redundant columns before merging
#           `colnames<-`(gsub("Gene", "Topic", colnames(.)))
#       if(mode.selection=="per.cell"){
#           merged.ptb.normalized.X.omega.t.by.guide <- merge(normalized.ann.X.ptb.by.guide %>% select(-Guide), normalized.ptb.omega.t.by.guide %>% select(-Guide), by="long.CBC")
#       } else {
#           merged.ptb.normalized.X.omega.t.by.guide <- merge(normalized.ann.X.ptb.by.guide, normalized.ptb.omega.t.by.guide, by="Guide") # GOSR2-plus should have 10 guides, but only 8 left
#           ##debug: rbind(deparse.level, ...):   numbers of columns of arguments do not match
#       }

#       toPlot <- merged.ptb.normalized.X.omega.t.by.guide %>% order.toPlot(., column="Gene.count")
#       toPlot.ptb <- toPlot %>% subset(Gene == ptb) # for linear regression
#       fit <- tryCatch(york(toPlot.ptb %>% select(Gene.expression, Gene.error.bar, Topic.expression, Topic.error.bar)),
#                       error=function(cond){
#                         return(data.frame(a = 0, b = 0))
#                       })
#       p6a <- toPlot %>% ggplot(aes(x=Gene.expression, y=Topic.expression, color = Gene.count)) +
#         geom_point(size=1) +
#         ylab(paste0("Topic ", t, " Expression\n(% vs control)")) + xlab(paste0(expressed.gene, " RNA Expression\n(% vs control)")) + mytheme +
#         scale_color_manual(values=c("grey", "red"), name="Guide")  + theme(legend.position="bottom") +
#         geom_smooth(data = toPlot.ptb, method="lm", se=F, fullrange=T, size=0.5) +
#         annotate("text", size = 5, x = min(toPlot$Gene.expression ) + 30, y = min(toPlot$Topic.expression - 15), hjust=0.2,
#                  label = lm_eqn(data.frame(x=toPlot.ptb$Gene.expression,
#                                            y=toPlot.ptb$Topic.expression)), parse = TRUE) # , parse = TRUE

#       p6 <- toPlot %>% ggplot(aes(x=Gene.expression, y=Topic.expression, color = Gene.count)) +
#           geom_point(size=0.3, color = "gray") +
#           geom_point(data=toPlot.ptb, size=0.8, color = "red") + ## put perturbation red datapoints on top
#         ylab(paste0("Topic ", t, " Expression\n(% vs control)")) + xlab(paste0(expressed.gene, " RNA Expression\n(% vs control)")) + mytheme +
#         scale_color_manual(values=c("grey", "red"), name="Guide")  + theme(legend.position="bottom") +
#         geom_abline(intercept = fit$a[1], slope = fit$b[1], col="red", size=0.5)
#       if(mode.selection=="per.cell"){
#           p6 <- p6 + annotate("text", size = 5, hjust=0.2, x=0, y=0-(max(toPlot$Topic.expression)-min(toPlot$Topic.expression))/20,
#                               label = lm_eqn_manual(fit$a[1], fit$b[1]), parse = TRUE)
#       } else {
#           p6 <- p6 + annotate("text", size = 5, x = min(toPlot$Gene.expression ) + 30, y = min(toPlot$Topic.expression - 15), hjust=0.2,
#                               label = lm_eqn_manual(fit$a[1], fit$b[1]), parse = TRUE)
#       }


#       # set xlim and ylim to fit annotation!

#       ## TF motif enrichment volcano plot
#       toplot <- all.promoter.ttest.df %>% subset(topic==paste0("topic_",t) & top.gene.mean != 0)
#       volcano.plot <- function(toplot, EP.string, label.type="") {
#           if( label.type == "pos") {
#               label <- toplot %>% subset(-log10(p.adjust) > 1 & enrichment.log2fc > 0) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
#           } else {
#               label <- toplot %>% subset(-log10(p.adjust) > 1) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
#           }
#           t <- gsub("topic_", "", toplot$topic[1])
#           p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(p.adjust))) + geom_point(size=0.5) + mytheme +
#               ggtitle(paste0(EP.string, " Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(adjusted p-value)") +
#               geom_text_repel(data=label, box.padding = 0.5,
#                               aes(label=motif.toshow), size=5,
#                               max.overlaps=25,
#                               color="black")
#           return(p)
#       }
#       p.promoter.motif <- volcano.plot(toplot, "Promoter", label.type="pos")
#       toplot <- all.enhancer.ttest.df.10en6 %>% subset(topic==paste0("topic_",t) & top.gene.mean != 0)
#       p.enhancer.motif <- volcano.plot(toplot, "Enhancer", label.type="pos")


#       ## edgeR p-value vs ptb effect RNA expression for top 100 specific gene in the topic
#       ##use.edgeR.results
#       edgeR.expr.gene.names <- rownames(log2fc.edgeR) %>% strsplit(split=":") %>% sapply("[[",1)
#       ptb.colindex <- which(grepl(ptb, colnames(log2fc.edgeR)))
#       theta.zscore.t <- theta.zscore[,t] %>% as.data.frame %>% `colnames<-`("topic.zscore.weight") %>% mutate(genes=rownames(.)) %>% arrange(desc(topic.zscore.weight)) %>% mutate(gene.rank = 1:n(), top100=ifelse(gene.rank <= 100, paste0("Top 100 in Topic ", t), paste0("Not in Top 100")))
#       ptb.pval.log2fc <- merge(p.value.edgeR[,c(1,ptb.colindex)] %>% `colnames<-`(c("genes", "pval")) %>% mutate(genes = strsplit(genes, split=":") %>% sapply("[[",1)), log2fc.edgeR[,c(1,ptb.colindex)] %>% `colnames<-`(c("genes", "log2fc")) %>% mutate(genes = strsplit(genes, split=":") %>% sapply("[[",1)), by="genes") %>% merge(theta.zscore.t, by="genes") %>% mutate(nlog10pval = -log10(pval))
#       ptb.pval.log2fc$nlog10pval[ptb.pval.log2fc$nlog10pval > 15] <- 15

#       ptb.pval.log2fc$top100 <- factor(ptb.pval.log2fc$top100) %>% ordered(levels=c(paste0("Top 100 in Topic ", t), "Not in Top 100"))

#       label <- ptb.pval.log2fc %>% subset(top100!="Not in Top 100")
#       ptb.10X <- ptb.10X.name.conversion$`Name used by CellRanger`[which(grepl(ptb, ptb.10X.name.conversion$Symbol))]
#       label.self <- ptb.pval.log2fc %>% subset(genes == ptb.10X)
#       p.density <- ptb.pval.log2fc %>% ggplot(aes(x=log2fc, fill=top100)) + geom_density(alpha=0.4) + mytheme + xlab("RNA Expression\n(log2FC vs Control") + scale_fill_manual(values=c("red", "gray"), name = "") + theme(legend.position = "none") # + guides(fill=guide_legend(nrow=2, byrow=T))
#       p.edgeR <- ggplot(ptb.pval.log2fc, aes(x=log2fc, y=nlog10pval)) + geom_point(size=0.1, color = "gray") +
#           geom_point(data = label, size=0.1, color = "red") +
#           mytheme +
#           geom_text_repel(data=label.self, box.padding = 0.5,
#                           max.overlaps=30,
#                           aes(label=genes), size=4,
#                           color="blue") +
#           geom_text_repel(data=label, box.padding = 0.5,
#                           max.overlaps=30,
#                           aes(label=genes), size=4,
#                           color="black") +
#           scale_color_manual(values=c("gray", "red")) +
#           geom_vline(xintercept=0, col = "#38b4f7", lty=3) +
#           theme(legend.position="bottom") + 
#           ggtitle(paste0("Topic ", t, " Perturbation ", ptb)) +
#           xlab("Average RNA Expression (log2FC vs control)") + ylab("p-value (-log10)") # +
#       ## inset_element(p.density, left = 0.1, bottom = 0.75, right=0.9, top = 0.95)
#       p.edgeR.combined <- cowplot::plot_grid(p.density, p.edgeR, aign="v", ncol=1, rel_heights=c(0.15, 0.85))

#       ## ## old
#       ## p.top.left <- ggarrange(p1, p2, p6, nrow=1, widths=c(1.5,1.5,2))
#       ## ## p.bottom.left <- ggarrange(p6a, p6, nrow=1)
#       ## p.bottom.left <- ggarrange(p.promoter.motif, p.enhancer.motif, nrow=1)
#       ## p.left <- ggarrange(p.top.left, p.bottom.left, nrow=2)
#       ## p <- ggarrange(p.left, p4, p5, ncol=3, nrow=1, widths=c(2,1,1))
#       ## # p <- ggarrange(p1, p2, p4, p5, p6, ncol=4, nrow=1, widths=c(1,1,1.5,1.5))

#       p.top.left <- ggarrange(p1, p2, p6, nrow=1, widths=c(1.5,1.5,2))
#       p.bottom.left <- ggarrange(p.promoter.motif, p.enhancer.motif, nrow=1)
#       p.left <- ggarrange(p.top.left, p.bottom.left, nrow=2)
#       p.mid <- ggarrange(p4, p.edgeR, nrow=2, heights=c(1.5,1))
#       p <- ggarrange(p.left, p.mid, p5, ncol=3, nrow=1, widths=c(2,1,1))

#       print(p)
#     }
#     if(!enhancer) dev.off()
#   }
# }


##########################################################################
## slide 26 for all perturbation x topic pairs
# function for plot 5

## # plot 3 data
## toSave.features <- read.delim(paste0(OUTDIRSAMPLE, "/topic.KL.score_K", k,ifelse(SEP, ".sep", ""), ".txt"),header=T, stringsAsFactors=F) ### commented out because it's not necessary in this pipeline
## topFeatures <- toSave.features %>% group_by(topic) %>% arrange(desc(scores)) %>% slice(1:10)
# plot 4 data


## ## FGSEA results
## ## load data
## type <- "z.score"
## fgsea.df <- read.delim(file=paste0(FGSEADIR, "/fgsea_", type, "_", SUBSCRIPT, ".txt"), header=T, stringsAsFactors=F) ##HERE
## fgsea.df.GO <- fgsea.df %>% subset(database == "msigdb.c3")
## fgsea.df.all <- fgsea.df %>% subset(database == "msigdb.all")


# # ptb.array <- c("GOSR2", "PRKCE", "PHACTR1", "EIF2B2")
#     ptb.array <- c("RHOA", "PECAM1", "RAP1A", "KLF4", "MEF2C", "EGR1", "CDC42EP2", "CDH5", "KIAA1429","GOSR2","TP53","MAT2A","SKI","EDN1","SMAD3","PHACTR1","EDN1","GGCX","CDKN1A","EGFL7","ELOF1","GPANK1","YLPM1","MESDC1","ITGA5","SKIV2L","LST1","R3HCC1L","UPF2","MEAF6", "CCM2", "KRIT1", "ITGB1BP","HEG1")
#     ## ptb.array <- c("TP53", "SWAP70", "GOSR2", "PRKCE", "PHACTR1", "EIF2B2", "PPIF", "DMRTA1", "ADAMTS7", "VEZT", "MEAF6", "CDH5")
# # ptb.array <- guideCounts$Gene %>% unique()
# # ptb.array <- enhancer.set
# ## ptb.array <- append(ptb.array, all.test.guide.w$Gene %>% unique()) %>% append(CAD.focus.gene.set) %>% append(gene.set.type.df %>% subset(grepl("CAD_Loci", type))  %>% pull(Gene) %>% sort())
# ptb.array <- append(ptb.array, all.test.guide.w %>%subset(adjusted.p.value < 0.1) %>% pull(Gene) %>% unique())
# # ptb.array <- CAD.focus.gene.set[grepl("E_at_", CAD.focus.gene.set)]

# for (mode.selection in c("per.guide", "per.cell")){

#     ## new!
#     for (ptb in ptb.array) { # make a separate one for enhancers, clear up cases like E_at_AAGAB,IQCH,LOC102723493&SMAD3
#         print(paste0(ptb, "\n\n"))
#         if(grepl("E_at_", ptb)) {
#             target.gene <- strsplit(gsub("E_at_","",ptb %>% strsplit(split="-") %>% unlist() %>% nth(1)), split=",|&") %>% unlist() %>% as.character()
#             if((target.gene %in% colnames(ann.X.full.filtered)) %>% as.numeric() %>% sum() > 0 &  # target gene must express
#                (ptb %in% ann.X.full.filtered$Gene) %>% as.numeric() %>% sum() > 0) { # enhancer must have enough cells and guides
#                 pdf(file=paste0(FIGDIRSAMPLE,"/gene.by.topic/",ptb, ".", mode.selection, ".pdf"), width=15, height=8)
#                 for(expressed.gene in target.gene) {
#                     if(expressed.gene %in% colnames(ann.X.full.filtered)){
#                         print(paste0("Enhancer of ", expressed.gene, "\n\n"))
#                         gene.index <- which(grepl(expressed.gene, ptb.10X.name.conversion$Symbol))
#                         if(length(gene.index) == 1) {
#                             ptb.expression.name <- ptb.10X.name.conversion$`Name used by CellRanger`[gene.index]
#                         } else {
#                             ptb.expression.name <- expressed.gene
#                         }
#                         giant.summary.plot(ptb, ptb.expression.name, expressed.gene, mode.selection, enhancer=T)
#                     }
#                 }
#                 dev.off()
#             }
#         } else{
#             expressed.gene <- ptb %>% strsplit(split="-") %>% unlist() %>% nth(1)
#             gene.index <- which(grepl(ptb, ptb.10X.name.conversion$Symbol))
#             if(length(gene.index) == 1) {
#                 ptb.expressed.name <- ptb.10X.name.conversion$`Name used by CellRanger`[gene.index]
#                 expressed.gene <- ptb.expressed.name
#             } else {
#                 ptb.expressed.name <- expressed.gene
#             }
#             message(paste0("Expressed gene: ", expressed.gene, ", Perturbation: ", ptb, ", with expressed name: ", ptb.expressed.name))
#             if((colnames(X.full) == expressed.gene) %>% as.numeric() %>% sum() > 0) giant.summary.plot(ptb, ptb.expressed.name, expressed.gene, mode.selection, enhancer=F)
#         }
#     }
#     ##HERE
# } # if(opt$subsample.type!="ctrl")



# ##scratch:210818
# KDefficacyFIGDIR <- paste0(FIGDIRSAMPLE, "/KD.efficacy/")
# if(!dir.exists(KDefficacyFIGDIR)) dir.create(KDefficacyFIGDIR)

# ptb.array <- c("RHOA", "PECAM1", "RAP1A", "KLF4", "MEF2C", "EGR1", "CDC42EP2", "TSPAN14", "NFATC2") ##210823 list for topic 15
# ptb.array <- ref.table$Symbol %>% unique()

# ## per-guide KD efficacy
# for (ptb in ptb.array) {
#     subset.index <- which(X.gene.names == ptb) 
#     gene.index <- which(grepl(ptb, ptb.10X.name.conversion$Symbol))
#     if(length(gene.index) == 1) {
#         expressed.gene <- ptb.10X.name.conversion$`Name used by CellRanger`[gene.index]
#     } else {
#         expressed.gene <- ptb
#     }
#     ##TODO: add rep1rep2 to ctrl.X %>% subset()
#     ## for SEP=T, subset ctrl.X to the right sample's control
#     if(SEP) {
#         label.here <- strsplit(ptb, split="-") %>% unlist() %>% nth(2) %>% paste0("-",.)
#         ctrl.X.here <- ctrl.X %>% subset(grepl(label.here,Gene))
#         ctrl.ann.omega.here <- ctrl.ann.omega %>% subset(grepl(label.here,Gene))
#     } else  {
#         ctrl.X.here <- ctrl.X
#         ctrl.ann.omega.here <- ctrl.ann.omega
#     }
#     ## construct table for plotting
#     toPlot.list <- get.average.ptb.gene.expression.based.on.ctrl(expressed.gene, ann.X.full.filtered[subset.index,], ctrl.X.here, mode=mode.selection)
#     toPlot1 <- toPlot.list[["toPlot"]]
#     order.toPlot.Gene <- function(df) {
#         df$Gene <- factor(x=df$Gene, levels=df$Gene[c(which(grepl("^Control", df$Gene)) %>% min(), which(!grepl("^Control", df$Gene))%>% min())])
#         return(df)
#     }
#     order.toPlot <- function(df, column) {
#         array <- df %>% pull(all_of(column))
#         df <- df %>% mutate(!!column:=factor(x=array, levels=array[c(which(grepl("^Control", array)) %>% min(), which(!grepl("^Control", array))%>% min())]))
#         return(df)
#     }
#     toPlot1 <- order.toPlot(toPlot1, column="Gene")
#     ## data points
#     KD.per.guide <- toPlot.list[["toCalculate"]]
#     colnames(KD.per.guide)[colnames(KD.per.guide)=="Gene.expression"] <- "mean.expression"
#     ## plot
#     p.KD.efficacy.barplot <- toPlot1 %>% ggplot(aes(x=Gene,y=mean.expression)) + geom_bar(stat='identity', width=0.5, fill="#38b4f7") +
#         geom_errorbar(data = toPlot1, aes(x=Gene, ymin=mean.expression-error.bar, ymax=mean.expression+error.bar), width=.15) +
#                                         # ylim(min(KD.per.guide$mean.expression - 25, 0), max(KD.per.guide$mean.expression + 50, 150)) +
#         ylab(paste0(expressed.gene, " RNA Expression\n(% vs control)")) + xlab("Guides") + mytheme +
#                                         # scale_y_continuous(limits = c(0, max(KD.per.guide$mean.expression + 50, 150)), breaks = round(seq(0, max(KD.per.guide$mean.expression + 50, 150), 20),20) ) +
#         geom_hline(yintercept = 100, color="grey", linetype="dashed", size=0.5) +
#         geom_jitter(data=KD.per.guide, size=0.25, width=0.15, color="red")

#     pdf(file=paste0(KDefficacyFIGDIR, "ptb.", ptb, "_KD.efficacy.pdf"), width=4, height=6)
#     print(p.KD.efficacy.barplot)
#     dev.off()
# }




## commented out 211013


# ## topic 29 expression log2FC versus CDH5 RNA expression log2FC ##todo:210812
# for (t in 1:k) {
#     ## t <- 29 ##scratch
#     topic <- paste0("topic_",t)
#     ## plot location
#     perFACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/factor", t, "/")
#     if(!dir.exists(perFACTORFIGDIR)) dir.create(perFACTORFIGDIR)
#     perFACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/factor", t, "/", SAMPLE,"_K",k, "_dt_", DENSITY.THRESHOLD,"_factor", t, "_")

#     ## ptb.name <- "CDH5" ##scratch top perturbations
#     top.ptb.name <- gene.score %>% as.data.frame %>% mutate(Gene = rownames(.)) %>% select(all_of(topic), Gene) %>% arrange(desc(get(topic))) %>% slice(1:10) %>% pull(Gene)
#     for( ptb.name in top.ptb.name) {
#         topic.log2fc.here <- log2fc.omega %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% subset(grepl(ptb.name, Gene))  ## get Topic expression for Gene G
#         topic.fc.here <- fc.omega %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% subset(grepl(ptb.name, Gene))

#         ## inf.index <- topic.log2fc.here %>% pull(all_of(topic)) %>% is.infinite %>% which  ## get cell index for zero Topic expression entries
#         ## neg.inf.index <- (topic.fc.here %>% pull(all_of(topic)) == 0) %>% which
#         ## pos.inf.index <- which(!(inf.index %in% neg.inf.index))
#         ## topic.log2fc.here[[topic]][inf.index] <- 0
#         ## min.log2fc.here <- topic.log2fc.here %>% pull(all_of(topic)) %>% min
#         ## max.log2fc.here <- topic.log2fc.here %>% pull(all_of(topic)) %>% max
#         ## topic.log2fc.here[[topic]][neg.inf.index] <- floor(min.log2fc.here - 3)
#         ## topic.log2fc.here[[topic]][pos.inf.index] <- ceil(max.log2fc.here + 3)

#         remove.inf <- function(topic.log2fc.here, topic.fc.here, topic) {
#             topic.index <- which(grepl(topic, colnames(topic.log2fc.here)))
#             inf.index <- topic.log2fc.here[,topic.index] %>% is.infinite %>% which
#             neg.inf.index <- (topic.fc.here[,topic.index] == 0) %>% which

#             ## inf.index <- topic.log2fc.here %>% pull(all_of(topic)) %>% is.infinite %>% which  ## get cell index for zero Topic expression entries
#             ## neg.inf.index <- (topic.fc.here %>% pull(all_of(topic)) == 0) %>% which

#             pos.inf.index <- which(!(inf.index %in% neg.inf.index))
#             topic.log2fc.here[inf.index,topic.index] <- 0
#             min.log2fc.here <- topic.log2fc.here[,topic.index] %>% min
#             max.log2fc.here <- topic.log2fc.here[,topic.index] %>% max
#             ## min.log2fc.here <- topic.log2fc.here %>% pull(all_of(topic)) %>% min
#             ## max.log2fc.here <- topic.log2fc.here %>% pull(all_of(topic)) %>% max
#             topic.log2fc.here[neg.inf.index,topic.index] <- floor(min.log2fc.here - 3)
#             topic.log2fc.here[pos.inf.index,topic.index] <- ceil(max.log2fc.here + 3)
#             return(topic.log2fc.here)
#         }

#         topic.log2fc.here <- remove.inf(topic.log2fc.here, topic.fc.here, topic)


#         ## select RNA expresion
#         ## todo: function to extract df
#         X.full.here <- X.full %>% subset(grepl(ptb.name, rownames(.))) %>% `colnames<-`(gsub(":.*$", "", colnames(.)))
#         fc.X.full.here <- fc.X.full %>% subset(grepl(ptb.name, rownames(.))) %>% `colnames<-`(gsub(":.*$", "", colnames(.)))
#         log2fc.X.full.here <- log2fc.X.full %>% subset(grepl(ptb.name, rownames(.))) %>% `colnames<-`(gsub(":.*$", "", colnames(.)))

#         get.topicFC.vs.RNAexpFC.toPlot <- function(ptb.log2fc.df, expressed.ptb.name, RNA.full.here) {
#             ## ##### function to combine per cell topic data and gene expression data
#             ## INPUT
#             ## ptb.log2fc.df: topic log2FC df subset to one perturbation only
#             ## expressed.ptb.name: the expressed gene
#             ## X.full.here: matrix subset to one perturbation only, the values could be RNA expression count, FC, or log2FC.
#             ## 
#             ## OUTPUT
#             ## toPlot: a dataframe that has ptb only cells and one topic column and one expressed gene column
#             ##
#             expressed.gene.index <- which(grepl(paste0("^", paste0(expressed.ptb.name, collapse="$|^"), "$"), colnames(RNA.full.here)))
#             expressed.gene.RNA.here <- RNA.full.here[,expressed.gene.index] %>% as.data.frame %>% `colnames<-`(expressed.ptb.name) 
#             toPlot <- merge(ptb.log2fc.df, expressed.gene.RNA.here, by.x="Gene", by.y=0)
#             return(toPlot)
#         }


#         ## ptb.name.index <- which(grepl(ptb.name, colnames(log2fc.X.full.here)))
#         ## log2fc.RNA.expression.here <- log2fc.X.full.here[,ptb.name.index] %>% as.data.frame %>% `colnames<-`(ptb.name)
#         ## RNA.expression.here <- X.full.here[,ptb.name.index] %>% as.data.frame %>% `colnames<-`(ptb.name)
#         ## topic.RNA.expression.here <- merge(log2fc.here, RNA.expression.here, by.x="Gene", by.y=0)

#         if(ptb.name=="MESDC1") ptb.name <- "TLNRD1"
#         if(ptb.name %in% colnames(X.full.here)) {
#             log2fc.X.full.here <- remove.inf(log2fc.X.full.here, fc.X.full.here, ptb.name)
#             topic.RNA.expression.here <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, ptb.name, X.full.here)
#         } else {
#             warning(paste0(ptb.name, " is not in the 10X gene list"))
#         }

#         ## plot topic FC vs RNA expression
#         pdf(file=paste0(perFACTORFIGDIR, "ptb.", ptb.name, "_FC.vs.RNA.expression.pdf"))##here210818
#         p <- topic.RNA.expression.here %>% ggplot(aes(x=get(ptb.name), y=get(topic))) + geom_point() + mytheme +
#             ggtitle(paste0(ptb.name, " Perturbation")) + 
#             xlab(paste0(ptb.name, " RNA Expression")) + ylab(paste0("Topic ", gsub("topic_","",topic), " Expression (log2FC)"))
#         print(p)
#         dev.off()

#         ## EDN1 (or top genes in topic29) expression log2FC in CDH5 KD cells versus topic 29 log2FC in CDH5 KD cells ##todo:210812
#         ## get list of top perturbations
#         top.expressed.gene.in.topic <- theta.zscore[,t] %>% sort(decreasing=T) %>% head(num_ptb) %>% names
#         ## concatenate all top gene in topic RNA expression
#         topic.expressed.gene.RNA.df <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, top.expressed.gene.in.topic, X.full.here)
#         topic.expressed.gene.RNA.fc.df <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, top.expressed.gene.in.topic, fc.X.full.here)
#         topic.expressed.gene.RNA.log2fc.df <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, top.expressed.gene.in.topic, log2fc.X.full.here)

#         ## melt the dfs
#         topic.expressed.gene.RNA.long <- topic.expressed.gene.RNA.df %>% melt(id.vars = c("Gene", topic), value.name = "RNA.expression.count", variable.name = "expressed.gene.name")
#         topic.expressed.gene.RNA.fc.long <- topic.expressed.gene.RNA.fc.df %>% melt(id.vars = c("Gene", topic), value.name = "RNA.expression.fc", variable.name = "expressed.gene.name")
#         topic.expressed.gene.RNA.log2fc.long <- topic.expressed.gene.RNA.log2fc.df %>% melt(id.vars = c("Gene", topic), value.name = "RNA.expression.log2fc", variable.name = "expressed.gene.name")

#         ## ## violin plot
#         ## p.violin <- topic.expressed.gene.RNA.long %>% ggplot(aes(x=expressed.gene.name, y=RNA.expression.count)) + mytheme +
#         ##     xlab(paste0("Top Specific Gene in Topic ", t)) + ylab("RNA expression count") + 
#         ##     ggtitle(paste0(ptb.name, " Perturbation, Topic ", t, " Top Specific Gene Expression")) +
#         ##     ggdist::stat_halfeye(
#         ##                 ## custom bandwidth
#         ##                 adjust = .5, 
#         ##                 ## adjust height
#         ##                 width = .3, 
#         ##                 ## move geom to the right
#         ##                 justification = -.4, 
#         ##                 ## remove slab interval
#         ##                 .width = 0, 
#         ##                 point_colour = NA
#         ##             ) +
#         ##     geom_boxplot(
#         ##         width = .1, 
#         ##         ## remove outliers
#         ##         outlier.color = NA ## `outlier.shape = NA` works as well
#         ##     ) +
#         ##      ## add justified jitter from the {gghalves} package
#         ##     gghalves::geom_half_point(
#         ##                   ## control point size
#         ##                   size = 0.5,
#         ##                   ## draw jitter on the left
#         ##                   side = "l", 
#         ##                   ## control range of jitter
#         ##                   range_scale = .4, 
#         ##                   ## add some transparency
#         ##                   alpha = .3
#         ##               ) +
#         ##     coord_cartesian(xlim = c(1.2, NA), clip = "off") 

#         ## p.fc.violin <- topic.expressed.gene.RNA.fc.long %>% ggplot(aes(x=expressed.gene.name, y=RNA.expression.fc)) + mytheme +
#         ##     xlab(paste0("Top Specific Gene in Topic ", t)) + ylab("RNA expression (FC)") +
#         ##     ggtitle(paste0(ptb.name, " Perturbation, Topic ", t, " Top Specific Gene Expression")) +
#         ##     ggdist::stat_halfeye(
#         ##                 ## custom bandwidth
#         ##                 adjust = .5, 
#         ##                 ## adjust height
#         ##                 width = .3, 
#         ##                 ## move geom to the right
#         ##                 justification = -.4, 
#         ##                 ## remove slab interval
#         ##                 .width = 0, 
#         ##                 point_colour = NA
#         ##             ) +
#         ##     geom_boxplot(
#         ##         width = .1, 
#         ##         ## remove outliers
#         ##         outlier.color = NA ## `outlier.shape = NA` works as well
#         ##     ) +
#         ##     ## add justified jitter from the {gghalves} package
#         ##     gghalves::geom_half_point(
#         ##                   ## control point size
#         ##                   size = 0.5,
#         ##                   ## draw jitter on the left
#         ##                   side = "l", 
#         ##                   ## control range of jitter
#         ##                   range_scale = .4, 
#         ##                   ## add some transparency
#         ##                   alpha = .3
#         ##               ) +
#         ##     coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#         ##     geom_hline(
#         ##         yintercept = 1,
#         ##         linetype = "dashed",
#         ##         color = "#38b4f7",
#         ##         size = 0.5
#         ##     ) 
#         ## p.log2fc.violin <- topic.expressed.gene.RNA.log2fc.long %>% ggplot(aes(x=expressed.gene.name, y=RNA.expression.log2fc)) + mytheme +
#         ##     xlab(paste0("Top Specific Gene in Topic ", t)) + ylab("RNA expression (log2FC)") +
#         ##     ggtitle(paste0(ptb.name, " Perturbation, Topic ", t, " Top Specific Gene Expression")) +
#         ##         ggdist::stat_halfeye(
#         ##                 ## custom bandwidth
#         ##                 adjust = .5, 
#         ##                 ## adjust height
#         ##                 width = .3, 
#         ##                 ## move geom to the right
#         ##                 justification = -.4, 
#         ##                 ## remove slab interval
#         ##                 .width = 0, 
#         ##                 point_colour = NA
#         ##             ) +
#         ##     geom_boxplot(
#         ##         width = .1, 
#         ##         ## remove outliers
#         ##         outlier.color = NA ## `outlier.shape = NA` works as well
#         ##     ) +
#         ##     ## add justified jitter from the {gghalves} package
#         ##     gghalves::geom_half_point(
#         ##                   ## control point size
#         ##                   size = 0.5,
#         ##                   ## draw jitter on the left
#         ##                   side = "l", 
#         ##                   ## control range of jitter
#         ##                   range_scale = .4, 
#         ##                   ## add some transparency
#         ##                   alpha = .3
#         ##               ) +
#         ##     coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#         ##     geom_hline(
#         ##         yintercept = 0,
#         ##         linetype = "dashed",
#         ##         color = "#38b4f7",
#         ##         size = 0.5
#         ##     ) 

#         ## pdf(file=paste0(perFACTORFIGDIR, "ptb.", ptb.name, "_top.gene.RNA.expression.violin.pdf"))
#         ## print(p.violin)
#         ## print(p.fc.violin)
#         ## print(p.log2fc.violin)
#         ## dev.off()


#         ## get RNA expression FC and topic FC df for KD efficacy plot and topic expressio change plot
#         ## negative control df
#         ctrl.topic.fc <- fc.omega %>% select(all_of(topic)) %>% mutate(Gene = rownames(.)) %>% subset(grepl("negative|safe", Gene))
#         ctrl.topic.RNA.expression.here <- get.topicFC.vs.RNAexpFC.toPlot(ctrl.topic.fc, top.expressed.gene.in.topic, ctrl.X) %>% mutate(CBC=Gene, Gene="Negative Control")

#         ## perturbation df
#         ptb.topic.RNA.expression.here <- get.topicFC.vs.RNAexpFC.toPlot(topic.fc.here, top.expressed.gene.in.topic, X.full.here) %>% mutate(CBC=Gene, Gene=ptb.name)

#         ## combine perturbation and control dfs
#         ptb.with.ctrl <- rbind(ctrl.topic.RNA.expression.here, ptb.topic.RNA.expression.here)

#         ## plot gene expression (% vs control) distribution for perturbation and for control side-by-side (violin plot)
#         p.list <- vector("list",length(top.expressed.gene.in.topic))
#         for ( expr.gene.index in 1:length(top.expressed.gene.in.topic) ) {
#             expr.gene <- top.expressed.gene.in.topic[expr.gene.index]
#             toPlot <- ptb.with.ctrl %>% select(Gene, all_of(topic), all_of(expr.gene))
#             ptb.ary <- toPlot %>% subset(Gene == ptb.name) %>% pull(all_of(expr.gene))
#             ctr.ary <- toPlot %>% subset(Gene == "Negative Control") %>% pull(all_of(expr.gene))
#             toPlot <- toPlot %>% melt(id.vars=c("Gene", topic), variable.name = "expr.gene", value.name = "log2fc.RNA.expression")
#             p.value <- wilcox.test(ptb.ary, ctr.ary)$p.value
#             p <- toPlot %>% ggplot(aes(x=expr.gene, y=log2fc.RNA.expression, fill=Gene)) + geom_split_violin() + mytheme + 
#                 xlab(paste0("(p-value: ", format.pval(p.value, digits=4), ")")) + ylab("RNA Expression (count)")
#             p.list[[expr.gene.index]] <- p
#         }

#         num_plot_row <- ceil(length(p.list)/5)
#         p <- ggarrange(plotlist = p.list, ncol = 5, nrow = num_plot_row, common.legend=T, legend="bottom")
#         annotate_figure(p, left = "RNA Expression (log2FC vs Control)")
#         pdf(paste0(perFACTORFIGDIR, "ptb.", ptb.name, "_top.gene.RNA.expression.violin.pdf"), width=10, height=3*num_plot_row)
#         print(p)
#         dev.off()

#         ## ## plot KD efficacy
#         ## p.KD.efficacy <- ptb.with.ctrl %>% ggplot(aes(x=Gene, y=get(ptb.name)*100)) +
#         ##             ggdist::stat_halfeye(
#         ##             ## custom bandwidth
#         ##             adjust = .5, 
#         ##             ## adjust height
#         ##             width = .3, 
#         ##             ## move geom to the right
#         ##             justification = -.4, 
#         ##             ## remove slab interval
#         ##             .width = 0, 
#         ##             point_colour = NA
#         ##         ) +
#         ## geom_boxplot(
#         ##     width = .1, 
#         ##     ## remove outliers
#         ##     outlier.color = NA ## `outlier.shape = NA` works as well
#         ## ) +
#         ## ## add justified jitter from the {gghalves} package
#         ## gghalves::geom_half_point(
#         ##               ## control point size
#         ##               size = 0.5,
#         ##               ## draw jitter on the left
#         ##               side = "l", 
#         ##               ## control range of jitter
#         ##               range_scale = 0.3,
#         ##               ## control verticle range of jitter
#         ##               transformation = position_jitter(height = 10),
#         ##               ## add some transparency
#         ##               alpha = .3
#         ##           ) +
#         ## coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#         ## mytheme +
#         ## geom_hline(
#         ##     yintercept = 100,
#         ##     linetype = "dashed",
#         ##     color = "#38b4f7",
#         ##     size = 0.5
#         ## ) +
#         ## ylab(paste0(ptb.name, " RNA Expression\n(% vs control)")) +
#         ## xlab("Perturbation") 

#         ## ## copied from summary plots
#         ## if(SEP) {
#         ##     label.here <- strsplit(ptb, split="-") %>% unlist() %>% nth(2) %>% paste0("-",.)
#         ##     ctrl.X.here <- ctrl.X %>% subset(grepl(label.here,Gene))
#         ##     ctrl.ann.omega.here <- ctrl.ann.omega %>% subset(grepl(label.here,Gene))
#         ## } else  {
#         ##     ctrl.X.here <- ctrl.X
#         ##     ctrl.ann.omega.here <- ctrl.ann.omega
#         ## }

#         ## p.topic.expression.with.ctrl <- ptb.with.ctrl %>% ggplot(aes(x=Gene, y=get(topic)*100)) + 
#         ##             ggdist::stat_halfeye(
#         ##             ## custom bandwidth
#         ##             adjust = .5, 
#         ##             ## adjust height
#         ##             width = .3, 
#         ##             ## move geom to the right
#         ##             justification = -.4, 
#         ##             ## remove slab interval
#         ##             .width = 0, 
#         ##             point_colour = NA
#         ##         ) +
#         ## geom_boxplot(
#         ##     width = .1, 
#         ##     ## remove outliers
#         ##     outlier.color = NA ## `outlier.shape = NA` works as well
#         ## ) +
#         ## ## add justified jitter from the {gghalves} package
#         ## gghalves::geom_half_point(
#         ##               ## control point size
#         ##               size = 0.5,
#         ##               ## draw jitter on the left
#         ##               side = "l", 
#         ##               ## control range of jitter
#         ##               range_scale = 0.3,
#         ##               ## control verticle range of jitter
#         ##               transformation = position_jitter(height = 10),
#         ##               ## add some transparency
#         ##               alpha = .3
#         ##           ) +
#         ## coord_cartesian(xlim = c(1.2, NA), clip = "off") +
#         ## mytheme +
#         ## geom_hline(
#         ##     yintercept = 100,
#         ##     linetype = "dashed",
#         ##     color = "#38b4f7",
#         ##     size = 0.5
#         ## ) +
#         ## ylab(paste0("Topic ", t, " Expression\n(% vs control)")) +
#         ## xlab("Perturbation") 

#         ## pdf(file=paste0(perFACTORFIGDIR, "ptb.", ptb.name, "_topic", t, ".expression.vs.ctrl.pdf"))
#         ## print(p.topic.expression.with.ctrl)
#         ## dev.off()


#         pdf(file=paste0(perFACTORFIGDIR, "ptb.", ptb.name, "_log2FC.vs.top.gene.RNA.expression.pdf"))
#         for (expressed.gene.here in top.expressed.gene.in.topic) {
#             print(paste0("plotting ", expressed.gene.here))
#             ## expressed.gene.here <- "EDN1" # loop over top genes in topic
#             ## old (converted to get.topicFC.vs.RNAexpFC.toPlot())
#             ## expressed.ptb.name.index <- which(grepl(expressed.gene.here, colnames(log2fc.X.full.here)))
#             ## expressed.gene.RNA.here <- X.full.here[,expressed.ptb.name.index] %>% as.data.frame %>% `colnames<-`(expressed.gene.here)
#             ## topic.expressed.gene.RNA.here <- merge(topic.log2fc.here, expressed.gene.RNA.here, by.x="Gene", by.y=0)
#             topic.expressed.gene.RNA.here <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, expressed.gene.here, X.full.here)
#             topic.expressed.gene.RNA.fc.here <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, expressed.gene.here, fc.X.full.here)
#             topic.expressed.gene.RNA.log2fc.here <- get.topicFC.vs.RNAexpFC.toPlot(topic.log2fc.here, expressed.gene.here, log2fc.X.full.here)

#             p.raw.x <- topic.expressed.gene.RNA.here %>% ggplot(aes(y=get(expressed.gene.here), x=get(topic))) + geom_point() + mytheme +
#                 ggtitle(paste0(ptb.name, " Perturbation")) + 
#                 ylab(paste0(expressed.gene.here, " RNA Expression")) + xlab(paste0("Topic ", gsub("topic_","",topic), " Expression (log2FC)"))

#             p.fc.x <- topic.expressed.gene.RNA.fc.here %>% ggplot(aes(y=get(expressed.gene.here), x=get(topic))) + geom_point() + mytheme +
#                 ggtitle(paste0(ptb.name, " Perturbation")) + 
#                 ylab(paste0(expressed.gene.here, " RNA Expression (FC)")) + xlab(paste0("Topic ", gsub("topic_","",topic), " Expression (log2FC)"))

#             p.log2fc.x <- topic.expressed.gene.RNA.log2fc.here %>% ggplot(aes(y=get(expressed.gene.here), x=get(topic))) + geom_point() + mytheme +
#                 ggtitle(paste0(ptb.name, " Perturbation")) + 
#                 ylab(paste0(expressed.gene.here, " RNA Expression (log2FC)")) + xlab(paste0("Topic ", gsub("topic_","",topic), " Expression (log2FC)"))

#             ## output plots

#             print(p.raw.x) ## fit a line
#             print(p.fc.x)
#             print(p.log2fc.x)

#         }
#         dev.off()
#     }
# }

# ##end of scratch:210818



# ##scratch:210825

# t <- 15
# topic <- paste0("topic_",t)
# perFACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/factor", t, "/")
# if(!dir.exists(perFACTORFIGDIR)) dir.create(perFACTORFIGDIR)
# perFACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/factor", t, "/", SAMPLE,"_K",k, "_dt_", DENSITY.THRESHOLD,"_factor", t, "_")

# top.ptb.name <- gene.score %>% as.data.frame %>% mutate(Gene = rownames(.)) %>% select(all_of(topic), Gene) %>% arrange(desc(get(topic))) %>% slice(1:10) %>% pull(Gene)
# top.ptb.name <- gsub("WTAP2", "WTAP", top.ptb.name)
# ptb <- "MESDC1"

# pdf(paste0(perFACTORFIGDIR, "top.ptb_pval.log2fc.expr.gene.volcano.pdf"), width=5, height=5)
# for ( ptb in top.ptb.name ) {
#     ##use.edgeR.results
#     edgeR.expr.gene.names <- rownames(log2fc.edgeR) %>% strsplit(split=":") %>% sapply("[[",1)
#     ptb.colindex <- which(grepl(ptb, colnames(log2fc.edgeR)))
#     theta.zscore.t <- theta.zscore[,t] %>% as.data.frame %>% `colnames<-`("topic.zscore.weight") %>% mutate(genes=rownames(.)) %>% arrange(desc(topic.zscore.weight)) %>% mutate(gene.rank = 1:n(), top100=ifelse(gene.rank <= 100, paste0("Top 100 in Topic ", t), paste0("Not in Top 100")))
#     ptb.pval.log2fc <- merge(p.value.edgeR[,c(1,ptb.colindex)] %>% `colnames<-`(c("genes", "pval")) %>% mutate(genes = strsplit(genes, split=":") %>% sapply("[[",1)), log2fc.edgeR[,c(1,ptb.colindex)] %>% `colnames<-`(c("genes", "log2fc")) %>% mutate(genes = strsplit(genes, split=":") %>% sapply("[[",1)), by="genes") %>% merge(theta.zscore.t, by="genes") %>% mutate(nlog10pval = -log10(pval))
#     ptb.pval.log2fc$nlog10pval[ptb.pval.log2fc$nlog10pval > 15] <- 15

#     ptb.pval.log2fc$top100 <- factor(ptb.pval.log2fc$top100) %>% ordered(levels=c(paste0("Top 100 in Topic ", t), "Not in Top 100"))

#     label <- ptb.pval.log2fc %>% subset(top100!="Not in Top 100")
#     ptb.10X <- ptb.10X.name.conversion$`Name used by CellRanger`[which(grepl(ptb, ptb.10X.name.conversion$Symbol))]
#     label.self <- ptb.pval.log2fc %>% subset(genes == ptb.10X)
#     p.density <- ptb.pval.log2fc %>% ggplot(aes(x=log2fc, fill=top100)) + geom_density(alpha=0.4) + mytheme + xlab("RNA Expression\n(log2FC vs Control") + coord_flip() + scale_fill_manual(values=c("red", "gray"), name = "") + theme(legend.position = "none") # + guides(fill=guide_legend(nrow=2, byrow=T))
#     p <- ggplot(ptb.pval.log2fc, aes(x=log2fc, y=nlog10pval)) + geom_point(size=0.1, color = "gray") +
#         geom_point(data = label, size=0.1, color = "red") +
#         mytheme +
#         geom_text_repel(data=label.self, box.padding = 0.5,
#                         max.overlaps=30,
#                         aes(label=genes), size=4,
#                         color="blue") +
#         geom_text_repel(data=label, box.padding = 0.5,
#                         max.overlaps=30,
#                         aes(label=genes), size=4,
#                         color="black") +
#         scale_color_manual(values=c("gray", "red")) +
#         geom_vline(xintercept=0, col = "#38b4f7", lty=3) +
#         theme(legend.position="bottom") + 
#         ggtitle(paste0("Topic ", t, " Perturbation ", ptb)) +
#         xlab("Average RNA Expression (log2FC vs control)") + ylab("p-value (-log10)") +
#         inset_element(p.density, left = 0.7, bottom = 0.05, right=0.95, top = 0.55)
#     print(p)
# }
# dev.off()

# ##end of use.edgeR.results


# ## ##redo.wilcoxon.test
# ## ptb.ctrl.rowindex <- which(grepl(ptb,X.gene.names) | grepl("negative|safe", X.gene.names))
# ## ptb.ctrl.X <- X.full[ptb.ctrl.rowindex,] 
# ## ptb.rowindex.new <- which(grepl(ptb, rownames(ptb.ctrl.X))) ## cells with ptb
# ## ctrl.rowindex.new <- which(!grepl(ptb, rownames(ptb.ctrl.X))) ## ctrl cells
# ## log2fc.ptb.ctrl.X <- log2fc.X.full[ptb.ctrl.rowindex,]
# ## df <- do.call(rbind, lapply(1:dim(ptb.ctrl.X)[2], function(expr.gene.index) {
# ##     expr.gene <- colnames(ptb.ctrl.X)[expr.gene.index]
# ##     p.value <- wilcox.test(ptb.ctrl.X[ptb.rowindex.new,expr.gene.index], ptb.ctrl.X[ctrl.rowindex.new,expr.gene.index])$p.value
# ##     return(data.frame(ptb = ptb,
# ##                       expr.gene = expr.gene,
# ##                       avg.log2fc.RNA.expr = log2fc.ptb.ctrl.X[ptb.rowindex.new, expr.gene.index] %>% mean,
# ##                       p.value = p.value))
# ## }))
# ## ##end of redo.wilcoxon.test


# ##end of scratch:210825

# #######################################################
# ## scatter plot of ( average log2FC of gene KD compared to control ) vs ( weight in the topic )
# ##scratch:210823
# t <- 15 ## loop over topics
# topic <- paste0("topic_",t)

# FACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/factor", t, "/")
# if(!dir.exists(FACTORFIGDIR)) dir.create(FACTORFIGDIR)
# FACTORFIGDIR <- paste0(FACTORFIGDIR, SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_factor", t, "_")

# pdf(paste0(FACTORFIGDIR, "top10.ptb.in.topic_RNA.expr.avg.log2fc_zscore.weight.pdf"))

# top.ptb <- gene.score %>% as.data.frame %>% select(all_of(topic)) %>% arrange(desc(get(topic)))
# top.gene.zscore <- theta.zscore[,t] %>% sort(decreasing=T) 
# top.gene.names <- names(top.gene.zscore)[1:100]
# top.gene.col.index <- which(grepl(paste0("^", paste0(top.gene.names, collapse="$|^"), "$"), colnames(fc.X.full)))

# num_ptb <- 10
# toPlot.list <- vector("list", num_ptb)
# for (ptb.index in 1:num_ptb) {
#     ptb <- rownames(top.ptb)[ptb.index]
#     ## ptb <- "MESDC1" ## loop based on top perturbation

#     ## subset RNA expression log2fc matrix
#     ptb.row.index <- X.gene.names %>% grepl(paste0("^",ptb,"$"), .) %>% which
#     ptb.fc.top100.topic.genes <- fc.X.full[ptb.row.index, top.gene.col.index]
#                                         # adjust for -Inf
#     ptb.mean.log2fc.top100.topic.genes <- ptb.fc.top100.topic.genes %>% apply(2, mean) %>% log2 ## apply log2 after the normalization and average
#     ## merge y-axis average expression with x-axis gene weight in topic
#     toPlot <- merge(data.frame(theta.zscore=top.gene.zscore[1:100],
#                                Gene=names(top.gene.zscore[1:100])),
#                     data.frame(log2fc.RNA.expr=ptb.mean.log2fc.top100.topic.genes,
#                                Gene=names(ptb.mean.log2fc.top100.topic.genes)),
#                     by="Gene")
#     toPlot.list[[ptb.index]] <- toPlot %>% mutate(Perturbation = ptb, log2fc.topic.expr=top.ptb %>% subset(rownames(.) == ptb) %>% pull(all_of(topic)))

#     ## scatter plot
#     p <- toPlot %>% ggplot(aes(x=theta.zscore, y=log2fc.RNA.expr)) + geom_point() + mytheme +
#         xlab(paste0("Topic ", t, " z-score (Specificity) Weight")) + ylab(paste0("Average RNA Expression (log2FC)")) +
#         ggtitle(paste0(ptb, " Perturbation Top 100 (by z-score) Gene Expression in Topic ", t)) +
#         geom_text_repel(data=toPlot, box.padding = 0.5,
#                         aes(label=Gene), size=4,
#                         color="black")


#     ## pdf(paste0(FACTORFIGDIR, "ptb.", ptb, "_RNA.expr.avg.log2fc_zscore.weight.pdf"))
#     print(p)
# }
# dev.off()

# toPlot <- do.call(rbind, toPlot.list)

# pdf(paste0(FACTORFIGDIR, "top10.ptb.in.topic_RNA.expr.avg.log2fc_zscore.weight.scratch.pdf"))
# p <- toPlot %>% ggplot(aes(x = log2fc.topic.expr, y = log2fc.RNA.expr, color = theta.zscore)) + geom_point(size=0.5) + mytheme +
#     xlab(paste0("Topic ", t, " Expression (log2FC)")) + ylab(paste0("Average RNA Expression (log2FC)"))
# print(p)
# p <- toPlot %>% ggplot(aes(x=theta.zscore, y=log2fc.RNA.expr, color=Perturbation)) + geom_point(size=0.1) + mytheme +
#     xlab(paste0("Topic ", t, " z-score (Specificity) Weight")) + ylab(paste0("Average RNA Expression (log2FC)")) +
#     ggtitle(paste0("Top 100 (by z-score) Gene Expression in Topic ", t))
# print(p)
# dev.off()

# ##end of scratch:210823


# ##scratch:210824
# ## violin plot of top 30 genes

# ##end of scratch:210824




# ## per factor summary plot by test
# FACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/")
# if(!dir.exists(FACTORFIGDIR)) dir.create(FACTORFIGDIR)
# FACTORFIGDIR <- paste0(FIGDIRSAMPLE, "factor.summary/", SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")

# ptb.zscore.long <- ptb.zscore %>% as.data.frame %>% mutate(Gene=rownames(.)) %>% melt(id.vars = "Gene", value.name = "perturbation.zscore", variable.name = "Topic") ## Perturbation z-score plot
# for(test.type.here in c("per.cell.wilcoxon", "per.guide.wilcoxon")) {
#     all.test.w <- all.test %>% subset(test.type==test.type.here)
#     realPvals.df.w <- realPvals.df %>% subset(test.type==test.type.here)
#     for (t in 1:dim(omega)[2]) { #:dim(omega)[2]
#         figure.path <- paste0(FACTORFIGDIR, "factor", t, "_with.sig.", test.type.here, ".0.1.perturbation")
#         ## raw weight version
#         topic <- paste0("topic_",t)
#         toPlot.all.test <- all.test.w %>% subset(Topic==topic)
#         toPlot.fdr <- realPvals.df.w %>% subset(Topic == topic) %>% select(Gene,fdr)
#         ## toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
#         ##                      Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores)) %>%
#         ##     merge(., gene.def.pathways, by="Gene", all.x=T) %>% arrange(desc(Score)) %>% slice(1:10)
#         ## toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#         toPlot <- data.frame(Gene=topFeatures.raw.weight %>% subset(topic == t) %>% pull(Gene),
#                              Score=topFeatures.raw.weight %>% subset(topic == t) %>% pull(scores)) %>%
#             merge(., summaries %>% select(Symbol, top_class), by.x="Gene", by.y="Symbol", all.x=T) %>% `colnames<-`(c("Gene", "Score", "Pathway")) %>% arrange(desc(Score)) %>% slice(1:10)
#         toPlot$Pathway[is.na(toPlot$Pathway)] <- "Other/Unclassified"
#         toPlot$Pathway[toPlot$Pathway == "unclassified"] <- "Other/Unclassified"
#         p4 <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score*100, fill=Pathway) ) + geom_col(width=0.5) + theme_minimal()
#         p4 <- p4 + coord_flip() + xlab("Top 10 Genes") + ylab("Raw Weights (z-score)") +
#             mytheme + theme(legend.position="bottom", legend.direction="vertical")


#                                         # plot 4
#                                         # add ABC to gene.set.type.df for this particular plot
#         ## gene.set.type.df$type[which(gene.set.type.df$Gene %in% gene.set)] <- "ABC"
#                                         # assemble toPlot
#         toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>%
#             merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             ## merge(.,gene.set.type.df,by="Gene") %>%
#             mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene)) %>%
#             merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#             mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
#             mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                             ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#             mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))
#         colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"
#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*")) %>%
#             arrange(log2FC) %>%
#             mutate(x = seq(nrow(.)))
#         label <- toPlot %>% subset(x <= 3 | x > (nrow(toPlot)-3) | adjusted.p.value < fdr.thr)
#         mytheme <- theme_classic() + theme(axis.text = element_text(size = 13), axis.title = element_text(size = 15), plot.title = element_text(hjust = 0.5, face = "bold"))
#         p5 <- toPlot %>% ggplot(aes(x=reorder(Gene, log2FC), y=log2FC, color=significant)) + geom_point(size=0.75) + mytheme +
#             theme(axis.ticks.x=element_blank(), axis.text.x=element_blank()) + scale_color_manual(values = c("#E0E0E0", "#38b4f7")) +
#             xlab("Perturbed Genes") + ylab(paste0("Factor ", t, " Expression log2 Fold Change")) +
#             geom_text_repel(data=label, box.padding = 0.5,
#                             aes(label=ann.Gene), size=4,
#                             color="black") +
#             theme(legend.position = "none") # legend at the bottom?

#         ## top perturbation list
#         ## toPlot.all.test <- all.test.w 
#         ## toPlot.fdr <- realPvals.df.w %>% subset(Topic == topic) %>% select(Gene,adjusted.p.value)
#                                         # assemble toPlot


#         toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             merge(.,gene.set.type.df,by="Gene") %>%
#             mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene)) %>%
#             merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#             mutate(EC_ctrl_text = ifelse(.$type == "EC_ctrls", "(+)", "")) %>%
#             mutate(GWAS.class.text = ifelse(grepl("CAD", type), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                             ifelse(grepl("IBD", type), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#             mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))
#          colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"

#         ## toPlot <- gene.score %>% select(all_of(topic)) %>% mutate(Gene=rownames(.)) %>% merge(.,toPlot.all.test,by="Gene", all.x=T) %>%
#         ##     merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#         ##     merge(.,gene.set.type.df,by="Gene") %>%
#         ##     mutate(Gene = gsub("Enhancer-at-CAD-SNP-","",Gene))
#         ## colnames(toPlot)[which(colnames(toPlot)==topic)] <- "log2FC"

#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*"))
#         toPlot.top <- toPlot %>% arrange(desc(log2FC)) %>% slice(1:25)
#         toPlot.bottom <- toPlot %>% arrange(log2FC) %>% slice(1:25)
#         toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#             mutate(color=ifelse(type %in% c("ABC","CAD focus"), "red",
#                          ifelse(type=="non-expressed", "grey",
#                          ifelse(type=="other", "blue", "black")))) %>%
#              mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*"))
#      #       mutate(Gene = ifelse(type=="ABC", paste0("[ ", Gene, " ]"), Gene))
#         p.ptb.list <- toPlot.extreme %>% arrange(desc(log2FC)) %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#             ggplot(aes(x=reorder(Gene, log2FC), y=log2FC, fill=significant)) + geom_col() + theme_minimal() +
#             coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("log2 Fold Change") +
#             scale_fill_manual(values=c("grey", "#38b4f7")) +
#             geom_text(aes(label = significant)) +
#             theme(legend.position = "none")#, axis.text.y = element_text(colour = toPlot.extreme$color))


#         ## ## enriched gene sets by fgsea
#         ## fgsea.here <- fgsea.df.all %>% subset(topic == t) %>% arrange(padj) 

#         ## Perturbation z-score
#         toPlot.all.test <- all.test %>% subset(test.type=="per.cell.wilcoxon" & Topic==topic)
#         toPlot.fdr <- realPvals.df %>% subset(test.type=="per.cell.wilcoxon" & Topic == topic) %>% select(Gene,fdr)
#                                         # assemble toPlot
#         toPlot <- ptb.zscore.long %>% subset(Topic == topic) %>% merge(.,toPlot.all.test,by=c("Gene","Topic"), all.x=T) %>%
#             merge(.,toPlot.fdr,by="Gene", all.x=T) %>%
#             merge(.,gene.set.type.df,by="Gene", all.x=T) %>% ##here210809
#             ## merge(.,gene.def.pathways, by="Gene", all.x=T) %>% 
#             merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Gene", by.y="Symbol", all.x=T) %>%
#             mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
#             mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
#                                      ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
#             mutate(ann.Gene = paste0(Gene, GWAS.class.text, EC_ctrl_text))

#         toPlot <- toPlot %>% mutate(significant=ifelse((adjusted.p.value >= fdr.thr | is.na(adjusted.p.value)), "", "*"))

#         toPlot.top <- toPlot %>% arrange(desc(perturbation.zscore)) %>% slice(1:25)
#         toPlot.bottom <- toPlot %>% arrange(perturbation.zscore) %>% slice(1:25)
#         toPlot.extreme <- rbind(toPlot.top, toPlot.bottom) %>%
#             mutate(color=ifelse(grepl("CAD", type), "red",
#                          ifelse(type=="non-expressed", "gray",
#                          ifelse(type=="EC_ctrls", "blue", "black"))))  %>%
#             mutate(color=ifelse(is.na(type), "black", color))
#         ## colors <- toPlot.extreme$color[order(toPlot.extreme %>% arrange(desc(perturbation.zscore)) %>% pull(color))]
#         toPlot.extreme <- toPlot.extreme %>% arrange(perturbation.zscore)
#         ## add gene distance to CAD
#         toPlot.extreme$ann.Gene <- factor(toPlot.extreme$ann.Gene, levels = toPlot.extreme$ann.Gene)
#         ## color y.axis.label
#         p.ptb.zscore <- toPlot.extreme %>%  #mutate(Gene = paste0("<span style = 'color: ", color, ";'>", Gene, "</span>")) %>%
#             ggplot(aes(x=ann.Gene, y=perturbation.zscore, fill=significant)) + geom_col() + theme_minimal() +
#             coord_flip() + xlab("Most Extreme Gene (Perturbation)") + ylab("Perturbation z-score") +
#             scale_fill_manual(values=c("grey", "#38b4f7")) +
#             geom_text(aes(label = significant)) +
#             theme(legend.position = "none", axis.text.y = element_text(colour = toPlot.extreme$color))


#         ## enriched TF motifs
#         toplot <- all.promoter.ttest.df %>% subset(topic==paste0("topic_",t) & top.gene.mean != 0 & !grepl("X.NA.",motif)) ##here:210816
#         p.promoter.motif <- volcano.plot(toplot, ep.type="promoter", ranking.type="z-score")
#         toplot <- all.enhancer.ttest.df.10en6 %>% subset(topic==paste0("topic_",t) & top.gene.mean != 0 & !grepl("X.NA.",motif))
#         p.enhancer.motif <- volcano.plot(toplot, ep.type="enhancer", ranking.type="z-score")

#         ## old as of 210816
#         ## volcano.plot <- function(toplot) {
#         ##     label <- toplot %>% subset(-log10(p.adjust) > 1) %>% mutate(motif.toshow = gsub("HUMAN.H11MO.", "", motif))
#         ##     t <- gsub("topic_", "", toplot$topic[1])
#         ##     p <- toplot %>% ggplot(aes(x=enrichment.log2fc, y=-log10(p.adjust))) + geom_point(size=0.5) + mytheme +
#         ##         ggtitle(paste0("Top 100 z-score Specific Promoter Motif Enrichment")) + xlab("Motif Enrichment (log2FC)") + ylab("-log10(adjusted p-value)") +
#         ##         geom_text_repel(data=label, box.padding = 0.5,
#         ##                         aes(label=motif.toshow), size=5,
#         ##                         color="black")
#         ##     return(p)
#         ## }
#         ## p.motif <- volcano.plot(toplot)

#         ## Factor expression on UMAP
#         plot.features <- paste0("K",k,"_",colnames(omega))
#         feature.name <- plot.features[grepl(paste0("_",t, "$"), plot.features)] ## make sure the seruat object has this feature
#         if ( grepl(feature.name, colnames([email protected])) %>% as.numeric() %>% sum() > 0 ) {
#             p.umap <- FeaturePlot(s, reduction = "umap", features=feature.name)
#         } else {
#             p.umap <- DimPlot(s, reduction = "umap")
#         }

#         p <- ggarrange(ggarrange(p4, p5, p.ptb.list, p.ptb.zscore, nrow=1), ggarrange(p.promoter.motif, p.enhancer.motif, p.umap, nrow=1), nrow=2)
#         ## p.left <- ggarrange(ggarrange(p4, p5, nrow=1), ggarrange(p.motif, p.umap, nrow=1), nrow=2)
#         ## p <- ggarrange(p.left, p.ptb.list, nrow=1, width=c(2.5,1))
#         p <- annotate_figure(p, top = text_grob(paste0("K = ", k, ", Factor ", t), face = "bold", size = 16))

#         ## pdf(paste0(figure.path, ".pdf"), width=10, height=12)
#         ## print(p)
#         ## dev.off()

#         png(paste0(figure.path, ".png"), width=1600, height=1200)
#         print(p)
#         dev.off()
#     }


#     ## ## convert the original pdf to png
#     ## im.convert(pdf.path, output = paste0(pdf.path, ".png"), extra.opts="-density 100") ## takes > 10 minutes

# }



# ## java options for allocating more memory to write Excel sheet
# options(java.parameters = "-Xmx16000m") ## 16 GB

# ## make an Excel file that has (topic x top gene x GeneCard summaries) and (topic x top perturbations x GeneCard summaries)
# ## load gene summary files
# theta.zscore.long <- theta.zscore %>% as.data.frame %>% mutate(Gene = rownames(.)) %>% melt(value.name = "score", id.vars="Gene", variable.name = "Factor")
# ## theta.zscore.annotation <- merge(theta.zscore.long %>% group_by(Factor) %>% arrange(desc(score)) %>% slice(1:100), gene.summary)
# ann_omega_long <- gene.score %>% as.data.frame %>% mutate(Gene=rownames(.)) %>% melt(value.name = "log2FC", id.vars = "Gene", variable.name = "Factor")
# ## ann.omega.long <- sqldf("select ann_omega_long.*, summaries.* from ann_omega_long left join summaries on instr(ann_omega_long.Gene, summaries.Symbol)") ## takes a while
# ann.omega.long <- merge(ann_omega_long, summaries, by.x="Gene", by.y="Symbol", all.x=T) %>% merge(., all.test %>% subset(test.type=="per.cell.wilcoxon"), by.x=c("Gene","Factor"), by.y=c("Gene","Topic"), all.x=T)
# ann.omega.long.top <- ann.omega.long %>% group_by(Factor) %>% arrange(desc(log2FC)) %>% slice(1:50) %>% as.data.frame
# ann.omega.long.bottom <- ann.omega.long %>% group_by(Factor) %>% arrange(log2FC) %>% slice(1:50) %>% as.data.frame
# ann.omega.long.sig <- ann.omega.long %>% subset(adjusted.p.value < 0.1) %>% as.data.frame
# ann.omega.long.output <- rbind(ann.omega.long.top, ann.omega.long.bottom, ann.omega.long.sig) %>% unique %>% arrange(Factor, desc(log2FC)) %>% relocate(c(adjusted.p.value,p.value,top_class,classes), .after="log2FC")

# ## load Gavin's top genes in factors with GeneCards information
# theta.zscore.long.output <- read_xlsx(paste0(opt$datadir, "210730_cNMF_topic_model_anal.xlsx"), sheet="top100_annotated")

# ## combine topic definition and perturbation information
# omega.theta.zscore.topic.analysis <- rbind(ann.omega.long.output %>% select(Factor, classes, top_class) %>% mutate(datasource = "omega"),
#                                            theta.zscore.long.output %>% select(Topic, Classes, Top_Class) %>% `colnames<-`(c("Factor", "classes", "top_class")) %>% mutate(datasource = "theta.zscore")) %>% 
#     group_by(Factor, top_class) %>%
#     mutate(top_class_count = n()) %>% ungroup() %>%
#     filter(!is.na(top_class)) %>% ## remove rows with NA in column `top_class`
#     separate_rows(classes, sep=";", convert=T) %>% ## separate classes by ";" and expand the rows
#     filter(!is.na(classes)) %>% subset(classes != "") %>%
#     group_by(Factor, classes) %>%
#     mutate(class_count = n()) %>% ungroup() %>% unique() %>%
#     arrange(Factor, desc(top_class_count), desc(class_count)) 

# ## separate top_class and class into two dfs
# classes.output <- omega.theta.zscore.topic.analysis %>% select(Factor, classes, class_count) %>% as.data.frame %>% unique %>% arrange(Factor,desc(class_count))
# top_class.output <- omega.theta.zscore.topic.analysis %>% select(Factor, top_class, top_class_count) %>% as.data.frame %>% unique %>% arrange(Factor,desc(top_class_count))
# ## write to xlsx
# topic.ptb.summary.xlsx.path <- paste0(OUTDIRSAMPLE, "Significant.or.Top50Ptb_per.cell.wilcoxon.adj.0.1_Summary_", SUBSCRIPT, ".xlsx") 
# write_xlsx(list(Annotations = ann.omega.long.output,
#                 Factor_pathway_summary = omega.theta.zscore.topic.analysis,
#                 Classes_summary = classes.output,
#                 Top_class_summary = top_class.output),
#            path=topic.ptb.summary.xlsx.path, col_names=T)

# ## heatmap for all classes
# classes.heatmap <- classes.output
# classes.heatmap$class_count <- as.numeric(classes.heatmap$class_count)
# classes.heatmap <- classes.heatmap %>% spread(key = classes, value = class_count, fill = as.numeric(0)) %>% `rownames<-`(gsub("topic", "factor",.$Factor)) %>% select(-Factor) %>% as.matrix

# ## plot heatmap
# plotHeatmap <- function(mtx){
#     heatmap.2(
#     mtx, 
#     Rowv=T, 
#     Colv=T,
#     trace='none',
#     key=T,
#     col=palette,
#     labCol=colnames(mtx),
#     margins=c(15,5), 
#     cex.main=0.1, 
#     cexCol=2.5/(nrow(mtx)^(1/3)), cexRow=1.7/(ncol(mtx)^(1/3)),
#     main=paste0(SAMPLE, ", K=", k, ", Factor Pathway Enrichment")
# )
# }

# pdf(paste0(FIGDIRTOP, "factor.pathway.enrichment.pdf"), width= 12, height = 9)
# mtx <- classes.heatmap
# plotHeatmap(mtx)
# mtx <- classes.heatmap %>% apply(2, function(x) x/sum(x)) 
# plotHeatmap(mtx)
# dev.off()




# ptb.array <- c("GOSR2", "TP53", "CDKN1A", "EDN1", "NOS3", "FGD6", "ELN")
# ptb.array <- c("TP53", "ELN", "PHB", "LRPPRC", "MESDC1")

# ## manual QC
# pdf(paste0(FIGDIRSAMPLE, "/manual.QC.pdf"))
# toPlot <- cell.per.ptb <- ann.omega %>% subset(!grepl("^neg|^safe",Gene)) %>% group_by(Gene) %>% summarize(cell.count = n()) # number of cell per perturbation
# p <- toPlot %>% ggplot(aes(x=cell.count)) + stat_ecdf() + mytheme + ggtitle("Number of Cells per Perturbation") + xlab("Number of Cells") + ylab("Fraction of Perturbed Genes")
# print(p)
# p <- toPlot %>% ggplot(aes(x=cell.count)) + geom_histogram() + mytheme + ggtitle("Number of Cells per Perturbation") + xlab("Number of Cells") + ylab("Number of Perturbed Genes")
# print(p)

# toPlot <- guide.per.ptb <- ann.omega %>% subset(!grepl("^neg|^safe",Gene)) %>% select(Gene,Guide) %>% unique() %>% group_by(Gene) %>% summarize(guide.count = n())
# p <- toPlot %>% ggplot(aes(x=guide.count)) + stat_ecdf() + mytheme + ggtitle("Number of Guides per Perturbation") + xlab("Number of Guides") + ylab("Fraction of Perturbed Genes")
# print(p)
# p <- toPlot %>% ggplot(aes(x=guide.count)) + geom_histogram() + mytheme + ggtitle("Number of Guides per Perturbation") + xlab("Number of Guides") + ylab("Number of Perturbed Genes")
# print(p)

# toPlot <- ann.omega %>% group_by(Guide) %>% summarize(count = n())
# p <- toPlot %>% ggplot(aes(x=count)) + stat_ecdf() + mytheme + ggtitle("Number of Cells per Guide") + xlab("Number of Cells") + ylab("Fraction of Perturbed Guides")
# print(p)
# p <- toPlot %>% ggplot(aes(x=count)) + stat_ecdf() + mytheme + ggtitle("Number of Cells per Guide") + xlim(0,20) + xlab("Number of Cells") + ylab("Fraction of Perturbed Guides")
# print(p)
# dev.off()

## cell.QC.data <- merge(cell.per.ptb, guide.per.ptb, by="Gene")




# ##########################################################################
# ## dotplot of log2FC for perturbation and control
# ## CDF of log2FC for perturbation and control
# ## per gene and per cell
# # create a directory for this type of files (one perturbation per file)
# FIGDIR.HERE=paste0(FIGDIRTOP,"log2FC_dotplot_CDF_each.perturbation/")
# if(!dir.exists(FIGDIR.HERE)) dir.create(FIGDIR.HERE, recursive=T)
 # 
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import pandas as pd
# import scipy.sparse as sp
# import scanpy as sc
import anndata as ad
from cnmf import cNMF
import argparse
import re

## argparse
parser = argparse.ArgumentParser()
parser.add_argument('--path_to_topics', type=str, help='path to the topic (cNMF directory) to project data on', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes_acrossK')
parser.add_argument('--topic_sampleName', type=str, help='sample name for topics to project on, use the same sample name as used for the cNMF directory', default='WeissmanK562gwps')
# parser.add_argument('--tpm_counts_path', type=str, help='path to tpm input cell x gene matrix', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes_acrossK/2kG.library_overdispersedGenes/cnmf_tmp/2kG.library_overdispersedGenes.tpm.h5ad') #/scratch/groups/engreitz/Users/kangh/cNMF_pipeline/220505_snakemake_moreK_findK/all_genes/K60/worker0/2kG.library/cnmf_tmp/2kG.library.norm_counts.h5ad')
parser.add_argument('--outdir', dest = 'outdir', type=str, help = 'path to output directory', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K10/threshold_0_2/IGVF_format/')
parser.add_argument('--k', dest = 'k', type=int, help = 'number of components', default='10')
parser.add_argument('--density_threshold', dest = 'density_threshold', type=float, help = 'component spectra clustering threshold, 2 for no filtering, recommend 0_2 (means 0.2)', default="0.2")
parser.add_argument('--barcode_dir', dest = 'barcode_dir', type=str, default='/oak/stanford/groups/engreitz/Users/kangh/collab_data/IGVF/mouse_ENCODE_heart/auxiliary_data/snrna/heart_Parse_10x_integrated_metadata.csv', help='Directory to barcodes, require columns CBC and Gene')

args = parser.parse_args()


# ## sdev for IGVF_b01_LeftCortex, all_genes, K=60
# args.outdir = '/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/IGVF_b01_LeftCortex/K60/threshold_0_2/IGVF_format/'
# args.path_to_topics = '/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes_acrossK'
# args.topic_sampleName = 'IGVF_b01_LeftCortex'
# args.k = 60
# args.density_threshold = 0.2
# args.barcode_dir = '/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_igvf_b01_LeftCortex_data/IGVF_b01_LeftCortex.barcodes.txt'


sample = args.topic_sampleName
# output_sample = args.output_sampleName
# tpm_counts_path = args.tpm_counts_path
OUTDIR = args.outdir
selected_K = args.k
density_threshold = args.density_threshold
output_directory = args.path_to_topics
run_name = args.topic_sampleName
barcode_dir = args.barcode_dir




cnmf_obj = cNMF(output_dir=output_directory, name=run_name)
usage_norm, gep_scores, gep_tpm, topgenes = cnmf_obj.load_results(K=selected_K, density_threshold=density_threshold)


## load cell barcode inforamtion
# barcode_dir = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"
if barcode_dir.endswith('csv'):
    barcodes_df = pd.read_csv(barcode_dir, index_col="CBC")
else:
    barcodes_df = pd.read_csv(barcode_dir, sep="\t", index_col="CBC")

## organize program name
programNames = [run_name + "_K" + str(selected_K) + "_" + str(i) for i in usage_norm.columns]
programNames_df = pd.DataFrame({"ProgramNames": programNames}, index=programNames)

usage_norm.columns = programNames

barcodes_df = barcodes_df.loc[usage_norm.index,:] ## sort the barcodes df to match with usage_norm
## create AnnData
adata = ad.AnnData(
    X = usage_norm,
    obs = barcodes_df,
    var = programNames_df
)

## save results
fileName = OUTDIR + sample + ".k_" + str(selected_K) + ".dt_" + re.sub("[.]", "_", str(density_threshold)) + ".cellxgene.h5ad"
adata.write_h5ad(fileName)
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("rename", "dplyr")

suppressPackageStartupMessages({
    library(optparse)
    library(dplyr)
    library(tidyr)
    library(reshape2)
    library(ggplot2)
    library(ggpubr) ## ggarrange
    library(gplots) ## heatmap.2
    library(ggrepel)
    library(readxl)
    library(xlsx) ## might not need this package
    library(writexl)
    library(org.Hs.eg.db)
})


option.list <- list(
    make_option("--sampleName", type="character", default="2kG.library", help="Name of the sample"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/K25/threshold_0_2/", help="Output directory"),
    make_option("--scratch.outdir", type="character", default="", help="Scratch space for temporary files"),
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    ## make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    ## make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--perturbSeq", type="logical", default=TRUE, help="Whether this is a Perturb-seq experiment")
)
opt <- parse_args(OptionParser(option_list=option.list))



## ## K562 gwps 2k overdispersed genes
## ## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K90/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/"
## opt$K.val <- 90
## opt$sampleName <- "WeissmanK562gwps"
## opt$perturbSeq <- TRUE
## opt$scratch.outdir <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/K90/analysis/comprehensive_program_summary/"
## opt$barcodeDir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"

## OUTDIR <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220316_regulator_topic_definition_table/outputs/"
## FIGDIR <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220316_regulator_topic_definition_table/figures/"
## SCRATCHOUTDIR <- "/scratch/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220316_regulator_topic_definition_table/outputs/"
OUTDIR <- opt$outdir
SCRATCHOUTDIR <- opt$scratch.outidr
check.dir <- c(OUTDIR, SCRATCHOUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

mytheme <- theme_classic() + theme(axis.text = element_text(size = 12), axis.title = element_text(size = 16), plot.title = element_text(hjust = 0.5, face = "bold"))
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)



## parameters
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/"
OUTDIRSAMPLE <- opt$outdir
k <- opt$K.val 
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
## SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)

## ## load ref table (for Perturbation distance to GWAS loci annotation in Perturb_plus column)
## ref.table <- read.delim(file="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/ref.table.txt", header=T, check.names=F, stringsAsFactors=F)

## ## load test results
## all.test.combined.df <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220312_compare_statistical_test/outputs/all.test.combined.df.txt"), stringsAsFactors=F)
## all.test.MAST.df <- all.test.combined.df %>% subset(test.type == "batch.correction")

## MAST.df.4n.input <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/2kG.library4n3.99x_MAST.txt"), stringsAsFactors=F, check.names=F)
## MAST.df.4n <- MAST.df.4n.input %>% ## remove multiTarget entries
##     subset(!grepl("multiTarget", perturbation)) %>%
##     group_by(zlm.model.name) %>%
##     mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method="fdr")) %>%
##     ungroup() %>%
##     subset(zlm.model.name == "batch.correction") %>%
##     select(-zlm.model.name) %>%
##     as.data.frame
## colnames(MAST.df.4n) <- c("Topic", "p.value", "log2FC", "log2FC.ci.hi", "log2fc.ci.lo", "fdr", "Perturbation", "fdr.across.ptb")

## Load Regulator Data
if(opt$perturbSeq) {
    MAST.file.name <- paste0(OUTDIR, "/", SAMPLE, "_MAST_DEtopics.txt")
    message(paste0("loading ", MAST.file.name))
    MAST.df <- read.delim(MAST.file.name, stringsAsFactors=F, check.names=F) %>% rename("Perturbation" = "perturbation") %>%
        rename("log2FC" = "coef", "log2FC.ci.hi" = ci.hi, "log2FC.ci.lo" = ci.lo, "p.value" = "Pr(>Chisq)")
    if(grepl("topic", MAST.df$primerid) %>% sum > 0) MAST.df <- MAST.df %>% mutate(ProgramID = paste0("K", k, "_", gsub("topic_", "", primerid))) %>% as.data.frame    
}



## ## 2n MAST
## file.name <- paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes//Perturb_2kG_dup4/acrossK/aggregated.outputs.findK.perturb-seq.RData")
## load(file.name)
## MAST.df.2n <- MAST.df %>%
##     filter(K == 60) %>%
##     select(-K) %>%
##     select(-zlm.model.name)
## colnames(MAST.df.2n) <- c("Topic", "p.value", "log2FC", "log2FC.ci.hi", "log2fc.ci.lo", "fdr", "Perturbation", "fdr.across.ptb")



## load gene annotations (Refseq + Uniprot)
gene.summaries.path <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/combined.gene.summaries.txt"
gene.summary <- read.delim(gene.summaries.path, stringsAsFactors=F)


## load topic model results
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
}



## modify theta.zscore if Gene is in ENSGID
db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
gene.ary <- theta.zscore %>% rownames
if(grepl("^ENSG", gene.ary) %>% as.numeric %>% sum == nrow(theta.zscore)) {
    GeneSymbol.ary <- mapIds(get(db), keys=gene.ary, keytype = "ENSEMBL", column = "SYMBOL")
    GeneSymbol.ary[is.na(GeneSymbol.ary)] <- row.names(theta.zscore)[is.na(GeneSymbol.ary)]
    rownames(theta.zscore) <- GeneSymbol.ary
}

## omega.4n <- omega
## theta.zscore.4n <- theta.zscore
## theta.raw.4n <- theta.raw
meta_data <- read.delim(opt$barcodeDir, stringsAsFactors=F) 

## ## batch topics
## batch.topics.4n <- read.delim(file=paste0(OUTDIRSAMPLE, "/batch.topics.txt"), stringsAsFactors=F) %>% as.matrix %>% as.character

ann.omega <- merge(meta_data, omega, by.x="CBC", by.y=0, all.T)


##########################################################################################
## create table
create_topic_definition_table <- function(theta.zscore, t) {
    out <- theta.zscore[,t] %>%
        as.data.frame %>%
        `colnames<-`(c("zscore")) %>%
        mutate(Perturbation = rownames(theta.zscore), .before="zscore") %>%
        merge(gene.summary, by.x="Perturbation", by.y="Gene", all.x=T) %>%
        arrange(desc(zscore)) %>%
        mutate(Rank = 1:n(), .before="Perturbation") %>%
        mutate(ProgramID = paste0("K", k, "_", t), .before="zscore") %>%
        arrange(Rank) %>%
        mutate(My_summary = "", .after = "zscore") %>%
        select(Rank, ProgramID, Perturbation, zscore, My_summary, FullName, Summary)
}

create_topic_regulator_table <- function(all.test, program.here, fdr.thr = 0.1) {
    out <- MAST.df %>%
        subset(ProgramID == program.here &
               fdr.across.ptb < fdr.thr) %>%
        select(Perturbation, fdr.across.ptb, log2FC, log2FC.ci.hi, log2FC.ci.lo, fdr, p.value) %>%
        merge(gene.summary, by.x="Perturbation", by.y="Gene", all.x=T) %>%
        arrange(fdr.across.ptb, desc(log2FC)) %>%
        mutate(Rank = 1:n(), .before="Perturbation") %>%
        mutate(My_summary = "", .after="Perturbation") %>%
        mutate(ProgramID = program.here, .after="Rank") %>%
        ## merge(., ref.table %>% select("Symbol", "TSS.dist.to.SNP", "GWAS.classification"), by.x="Perturbation", by.y="Symbol", all.x=T) %>%
        ## mutate(EC_ctrl_text = ifelse(.$GWAS.classification == "EC_ctrls", "(+)", "")) %>%
        ## mutate(GWAS.class.text = ifelse(grepl("CAD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb"),
        ##                          ifelse(grepl("IBD", GWAS.classification), paste0("_", floor(TSS.dist.to.SNP/1000),"kb_IBD"), ""))) %>%
        ## mutate(Perturb_plus = paste0(Perturbation, GWAS.class.text, EC_ctrl_text)) %>%
    select(Rank, ProgramID, Perturbation, fdr.across.ptb, log2FC, My_summary, FullName, Summary, log2FC.ci.hi, log2FC.ci.lo, fdr, p.value) %>% ## removed Perturb_plus
        arrange(Rank)
}



create_summary_table <- function(ann.omega, theta.zscore, all.test, meta_data) {
    df.list <- vector("list", k)
    for (t in 1:k) {
        program.here <- paste0("K", k, "_", t)

        ## topic defining genes
        ann.theta.zscore <- theta.zscore %>% create_topic_definition_table(t)
        ann.top.theta.zscore <- ann.theta.zscore %>% subset(Rank <= 100) ## select the top 100 topic defining genes to output

        ## regulators
        if(opt$perturbSeq) regulator.MAST.df <- all.test %>% create_topic_regulator_table(program.here, 0.3)

        ## write table to scratch dir
        file.name <- paste0(SCRATCHOUTDIR, program.here, "_table.csv")
        sink(file=file.name) ## open the document
        ## cat("Author,PERTURBATIONS SUMMARIES\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\n\n\nAuthor,TOPIC SUMMARIES\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\nAuthor,TESTABLE HYPOTHESIS IDEAS:\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\nAuthor,OTHER THOUGHTS:\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\nTOPIC DEFINING GENES (TOP 100),\n") ## headers
        cat("Author,PERTURBATIONS SUMMARIES,,,,,,,,,,,\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\"\nAuthor,TOPIC SUMMARIES\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\"\nAuthor,TESTABLE HYPOTHESIS IDEAS:\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\"\nAuthor,OTHER THOUGHTS:\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\"\n\nTOPIC DEFINING GENES (TOP 100),\n") ## headers
        write.csv(ann.top.theta.zscore, row.names=F) ## topic defining genes
        cat("\n\"\n\"\nPERTURBATIONS REGULATING TOPIC AT FDR < 0.3 (most significant on top),,,,,,,,,,,\n") ## headers

        if(opt$perturbSeq) write.csv(regulator.MAST.df, row.names=F) ## regulators
        sink() ## close the document

        ## read the assembled table to save to list
        df.list[[t]] <- read.delim(file.name, stringsAsFactors=F, check.names=F, sep=",")
    }


    ## output to xlsx
    names(df.list) <- paste0("Program ", 1:k)
    write_xlsx(df.list, paste0(OUTDIR, "/", SAMPLE, "_k_", k, ".dt_", DENSITY.THRESHOLD, "_ComprehensiveProgramSummary.xlsx"))

    return(df.list)
}


if(opt$perturbSeq == "F") MAST.df <- data.frame()
df <- create_summary_table(ann.omega, theta.zscore, MAST.df, meta_data)
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
suppressPackageStartupMessages(library(conflicted))
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("first", "dplyr")

suppressPackageStartupMessages({
    library(optparse)
    library(dplyr)
    library(tidyr)
    library(reshape2)
    ## library(ggplot2)
    ## library(cowplot)
    ## library(ggpubr) ## ggarrange
    ## library(gplots) ## heatmap.2
    ## library(scales) ## geom_tile gradient rescale
    ## library(ggrepel)
    library(stringr)
    library(stringi)
    library(svglite)
    library(Seurat)
    library(SeuratObject)
    library(xlsx)
})


##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--sampleName", type="character", default="2kG.library", help="Name of the sample"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/K25/threshold_0_2/", help="Output directory"),
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"),
    make_option("--perturbSeq", type="logical", default=TRUE, help="Whether this is a Perturb-seq experiment")
)
opt <- parse_args(OptionParser(option_list=option.list))


## ## K562 gwps 2k overdispersed genes
## ## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K90/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/"
## opt$K.val <- 90
## opt$sampleName <- "WeissmanK562gwps"
## opt$perturbSeq <- TRUE

## ## ENCODE mouse heart
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes/mouse_ENCODE_heart/K10/threshold_0_2/"
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$K.val <- 10
## opt$perturbSeq <- FALSE


OUTDIR <- opt$outdir
SAMPLE <- opt$sampleName
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
check.dir <- c(OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))
fdr.thr <- 0.05

db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
suppressPackageStartupMessages(library(!!db)) ## load the appropriate database

## Load Data
## Batch Correlation table
load(paste0(OUTDIR, "/batch.correlation.RDS"))
if("topic" %in% colnames(max.batch.correlation.df)){
    max.batch.correlation.df <- max.batch.correlation.df %>%
        mutate(ProgramID = paste0("K", k, "_", gsub("topic_", "", topic))) %>%
        as.data.frame
}
if ("ProgramID" %in% colnames(max.batch.correlation.df)) {
    max.batch.correlation.df <- max.batch.correlation.df %>%
        mutate(ProgramID = paste0("K", k, "_", gsub("topic_", "", ProgramID))) %>%
        as.data.frame
}


##################################################
## load cNMF results
cNMF.result.file <- paste0(OUTDIR,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
if(file.exists(cNMF.result.file)) {
    message(paste0("loading cNMF result file: \n", cNMF.result.file))
    load(cNMF.result.file)
} else {
	print(paste0(cNMF.result.file, " does not exist"))
}


## helper function to map between ENSGID and SYMBOL
map.ENSGID.SYMBOL <- function(df) {
    ## need column `Gene` to be present in df
    ## detect gene data type (e.g. ENSGID, Entrez Symbol)
    if(!("Gene" %in% colnames(df))) df$Gene = df$ENSGID
    gene.type <- ifelse(nrow(df) == sum(as.numeric(grepl("^ENS", df$Gene))),
                        "ENSGID",
                        "Gene")
    if(gene.type == "ENSGID") {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "ENSEMBL", column = "SYMBOL")
        df <- df %>% mutate(ENSGID = Gene, Gene = mapped.genes)
    } else {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "SYMBOL", column = "ENSEMBL")
        df <- df %>% mutate(ENSGID = mapped.genes)
    }
    df <- df %>% mutate(Gene = ifelse(is.na(Gene), ENSGID, Gene))
    return(df)
}

if(sum(c("ENSGID", "Gene") %in% colnames(median.spectra.zscore.df)) < 2) {
    median.spectra.zscore.df <- median.spectra.zscore.df %>% map.ENSGID.SYMBOL
}

## get list of topic defining genes by z-score coefficients
if(!("theta.zscore.rank.df" %in% ls())) {
    theta.zscore.rank.list <- vector("list", ncol(theta.zscore))## initialize storage list
    for(i in 1:ncol(theta.zscore)) {
        topic <- paste0("topic_", colnames(theta.zscore)[i])
        theta.zscore.rank.list[[i]] <- theta.zscore %>%
            as.data.frame %>%
            select(all_of(i)) %>%##here
            `colnames<-`("topic.zscore") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(topic.zscore), .before="topic.zscore") %>%
            mutate(zscore.specificity.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    theta.zscore.rank.df <- do.call(rbind, theta.zscore.rank.list) %>%  ## combine list to df
        `colnames<-`(c("topic.zscore", "Gene", "zscore.specificity.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL
}


if(!("theta.tpm.rank.df" %in% ls())) {
    theta.tpm.rank.list <- vector("list", ncol(theta.raw))## initialize storage list
    for(i in 1:ncol(theta.zscore)) {
        topic <- paste0("topic_", colnames(theta.raw)[i])
        theta.tpm.rank.list[[i]] <- theta %>%
            as.data.frame %>%
            select(all_of(i)) %>%
            `colnames<-`("topic.zscore") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(topic.zscore), .before="topic.zscore") %>%
            mutate(zscore.specificity.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    theta.tpm.rank.df <- do.call(rbind, theta.tpm.rank.list) %>%  ## combine list to df
        `colnames<-`(c("program.tpm.coef", "Gene", "tpm.coef.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL %>% mutate(Gene = ifelse(is.na(Gene), ENSGID, Gene))
}

## get list of topic genes by median spectra weight
if(!("median.spectra.rank.df" %in% ls())) {
    median.spectra.rank.list <- vector("list", ncol(median.spectra))## initialize storage list
    for(i in 1:ncol(median.spectra)) {
        topic <- paste0("topic_", colnames(median.spectra)[i])
        median.spectra.rank.list[[i]] <- median.spectra %>%
            as.data.frame %>%
            select(all_of(i)) %>%
            `colnames<-`("median.spectra") %>%
            mutate(Gene = rownames(.)) %>%
            arrange(desc(median.spectra), .before="median.spectra") %>%
            mutate(median.spectra.rank = 1:n()) %>% ## add rank column
            mutate(Topic = topic) ## add topic column
    }
    median.spectra.rank.df <- do.call(rbind, median.spectra.rank.list) %>%  ## combine list to df
        `colnames<-`(c("median.spectra", "Gene", "median.spectra.rank", "ProgramID")) %>%
        mutate(ProgramID = gsub("topic_", paste0("K", k, "_"), ProgramID)) %>%
        as.data.frame %>% map.ENSGID.SYMBOL %>% mutate(Gene = ifelse(is.na(Gene), ENSGID, Gene))
    ## median.spectra.zscore.df <- median.spectra.zscore.df %>% mutate(Gene = ENSGID) ## quick fix, need to add "Gene" column to this dataframe in analysis script
}

## Load Regulator Data
if(opt$perturbSeq) {
    MAST.file.name <- paste0(OUTDIR, "/", SAMPLE, "_MAST_DEtopics.txt")
    message(paste0("loading ", MAST.file.name))
    MAST.df <- read.delim(MAST.file.name, stringsAsFactors=F, check.names=F)
    if(grepl("topic", MAST.df$primerid) %>% sum > 0) MAST.df <- MAST.df %>% mutate(ProgramID = paste0("K", k, "_", gsub("topic_", "", primerid))) %>% as.data.frame    
}

## Load Promoter and Enhancer TF Motif Enrichment Data
add.ProgramID <- function(df) {
    if("topic" %in% c(df %>% colnames)) {
        return(df %>% mutate(ProgramID = paste0('K', k, '_', gsub('topic_', '', topic))) %>% as.data.frame)
    }
}
num.top.genes <- 300
## all.ttest.df.path <- paste0(OUTDIRSAMPLE,"/", ep.type, ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.", motif.match.thr.str, "_", SUBSCRIPT.SHORT,".txt")
promoter.ttest.df.path <- paste0(OUTDIR,"/", "promoter", ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.", "pval1e-4", "_", SUBSCRIPT.SHORT,".txt")
enhancer.ttest.df.path <- paste0(OUTDIR,"/", "promoter", ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.", "pval1e-6", "_", SUBSCRIPT.SHORT,".txt")
promoter.ttest.df <- read.delim(promoter.ttest.df.path, stringsAsFactors=F) %>% add.ProgramID
enhancer.ttest.df <- read.delim(enhancer.ttest.df.path, stringsAsFactors=F) %>% add.ProgramID


## GO terms
## file.name <- paste0(OUTDIRSAMPLE, "/clusterProfiler_GeneRankingType", ranking.type.here, "_EnrichmentType", GSEA.type,".txt")
file.name <- paste0(OUTDIR, "/clusterProfiler_GeneRankingType", "zscore", "_EnrichmentType", "GOEnrichment",".txt")
theta.zscore.GO.df <- read.delim(file.name, stringsAsFactors=F)
file.name <- paste0(OUTDIR, "/clusterProfiler_GeneRankingType", "median_spectra_zscore", "_EnrichmentType", "GOEnrichment",".txt")
median_spectra_zscore.GO.df <- read.delim(file.name, stringsAsFactors=F)


####################################################################################################
## Gather all columns
## MaxBatchCorrelation
MaxBatchCorrelation <- max.batch.correlation.df %>%
    select(ProgramID, maxPearsonCorrelation) %>%
    `colnames<-`(c("ProgramID", "MaxBatchCorrelation")) %>%
    as.data.frame

## nSigPerturbationsProgramUp
SigPerturbations.df <- MAST.df %>%
    subset(fdr.across.ptb < fdr.thr) %>%
    as.data.frame
nSigPerturbationsProgramUp <- SigPerturbations.df %>%
    subset(coef > log(1.1)) %>%
    group_by(ProgramID) %>%
    summarize(nSigPerturbationsProgramUp = n()) %>%
    as.data.frame
nSigPerturbationsProgramDown <- SigPerturbations.df %>%
    subset(coef < log(0.9)) %>%
    group_by(ProgramID) %>%
    summarize(nSigPerturbationsProgramDown = n()) %>%
    as.data.frame

## nSigMotifsPromoter
nSigMotifsPromoter <- promoter.ttest.df %>%
    subset(one.sided.p.adjust < fdr.thr) %>%
    group_by(ProgramID) %>%
    summarize(nSigMotifsPromoter = n()) %>%
    as.data.frame

## nSigMotifsEnhancer
nSigMotifsEnhancer <- enhancer.ttest.df %>%
    subset(one.sided.p.adjust < fdr.thr) %>%
    group_by(ProgramID) %>%
    summarize(nSigMotifsEnhancer = n()) %>%
    as.data.frame

## ProgramGenesZScoreCoefficientTop10
ProgramGenesZScoreCoefficientTop10 <- theta.zscore.rank.df %>%
    mutate(Gene = ifelse(is.na(Gene), ENSGID, Gene)) %>%
    subset(zscore.specificity.rank < 10) %>%
    group_by(ProgramID) %>%
    arrange(desc(topic.zscore)) %>%
    summarize(ProgramGenesZScoreCoefficientTop10 = paste0(Gene, collapse=",")) %>%
    as.data.frame

## ProgramGenesTPMCoefficientTop10
ProgramGenesTPMCoefficientTop10 <- theta.tpm.rank.df %>%
    subset(tpm.coef.rank < 10) %>%
    group_by(ProgramID) %>%
    arrange(desc(program.tpm.coef)) %>%
    summarize(ProgramGenesTPMCoefficientTop10 = paste0(Gene, collapse=",")) %>%
    as.data.frame

## ProgramGenesMedianSpectraTop10
ProgramGenesMedianSpectraTop10 <- median.spectra.rank.df %>%
    subset(median.spectra.rank < 10) %>%
    mutate(Gene = ifelse(is.na(Gene), ENSGID, Gene)) %>%
    group_by(ProgramID) %>%
    arrange(desc(median.spectra)) %>%
    summarize(ProgramGenesMedianSpectraTop10 = paste0(Gene, collapse = ",")) %>%
    as.data.frame

## ProgramGenesMedianSpectraZScoreTop10
ProgramGenesMedianSpectraZScoreTop10 <- median.spectra.zscore.df %>%
    subset(median.spectra.zscore.rank < 10) %>%
    group_by(ProgramID) %>%
    arrange(desc(median.spectra.zscore)) %>%
    summarize(ProgramGenesMedianSpectraZScoreTop10 = paste0(Gene, collapse = ",")) %>%
    as.data.frame

## ProgramGenesMotifsPromoter
ProgramGenesMotifsPromoter <- promoter.ttest.df %>%
    subset(one.sided.p.adjust < fdr.thr) %>%
    group_by(ProgramID) %>%
    arrange(one.sided.p.adjust) %>%
    summarize(ProgramGenesMotifsPromoter = paste0(motif, collapse = ",")) %>%
    as.data.frame

## ProgramGenesMotifsEnhancer
ProgramGenesMotifsEnhancer <- enhancer.ttest.df %>%
    subset(one.sided.p.adjust < fdr.thr) %>%
    group_by(ProgramID) %>%
    arrange(one.sided.p.adjust) %>%
    summarize(ProgramGenesMotifsEnhancer = paste0(motif, collapse = ",")) %>%
    as.data.frame

## ProgramGenesZScoreCoefficientGOTermsTop10
ProgramGenesZScoreCoefficientGOTermsTop10 <- theta.zscore.GO.df %>%
    group_by(ProgramID) %>%
    arrange(fdr.across.ont) %>%
    slice(1:10) %>%
    mutate(GOTerm = paste0(ONTOLOGY, ":", ID, ":", Description)) %>%
    summarize(ProgramGenesZScoreCoefficientGOTermsTop10 = paste0(GOTerm, collapse = ",")) %>%
    as.data.frame

## ProgramGenesMedianSpectraZScoreGOTermsTop10
ProgramGenesMedianSpectraZScoreGOTermsTop10 <- median_spectra_zscore.GO.df %>%
    group_by(ProgramID) %>%
    arrange(fdr.across.ont) %>%
    slice(1:10) %>%
    mutate(GOTerm = paste0(ONTOLOGY, ":", ID, ":", Description)) %>%
    summarize(ProgramGenesMedianSpectraZScoreGOTermsTop10 = paste0(GOTerm, collapse = ",")) %>%
    as.data.frame



## ProgramGenesTF*


## Combine all columns
ProgramSummary.list <- list(
    MaxBatchCorrelation = MaxBatchCorrelation,
    nSigPerturbationsProgramUp = nSigPerturbationsProgramUp,
    nSigPerturbationsProgramDown = nSigPerturbationsProgramDown,
    nSigMotifsPromoter = nSigMotifsPromoter,
    nSigMotifsEnhancer = nSigMotifsEnhancer,
    ProgramGenesZScoreCoefficientTop10 = ProgramGenesZScoreCoefficientTop10,
    ProgramGenesTPMCoefficientTop10 = ProgramGenesTPMCoefficientTop10,
    ProgramGenesMedianSpectraTop10 = ProgramGenesMedianSpectraTop10,
    ProgramGenesMedianSpectraZScoreTop10 = ProgramGenesMedianSpectraZScoreTop10,
    ProgramGenesMotifsPromoter = ProgramGenesMotifsPromoter,
    ProgramGenesMotifsEnhancer = ProgramGenesMotifsEnhancer,
    ProgramGenesZScoreCoefficientGOTermsTop10 = ProgramGenesZScoreCoefficientGOTermsTop10,
    ProgramGenesMedianSpectraZScoreGOTermsTop10 = ProgramGenesMedianSpectraZScoreGOTermsTop10
)

ProgramSummary.df <- Reduce(function(x, y, ...) full_join(x, y, by = "ProgramID", ...), ProgramSummary.list)

## ## Write Table to Text File
fileName <- paste0(SAMPLE, "_ProgramSummary_", SUBSCRIPT.SHORT)
write.table(ProgramSummary.df, file=paste0(OUTDIR, "/", fileName, ".xlsx"), row.names=F, quote=F, sep="\t")
## Write Excel File
write.xlsx(ProgramSummary.df, file=paste0(OUTDIR, "/", fileName, ".xlsx"), sheetName="Gene Selection Table (For experimental design)", row.names=F, showNA=F)
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import pandas as pd
import numpy as np
from scipy.io import mmread
import scipy.sparse as sp
# import matplotlib.pyplot as plt
#from IPython.display import Image
import scanpy as sc
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--inputPath', dest = 'inputPath', type=str, default='/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.txt', help = 'path to count matrix file')
# parser.add_argument('--outPath', dest = 'outPath', type=str, help = 'path to output folder')
# parser.add_argument('--run_name',dest ='run_name', type=str, help = 'sample name')
parser.add_argument('--output_h5ad_mtx', dest = 'output_h5ad_mtx', type=str, help = 'path to output folder', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/211014_test_txt_to_h5ad/outputs/test.h5ad')
parser.add_argument('--output_gene_name_txt', dest = 'output_gene_name_txt', type=str, help = 'path to gene name output txt file', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/211014_test_txt_to_h5ad/outputs/test.h5ad.all.genes.txt')
args = parser.parse_args()


# ## 200 gene library 230610, no IL1B
# args.inputPath = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/no_IL1B.raw.h5ad'
# args.output_h5ad_mtx = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/no_IL1B.h5ad'
# args.output_gene_name_txt = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/no_IL1B.h5ad.all.genes.txt'


# ## 200 gene library 230610, plus IL1B
# args.inputPath = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/plus_IL1B.raw.h5ad'
# args.output_h5ad_mtx = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/plus_IL1B.h5ad'
# args.output_gene_name_txt = '/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/plus_IL1B.h5ad.all.genes.txt'


## load scanpy from txt
adata = sc.read(args.inputPath)

# remove non protein coding genes
df_columns = pd.Series(adata.var_names.values.flatten()).astype(str)
AL = df_columns[df_columns.str.contains('^AL[0-9][0-9][0-9][0-9][0-9][0-9]\\.').tolist()]
AC = df_columns[df_columns.str.contains('^AC[0-9][0-9][0-9][0-9][0-9][0-9]\\.').tolist()]
AP = df_columns[df_columns.str.contains('^AP[0-9][0-9][0-9][0-9][0-9][0-9]\\.').tolist()]
LINC = df_columns[df_columns.str.contains('LINC').tolist()]
allPattern = df_columns[df_columns.str.contains('^[A-Za-z][A-Za-z][0-9][0-9][0-9][0-9][0-9][0-9]\\.').tolist()]
toremove = allPattern.append(LINC)
tokeep = ~df_columns.isin(toremove)
adata = adata[:,tokeep]
# filter cells
sc.pp.filter_cells(adata, min_genes=200) # filter cells with fewer than 200 genes
sc.pp.filter_cells(adata, min_counts=200)  # This is a weaker threshold than above. It is just to population the n_counts column in adata
sc.pp.filter_genes(adata, min_cells=10) # filter genes detected in fewer than 3 cells

# save to h5ad file
sc.write(args.output_h5ad_mtx, adata)

# get gene names after filtering
filtered_genes = pd.DataFrame(adata.var_names.values.flatten()).astype(str)

# save gene names
filtered_genes.to_csv(args.output_gene_name_txt, sep="\t", header=False, index=False)
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
COORD=$1
FASTA=$2
OUTFASTA=$3


# PROJECT=$OAK/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/
# DATADIR=$PROJECT/data/
# FILEDIR=$OAK/Data/hg38
# OUTDIR=${PROJECT}/outputs/
# ABCDIR=$OAK/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/
# TOPDATADIR=$OAK/Users/kangh/2009_endothelial_perturbseq_analysis/data/
# TOPDATADIRABC=$OAK/Users/kangh/2009_endothelial_perturbseq_analysis/data/ABC/
# SCRATCHDIR=${SCRATCH}/210509_topic_motif_enrichment/


# mkdir -p $OUTDIR
# mkdir -p $DATADIR
# mkdir -p $TOPDATADIRABC
# mkdir -p $SCRATCHDIR
# LOG=${PROJECT}/logs/
# mkdir -p $LOG
# QSUB=/home/groups/engreitz/bin/quick-sub
# hg38FASTA=/oak/stanford/groups/engreitz/Data/hg38/Sequence/hg38.fa
# hg19FASTA=/oak/stanford/groups/engreitz/Data/hg19/Sequence/hg19.fa



############################################################
## get fasta for the enhancer regions
# chr, start, end, name, class, activity_base, TargetGene, TargetGeneTSS, TargetGeneExpression, TargetGenePromoterActivityQuantile, TargetGeneIsExpressed, distance, isSelfPromoter, powerlaw_contact, powerlaw_contact_reference, hic_contact, hic_contact_pl_scaled, hic_pseudocount, hic_contact_pl_scaled_adj, ABC.Score.Numerator, ABC.Score, powerlaw.Score.Numerator, powerlaw.Score, CellType
bedtools getfasta -name -fi ${FASTA} -bed <(awk 'OFS="\t" {print $1,$2,$3,$1":"$2"-"$3"|"$4"|"$7}' ${COORD}) -fo ${OUTFASTA}

    # bedtools getfasta -name -fi ${hg19FASTA} -bed <(awk 'OFS="\t" {print $1,$2,$3,$1":"$2"-"$3"|"$4"|"$7}' ${COORD}) -fo ${TOPDATADIRABC}/${sample}_Predictions.AvgHiC.ABC0.015.minus150.fa
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("first", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("melt", "reshape2")
conflict_prefer("filter", "dplyr")

packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr") #, "grid", "gtable", "gridExtra","ggrepel",#"ramify",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")



##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--sampleName", type="character", default="2kG.library", help="sample name"),
    make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/2kG.library.barcodes.tsv", help="barcodes.tsv for all cells"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    ## make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    ## make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--outdirsample", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/", help="path to cNMF analysis results"), ## or for 2n1.99x: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/K60/threshold_0_2/"
    make_option("--scatteroutput", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/K35/threshold_0_2/", help="path to gene breakdown table output"),
    make_option("--total.scatter.gene.group", type="numeric", default=50, help="Total number of groups for running MAST"),

    ## script dir
    make_option("--scriptdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts/", help="location for this script and functions script")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## sdev debug K562 gwps
## opt$sampleName <- "WeissmanK562gwps"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K35/threshold_0_2/"
## opt$scatteroutput <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/WeissmanK562gwps/MAST/K35/threshold_0_2/"
## opt$total.scatter.gene.group <- 494
## opt$scriptdir <- "/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts"
## opt$K.val <- 35


k <- opt$K.val 
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
## SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
OUTDIRSAMPLE <- opt$outdirsample
SCATTEROUTDIR <- opt$scatteroutput
NUMGROUPS <- opt$total.scatter.gene.group



OUTDIR <- OUTDIRSAMPLE
check.dir <- c(OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

## load data
MAST.df.list <- vector("list", NUMGROUPS)
for (i in 1:NUMGROUPS) {
    MAST.df.list[[i]] <- read.delim(paste0(SCATTEROUTDIR, "/", SAMPLE, "_MAST_DEtopics_Group", i, ".txt"), stringsAsFactors=F, check.names=F)
}
MAST.df <- do.call(rbind, MAST.df.list) %>%
    group_by(zlm.model.name) %>%
    mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method='fdr')) %>%
    as.data.frame

write.table(MAST.df, file=paste0(OUTDIRSAMPLE, "/", SAMPLE, "_MAST_DEtopics.txt"), quote=F, row.names=F, sep="\t")
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("melt", "reshape2")
conflict_prefer("filter", "dplyr")

packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel",#"ramify",
              "ggpubr","gridExtra", "parallel", "future",
              "org.Hs.eg.db","limma","conflicted", #"fgsea", 
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "Seurat", "writexl", "SingleCellExperiment", "MAST") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--sampleName", type="character", default="2kG.library", help="sample name"),
    make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/2kG.library.barcodes.tsv", help="barcodes.tsv for all cells"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    ## make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    ## make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--outdirsample", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/", help="path to cNMF analysis results"), ## or for 2n1.99x: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/K60/threshold_0_2/"
    make_option("--scatteroutput", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/", help="path to gene breakdown table output"),
    make_option("--numCtrl", type="numeric", default=5000, help="number of control cells to use for MAST")

    ## ## script dir
    ## make_option("--scriptdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts/", help="location for this script and functions script")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## sdev debug K562 gwps
## opt$sampleName <- "WeissmanK562gwps"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K35/threshold_0_2/"
## opt$scatteroutput <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/K35/threshold_0_2/"
## opt$scatter.gene.group <- 496
## opt$scriptdir <- "/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts"
## opt$K.val <- 35


k <- opt$K.val 
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
## SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K31/threshold_0_2/"
OUTDIRSAMPLE <- opt$outdirsample
SCATTEROUTDIR <- opt$scatteroutput
SCATTERINDEX <- opt$scatter.gene.group

INPUTDIR <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/inputs/"
OUTDIR <- OUTDIRSAMPLE
check.dir <- c(INPUTDIR, OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))



##########################################################################################
## load data

## load topic model results
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
}

barcode.names <- read.delim(opt$barcode.names, stringsAsFactors=F)
## separate out control cells
omega.ctrl.index <- barcode.names %>% mutate(rowindex = 1:n()) %>% filter(Gene == "negative-control") %>% pull(rowindex)
omega.ptb.index <- barcode.names %>% mutate(rowindex = 1:n()) %>% filter(Gene != "negative-control") %>% pull(rowindex)
omega.ctrl <- omega[omega.ctrl.index,]
omega.ptb <- omega[omega.ptb.index,]
barcode.names.ctrl <- barcode.names[omega.ctrl.index,]
barcode.names.ptb <- barcode.names[omega.ptb.index,]
## randomly subset to 5000 cells
ctrl.subset.index <- sample(1:length(omega.ctrl.index), min(length(omega.ctrl.index), opt$numCtrl), replace=FALSE)
omega.ctrl.subset <- omega.ctrl[ctrl.subset.index,]
barcode.names.ctrl.subset <- barcode.names.ctrl[ctrl.subset.index,]

omega.new <- rbind(omega.ptb, omega.ctrl.subset)
barcode.names.subset <- rbind(barcode.names.ptb, barcode.names.ctrl.subset)

omega.tpm <- omega.new %>% as.matrix %>% apply(2, function(x) x / sum(x) * 1000000) ## convert to TPM
log2.omega <- (omega.tpm + 1) %>% log2 ## log2(TPM + 1)

## output log2(TPM + 1)
save(log2.omega, barcode.names.subset, file=paste0(opt$scatteroutput, "/", SAMPLE, "_MAST_log2TPM_barcodes.RDS"))
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("melt", "reshape2")
conflict_prefer("filter", "dplyr")

packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel",#"ramify",
              "ggpubr","gridExtra", "parallel", "future",
              "org.Hs.eg.db","limma","conflicted", #"fgsea", 
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "Seurat", "writexl", "SingleCellExperiment", "MAST") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--sampleName", type="character", default="2kG.library", help="sample name"),
    make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/2kG.library.barcodes.tsv", help="barcodes.tsv for all cells"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--outdirsample", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/", help="path to cNMF analysis results"), ## or for 2n1.99x: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/K60/threshold_0_2/"
    make_option("--scatteroutput", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/K80/threshold_0_2/", help="path to gene breakdown table output"),
    make_option("--gene.group.list", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/GeneNames_Group43.txt"),
    make_option("--scatter.gene.group", type="numeric", default=50, help="Gene group index"),

    ## script dir
    make_option("--scriptdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts/", help="location for this script and functions script")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## sdev debug K562 gwps
## opt$sampleName <- "WeissmanK562gwps"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K35/threshold_0_2/"
## opt$scatteroutput <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/WeissmanK562gwps/MAST/K80/threshold_0_2/"
## opt$gene.group.list <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/GeneNames_Group43.txt"
## opt$scatter.gene.group <- 43
## opt$scriptdir <- "/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts"
## opt$K.val <- 80

## ## no IL1B
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/data/no_IL1B.barcodes.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/analysis/top2000VariableGenes/no_IL1B/K15/threshold_0_2/"
## opt$scatteroutput <- "/scratch/groups/engreitz/Users/kangh/cNMF_pipeline/tutorials/2306_V2G2P_prep/top2000VariableGenes/no_IL1B/MAST/K15/threshold_0_2/"
## opt$gene.group.list <- "/scratch/groups/engreitz/Users/kangh/cNMF_pipeline/tutorials/2306_V2G2P_prep/top2000VariableGenes/no_IL1B/MAST/GeneNames_Group12.txt"
## opt$scatter.gene.group <- 12
## opt$sampleName <- "no_IL1B"
## opt$K.val <- 15
## opt$density.thr <- 0.2
## opt$scriptdir <- "/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts"



k <- opt$K.val 
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K31/threshold_0_2/"
OUTDIRSAMPLE <- opt$outdirsample
SCATTEROUTDIR <- opt$scatteroutput
SCATTERINDEX <- opt$scatter.gene.group

INPUTDIR <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/inputs/"
OUTDIR <- OUTDIRSAMPLE
check.dir <- c(INPUTDIR, OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

## mytheme <- theme_classic() + theme(axis.text = element_text(size = 12), axis.title = element_text(size = 16), plot.title = element_text(hjust = 0.5, face = "bold"))
## palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
## p.adjust.thr <- 0.1



##########################################################################################
## load data
## ## load known CAD genes
## CAD.genes <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/known_CAD_gene_set.txt", header=F, stringsAsFactors=F) %>% as.matrix %>% as.character

## ## load topic model results
## cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
## print(cNMF.result.file)
## if(file.exists(cNMF.result.file)) {
##     print("loading cNMF result file")
##     load(cNMF.result.file)
## }

## file.name <- paste0(OUTDIRSAMPLE,"/cNMFAnalysis.",SUBSCRIPT,".RData")
## print(file.name) 
## if(file.exists((file.name))) { 
##     print(paste0("loading ",file.name))
##     load(file.name) 
## }

## ## add UMI / cell information
## umiPerCell.df <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/outputs/UMIPerCell.txt", stringsAsFactors=F, row.names=1) %>% mutate(long.CBC = rownames(.)) %>%
##     separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
##     separate(col="CBC", into=c("CBC", "sample"), sep="-scRNAseq_2kG_", remove=F)

## sample.to.10X.lane <- data.frame(sample_num = 1:20,
##                                  sample = umiPerCell.df$sample %>% unique %>% sort)
## umiPerCell.df <- merge(umiPerCell.df, sample.to.10X.lane, by="sample") %>%
##     mutate(CBC_10x = paste0(CBC, "-", sample_num))  

## ## add guide / cell information
## guidePerCell.df <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/inputs/_UPDATED_ALL_SAMPLES_dup4_NON_NA_CALLS_FOR_EA_CBC_FULL_INFO.txt", stringsAsFactors=F) 

## ## add genes detected / cell information
## geneDetectedPerCell.df <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/outputs/geneDetectedPerCell.txt", stringsAsFactors=F) %>%
##     mutate(long.CBC = rownames(.)) %>%
##     separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
##     separate(col="CBC", into=c("CBC", "sample"), sep="-scRNAseq_2kG_", remove=F) %>%
##     merge(sample.to.10X.lane, by="sample") %>%
##     mutate(CBC_10x = paste0(CBC, "-", sample_num))  


## load gene list
gene.ary <- read.delim(opt$gene.group.list, header=F, stringsAsFactors=F) %>% unlist %>% as.character

## load log2(TPM + 1)
load(paste0(opt$scatteroutput, "/", SAMPLE, "_MAST_log2TPM_barcodes.RDS"))

print("organize meta data")
## get metadata from log2.X.full rownames
# if(SAMPLE %in% c("2kG.library", "Perturb_2kG_dup4")) {
if( grepl("2kG.library|Perturb_2kG_dup4", SAMPLE) ) {
    barcode.names <- read.table(opt$barcode.names, header=F, stringsAsFactors=F) %>% `colnames<-`("long.CBC")
    ## rownames(omega) <- barcode.names %>% pull(long.CBC) %>% gsub("CSNK2B-and-CSNK2B", "CSNK2B",.)
    meta_data <- barcode.names %>%
        rownames %>%
        as.data.frame %>%
        `colnames<-`("long.CBC") %>%
        separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
        ## separate(col="CBC", into=c("CBC", "sample"), sep="-", remove=F) %>%
        separate(col="CBC", into=c("CBC", "sample"), sep="-scRNAseq_2kG_", remove=F) %>%
        mutate(Gene = gsub("-TSS2", "", Gene.full.name),
               CBC = gsub("RHOA-and-", "", CBC)) %>%
        filter(Gene %in% c(gene.ary, "negative-control", "safe-targeting")) %>%
        as.data.frame
    sample.to.10X.lane <- data.frame(sample_num = 1:20,
                                     sample = meta_data$sample %>% unique %>% sort)

    ## meta_data <- merge(meta_data, sample.to.10X.lane, by="sample") %>%
    ##     mutate(CBC_10x = paste0(CBC, "-", sample_num)) %>%
    ##     merge(guidePerCell.df %>% select(CBC_10x, guides_per_cbc, max_umi_ct), by="CBC_10x") %>%
    ##     merge(umiPerCell.df, by="long.CBC") %>%
    ##     merge(geneDetectedPerCell.df, by.x="long.CBC", by.y=0)
} else {
    ## barcode.names <- read.table(opt$barcode.names, header=T, stringsAsFactors=F) ## %>% `colnames<-`("long.CBC")
    ## print("finished loading barcode names")
    ## print(paste0("omega dimensions: ", dim(omega)))
    ## print(paste0("barcode names dimensions: ", dim(barcode.names)))
    meta_data <- barcode.names.subset %>% filter(Gene %in% c(gene.ary, "negative-control"))
    meta_data <- meta_data %>% mutate(sample = factor(sample))
}    



## load model fitting fomulas
test.cmd.df <- read.delim(paste0(INPUTDIR, "MAST_model_formulas.txt"), stringsAsFactors=F)

## ## 220222 scratch
## if ( !( "ann.omega" %in% ls()) ) {
##     if( grepl("2kG.library|Perturb_2kG_dup4", SAMPLE) ) {
##         ann.omega <- omega %>%
##             as.data.frame %>%
##             mutate(long.CBC = rownames(.)) %>%
##             separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
##             mutate(Gene = gsub("-TSS2$", "", Gene.full.name))
##     } else {
##         ann.omega <- omega %>% as.data.frame %>%
##             mutate(CBC = rownames(.)) %>%
##             merge(barcode.names, by="CBC", all.x=T)
##         ## meta_data <- barcode.names
##     }
## } else {                        
##     ann.omega <- ann.omega %>%
##         mutate(Gene = gsub("-TSS2$", "", Gene.full.name)) ## remove TSS2 annotation
## }
## test

## omega.tpm <- omega %>% as.matrix %>% apply(2, function(x) x / sum(x) * 1000000) ## convert to TPM
## ## end of test
## ## omega.tpm <- ann.omega[,1:k] %>% as.matrix %>% apply(2, function(x) x / sum(x) * 1000000) ## convert to TPM
## log2.omega <- (omega.tpm + 1) %>% log2 ## log2(TPM + 1)
## ## ann.omega.original <- ann.omega

## ## load log2(TPM + 1)
## log2.omega <- readRDS(paste0(opt$scatteroutput, "/", SAMPLE, "_MAST_log2TPM.RDS"))

if( grepl("2kG.library|Perturb_2kG_dup4", SAMPLE) ) {
    df <- merge(log2.omega %>% as.data.frame %>% mutate(long.CBC = rownames(.)), meta_data, by="long.CBC")
} else {
    df <- merge(log2.omega %>% as.data.frame %>% mutate(CBC = rownames(.)), meta_data, by="CBC", all.y=T)
}


## gene.here <- "MESDC1"
gene.list <- gene.ary
num.ptb <- length(gene.list)
cat(paste0("number of genes: ", num.ptb, "\n"))

## MAST.list <- vector("list", num.ptb)
MAST.list <- mclapply(1:num.ptb, function(i) {
    gene.here <- gene.list[i]
    out <- tryCatch(
    { suppressWarnings({ 
        totest.df <- df %>%
            mutate(Gene = gsub("safe-targeting","negative-control", Gene)) %>% ## combine safe targeting and negative control guide to call them "negative-control"
            subset(Gene %in% c("negative-control", gene.here)) ## take a small slice of data (one perturbation one control)

        scaRaw <- FromMatrix(totest.df %>% select(-any_of(colnames(meta_data))) %>% t)
        colData(scaRaw)$perturb_status <- totest.df %>% pull(Gene)
        colData(scaRaw)$lane <- totest.df %>% 
                          pull(sample)
        cond <- factor(colData(scaRaw)$perturb_status)
        cond <- relevel(cond, "negative-control")
        colData(scaRaw)$condition <- cond

        format.zlm.result <- function(zlmCond) {
            condition.str <- paste0('condition', gene.here)
            summaryCond <- summary(zlmCond, doLRT=condition.str, parallel = T)

            summaryDt <- summaryCond$datatable
            fcHurdle <- merge(summaryDt[contrast==condition.str & component=='H',.(primerid, `Pr(>Chisq)`)], #hurdle P values
                              summaryDt[contrast==condition.str & component=='logFC', .(primerid, coef, ci.hi, ci.lo)], by='primerid') %>% #logFC coefficients
                mutate(fdr:=p.adjust(`Pr(>Chisq)`, 'fdr'),
                       perturbation = gene.here)
            return(fcHurdle)
        }

        fcHurdle <- do.call(rbind, lapply(1:nrow(test.cmd.df), function(i) {
            zlmCond <- eval(parse(text = paste0("zlmCond <- zlm(", test.cmd.df$zlm.model.command[i], ", scaRaw)")))
            fcHurdle.here  <- format.zlm.result(zlmCond) %>% mutate(zlm.model.name = test.cmd.df$zlm.model.name[i])
            return(fcHurdle.here)
        }))
        cat(gene.here)

        return(fcHurdle)
    })},
    warning = function(cond) {
        return(fcHurdle)
    },
    error = function(cond) {
        return(data.frame(primerid = NA,
                          `Pr(>Chisq)`= NA,
                          coef = NA,
                          ci.hi = NA,
                          ci.lo = NA,
                          fdr = NA,
                          perturbation = gene.here, check.names=F,
                          zlm.model.name=NA)
               )
    }
    )
    ## MAST.list[[i]] <- fcHurdle %>% mutate(perturbation = gene.here)
}, mc.cores = max(1, floor(availableCores() - 1))) ## 64G, K=35, 3 perturbations took 30 minutes


MAST.df <- do.call(rbind, MAST.list) %>%
    ## group_by(zlm.model.name) %>%
    ## mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method='fdr')) %>%
    as.data.frame

write.table(MAST.df, file=paste0(SCATTEROUTDIR, "/", SAMPLE, "_MAST_DEtopics_Group", opt$scatter.gene.group, ".txt"), quote=F, row.names=F, sep="\t")
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
library(conflicted)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("melt", "reshape2")
conflict_prefer("filter", "dplyr")

packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr") #, "grid", "gtable", "gridExtra","ggrepel",#"ramify",
              ## "ggpubr","gridExtra", "parallel", "future",
              ## "org.Hs.eg.db","limma","conflicted", #"fgsea", 
              ## "cluster","textshape","readxl", 
              ## "ggdist", "gghalves", "Seurat", "writexl", "SingleCellExperiment", "MAST") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
xfun::pkg_attach(packages)
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")


##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--sampleName", type="character", default="2kG.library", help="sample name"),
    make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/2kG.library.barcodes.tsv", help="barcodes.tsv for all cells"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--outdirsample", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/", help="path to cNMF analysis results"), ## or for 2n1.99x: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/K60/threshold_0_2/"
    make_option("--num.genes.per.MAST.runGroup", type="numeric", default=494, help="Number of MAST parallel processes to create"), 
    make_option("--scatteroutput", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/", help="path to gene breakdown table output"),

    ## script dir
    make_option("--scriptdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts/", help="location for this script and functions script")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## sdev debug K562 gwps
## opt$sampleName <- "WeissmanK562gwps"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/data/K562_gwps_raw_singlecell_01_metadata.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K35/threshold_0_2/"
## opt$scatteroutput <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/230104_snakemake_WeissmanLabData/top2000VariableGenes/MAST/"
## opt$scriptdir <- "/oak/stanford/groups/engreitz/Users/kangh/cNMF_pipeline/Perturb-seq/workflow/scripts"
## opt$K.val <- 35


k <- opt$K.val 
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K31/threshold_0_2/"
## OUTDIRSAMPLE <- opt$outdirsample
SCATTEROUTDIR <- opt$scatteroutput

## INPUTDIR <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/inputs/"
## OUTDIR <- OUTDIRSAMPLE
## check.dir <- c(INPUTDIR, OUTDIR, SCATTEROUTDIR)
check.dir <- c(SCATTEROUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

## mytheme <- theme_classic() + theme(axis.text = element_text(size = 12), axis.title = element_text(size = 16), plot.title = element_text(hjust = 0.5, face = "bold"))
## palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)
## p.adjust.thr <- 0.1



## ##########################################################################################
## ## load data
## ## load known CAD genes
## CAD.genes <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/known_CAD_gene_set.txt", header=F, stringsAsFactors=F) %>% as.matrix %>% as.character

## ## load topic model results
## cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
## print(cNMF.result.file)
## if(file.exists(cNMF.result.file)) {
##     print("loading cNMF result file")
##     load(cNMF.result.file)
## }

## Organize metadata
print("organize meta data")
## get metadata from log2.X.full rownames
# if(SAMPLE %in% c("2kG.library", "Perturb_2kG_dup4")) {
if( grepl("2kG.library|Perturb_2kG_dup4", SAMPLE) ) {
    barcode.names <- read.table(opt$barcode.names, header=F, stringsAsFactors=F) %>% `colnames<-`("long.CBC")
    ## rownames(omega) <- barcode.names %>% pull(long.CBC) %>% gsub("CSNK2B-and-CSNK2B", "CSNK2B",.)
    barcode.names <- barcode.names %>%
        rownames %>%
        as.data.frame %>%
        `colnames<-`("long.CBC") %>%
        separate(col="long.CBC", into=c("Gene.full.name", "Guide", "CBC"), sep=":", remove=F) %>%
        ## separate(col="CBC", into=c("CBC", "sample"), sep="-", remove=F) %>%
        separate(col="CBC", into=c("CBC", "sample"), sep="-scRNAseq_2kG_", remove=F) %>%
        mutate(Gene = gsub("-TSS2", "", Gene.full.name),
               CBC = gsub("RHOA-and-", "", CBC)) %>%
        as.data.frame
    sample.to.10X.lane <- data.frame(sample_num = 1:20,
                                     sample = meta_data$sample %>% unique %>% sort)

    ## meta_data <- merge(meta_data, sample.to.10X.lane, by="sample") %>%
    ##     mutate(CBC_10x = paste0(CBC, "-", sample_num)) %>%
    ##     merge(guidePerCell.df %>% select(CBC_10x, guides_per_cbc, max_umi_ct), by="CBC_10x") %>%
    ##     merge(umiPerCell.df, by="long.CBC") %>%
    ##     merge(geneDetectedPerCell.df, by.x="long.CBC", by.y=0)
} else {
    barcode.names <- read.table(opt$barcode.names, header=T, stringsAsFactors=F) ## %>% `colnames<-`("long.CBC")
}    

cat("finished loading barcode names\n")
## print(paste0("omega dimensions: ", dim(omega)))
cat(paste0("barcode names dimensions: ", dim(barcode.names)[1], " x ", dim(barcode.names)[2], "\n"))
num_genes <- barcode.names %>% pull(Gene) %>% unique %>% length
cat(paste0("total number of perturbations: ", num_genes, "\n"))
numGenesPerRun <- opt$num.genes.per.MAST.runGroup
num_MAST_runs <- floor(num_genes / numGenesPerRun) + 1
cat(paste0("total number of MAST runs: ", num_MAST_runs, "\n"))
gene.ary <- barcode.names %>% pull(Gene) %>% unique %>% sort
## separate into numGenesPerRun perturbations per group for MAST
for (i in 1:(num_MAST_runs-1)) {
    gene.ary.i <- gene.ary[((i-1)*numGenesPerRun+1) : (numGenesPerRun*i)]
    write.table(gene.ary.i, paste0(SCATTEROUTDIR, "/GeneNames_Group", i, ".txt"), sep="\n", quote=F, row.names=F, col.names=F)
}   
i <- i + 1 # the last one
gene.ary.i <- gene.ary[((i-1)*numGenesPerRun+1) : length(gene.ary)]
write.table(gene.ary.i, paste0(SCATTEROUTDIR, "/GeneNames_Group", i, ".txt"), sep="\n", quote=F, row.names=F, col.names=F)
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
library(conflicted)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel","ramify",
              "ggpubr","gridExtra",
              "org.Hs.eg.db","limma","fgsea",
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "writexl") #              "GGally","RNOmni","usedist","GSEA","clusterProfiler","IsoplotR","wesanderson",
## packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
##               "CountClust", "Hmisc", "tidyr", "grid", "gtable", "gridExtra","ggrepel","ramify",
##               "GGally","RNOmni","usedist","ggpubr","gridExtra","GSEA",
##               "org.Hs.eg.db","limma","clusterProfiler","fgsea", "conflicted",
##               "cluster","textshape","readxl", "IsoplotR", "wesanderson", 
##               "ggdist", "gghalves", "Seurat", "writexl")
                                        # library(Seurat)
xfun::pkg_attach(packages)
conflict_prefer("slice", "dplyr")


## source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")
source("workflow/scripts/motif_enrichment_functions.R")

option.list <- list(
    make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220307_prioritized_topic_motif_enrichment/figures/", help="Figure directory"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220307_prioritized_topic_motif_enrichment/outputs/", help="Output directory"),
                                        # make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
    make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/", help="Input 10x data directory"),
                                        # make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210707_snakemake_maxParallel/all_genes_acrossK/2kG.library/", help="Topic model results directory"),
    make_option("--sampleName", type="character", default="FT010_fresh_4min", help="Name of Samples to be processed, separated by commas"),
                                        # make_option("--sep", type="logical", default=F, help="Whether to separate replicates or samples"),
    make_option("--K.list", type="character", default="2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,19,21,23,25", help="K values available for analysis"),
    make_option("--K.val", type="numeric", default=20, help="K value to analyze"),
    make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--ABCdir",type="character", default="/oak/stanford/groups/engreitz/Projects/ABC/200220_CAD/ABC_out/TeloHAEC_Ctrl/Neighborhoods/", help="Path to ABC enhancer directory"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
                                        # make_option("--raw.mtx.dir",type="character",default="stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/data/no_IL1B_filtered.normalized.ptb.by.gene.mtx.filtered.txt", help="input matrix to cNMF pipeline"),
                                        # make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"), # the first lane: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.expandedMultiTargetGuide.RDS"
                                        # make_option("--subsample.type", type="character", default="", help="Type of cells to keep. Currently only support ctrl"),
                                        # make_option("--barcode.names", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/barcodes.tsv", help="barcodes.tsv for all cells"),
    make_option("--reference.table", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/210702_2kglib_adding_more_brief_ca0713.xlsx"),

    ## fisher motif enrichment
    ## make_option("--outputTable", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.k_14.df_0_2.txt", help="Output directory"),
    ## make_option("--outputTableBinary", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.table.binary.k_14.df_0_2.txt", help="Output directory"),
    ## make_option("--outputEnrichment", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210607_snakemake_output/outputs/no_IL1B/topic.top.100.zscore.gene.motif.fisher.enrichment.k_14.df_0_2.txt", help="Output directory"),
    make_option("--motif.promoter.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv", help="All promoter's motif matches"),
    make_option("--motif.enhancer.background", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv", help="All enhancer's motif matches specific to {no,plus}_IL1B"),
    make_option("--enhancer.fimo.threshold", type="character", default="1.0E-4", help="Enhancer fimo motif match threshold"),
    make_option("--ep.type", type="character", default="enhancer", help="motif enrichment for enhancer or promoter, specify 'enhancer' or 'promoter'"),

                                        #summary plot parameters
    make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
    make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
    make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute"),
    make_option("--motif.match.thr.str", type="character", default="pval0.0001", help="threshold for subsetting motif matches"),

    ## Organism flag
    make_option("--organism", type="character", default="human", help="Organism type, accept org.Hs.eg.db. Only support human and mouse.")


)
opt <- parse_args(OptionParser(option_list=option.list))

## ## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
## opt$K.val <- 60
## opt$sampleName <- "2kG.library"


# ## ## all genes directories (for sdev)
# opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/"
# opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
# opt$K.val <- 60
# opt$sampleName <- "2kG.library"



## ## debug ctrl
## opt$topic.model.result.dir <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210810_snakemake_ctrls/all_genes_acrossK/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation"
## opt$sampleName <- "2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210810_snakemake_ctrls/figures/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210810_snakemake_ctrls/analysis/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation/all_genes/"
## opt$barcode.names <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210806_curate_ctrl_mtx/outputs/2kG.library.no.DE.gene.with.FDR.less.than.0.1.perturbation.barcodes.tsv"
## opt$K.val <- 60

## ## K562 gwps sdev
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$K.val <- 100
## opt$ep.type <- "enhancer"
## opt$motif.match.thr.str <- "qval0.1"
## opt$motif.enhancer.background <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/fimo/fimo_out/fimo.formatted.tsv"
## opt$motif.promoter.background <- "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv"

## ## ENCODE mouse heart
## opt$sampleName <- "mouse_ENCODE_heart"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes"
## opt$K.val <- 5
## opt$ep.type <- "enhancer"
## opt$motif.match.thr.str <- "pval1e-4"
## opt$motif.enhancer.background <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230116_snakemake_mouse_ENCODE_heart/analysis/top2000VariableGenes/mouse_ENCODE_heart/fimo/fimo_out/fimo.txt"
## opt$motif.promoter.background <- "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv"

## ## no_IL1B 200 gene library
## opt$sampleName <- "no_IL1B"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/tutorials/2306_V2G2P_prep/analysis/all_genes"
## opt$K.val <- 14
## opt$ep.type <- "promoter"
## opt$motif.match.thr.str <- "qval0.1"
## opt$motif.enhancer.background <- "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.tsv"
## opt$motif.promoter.background <- "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv"

## ## IGVF b01_LeftCortex sdev
## opt$sampleName <- "IGVF_b01_LeftCortex"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/figures/all_genes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes"
## opt$K.val <- 20
## opt$ep.type <- "enhancer"
## opt$organism <- "mouse"
## opt$motif.match.thr.str <- "pval1e-6"
## opt$motif.enhancer.background <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/IGVF_b01_LeftCortex/fimo/fimo_out/fimo.txt"
## opt$motif.promoter.background <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/IGVF_b01_LeftCortex/fimo/fimo_out/fimo.txt"


mytheme <- theme_classic() + theme(axis.text = element_text(size = 9), axis.title = element_text(size = 11), plot.title = element_text(hjust = 0.5, face = "bold"))

SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
DATADIR=opt$olddatadir # "/seq/lincRNA/Gavin/200829_200g_anal/scRNAseq/"
OUTDIR=opt$outdir
                                        # TMDIR=opt$topic.model.result.dir
                                        # SEP=opt$sep
                                        # K.list <- strsplit(opt$K.list,",") %>% unlist() %>% as.numeric()
k <- opt$K.val
num.top.genes <- 300 ## number of topic defining genes
ep.type <- opt$ep.type
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIR=opt$figdir
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
FGSEADIR=paste0(OUTDIRSAMPLE,"/fgsea/")
FGSEAFIG=paste0(FIGDIRSAMPLE,"/fgsea/")

## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
## SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr
motif.match.thr.str <- opt$motif.match.thr.str


## create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE, FGSEADIR, FGSEAFIG)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

## palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)


######################################################################
## Load topic model results

cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
}


## not used
# ## load hg38 promoter region file
# promoter.region.hg38.original <- read.table(file=paste0("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/RefSeqCurated.170308.bed.CollapsedGeneBounds.TSS300bp.hg38.bed"), header = F, stringsAsFactors = F)  %>%
#     `colnames<-`(c("chr","start","end", "gene","cell.type","strand")) %>%
#     mutate(sequence_name = paste0(chr, ":", start, "-", end, "(", strand, ")"))


# ## keep only the expressed gene's motifs for background
# expressed.genes <- theta %>% rownames()
# promoter.region.hg38 <- promoter.region.hg38.original %>%
#     mutate(expressed = (gene %in% expressed.genes) ) %>%
#     filter(expressed)
###########


## load FIMO matched motifs ## ifelse on promoter vs enhancer
if (ep.type == "promoter") {
    ## load promoter motif matches
    motif.background <- read.delim(file=paste0(ifelse(opt$motif.promoter.background!="", opt$motif.promoter.background, "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv")), header=F, stringsAsFactors=F) ## %>% filter(!grepl("#", motif_id))  # 30 seconds
        if(ncol(motif.background) > 9) {
        motif.background <- motif.background %>%
            `colnames<-`(c("motif_id", "motif_alt_id", "enhancer_region", "enhancer_type", "gene_region","sequence_name","start","stop","motif.matched.strand","score","p.value","q.value","matched_sequence")) %>% filter(!grepl("#|motif_id", motif_id))  # more than 30 seconds, minutes?
        motif.background <- motif.background %>% filter(grepl("promoter", enhancer_type))

    } else {
        motif.background <- motif.background %>%
            `colnames<-`(c("motif_id", "sequence_name", "start", "stop", "motif.matched.strand", "score", "p.value", "q.value", "matched_sequence")) 
    }
    ## old
    ## motif.background <- read.delim(file=paste0(ifelse(opt$motif.promoter.background!="", opt$motif.promoter.background, "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModel/2104_remove_lincRNA/data/fimo_out_all_promoters_thresh1.0E-4/fimo.tsv")), header=T, stringsAsFactors=F) ## %>% filter(!grepl("#", motif_id))  # 30 seconds
    ## ## colnames(motif.background) <- c("motif_id", "motif_alt_id", "enhancer_region", "enhancer_type", "gene_region","sequence_name","start","stop","motif.matched.strand","score","p.value","q.value","matched_sequence")
    ## colnames(motif.background)[colnames(motif.background) == "strand"] <- "motif.matched.strand"
    ## motif.background <- motif.background %>%
    ##     mutate(motif.short = strsplit(motif_id, split="_") %>% sapply("[[", 1) %>% as.character)
    ## end of old

} else {
    ## load enhancer motif matches
    print(opt$motif.enhancer.background)
    motif.background <- read.delim(file=paste0(ifelse(opt$motif.enhancer.background!="", opt$motif.enhancer.background, "/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/cNMF/2104_all_genes/data/fimo_out_ABC_TeloHAEC_Ctrl_thresh1.0E-4/fimo.formatted.tsv")), header=F, stringsAsFactors=F)
        if(ncol(motif.background) > 9) {
        motif.background <- motif.background %>%
            `colnames<-`(c("motif_id", "motif_alt_id", "enhancer_region", "enhancer_type", "gene_region","sequence_name","start","stop","motif.matched.strand","score","p.value","q.value","matched_sequence")) %>% filter(!grepl("#|motif_id", motif_id))  # more than 30 seconds, minutes?
        motif.background <- motif.background %>% filter(!grepl("promoter", enhancer_type))

    } else {
        motif.background <- motif.background %>%
            `colnames<-`(c("motif_id", "sequence_name", "start", "stop", "motif.matched.strand", "score", "p.value", "q.value", "matched_sequence")) 
    }
}
motif.background <- motif.background %>% mutate(motif.short = strsplit(motif_id, split="_") %>% sapply("[[", 1) %>% as.character)
message("finished loading motif input")


## subset to q.value < 0.1 
if (grepl("qval", motif.match.thr.str)) {
    motif.background <- motif.background %>%
        subset(q.value < 0.1)
} else {
    print(paste0("subset to ", motif.match.thr.str))
    threshold <- gsub("pval", "", motif.match.thr.str) %>% as.numeric
    motif.background <- motif.background %>%
        subset(p.value < threshold)
}


if(ep.type == "enhancer") {
    if(ncol(motif.background) == 9 | sum(as.numeric(grepl("|", motif.background$sequence_name))) == nrow(motif.background)) {
        motif.background <- motif.background %>%
            filter(!grepl("promoter", sequence_name) & !grepl("start", start)) %>%
            separate(col="sequence_name", into=c("enhancer_region", "enhancer_type", "gene_region", "gene_name_sequence_region"), sep="[|]") %>%
            separate(col="gene_name_sequence_region", into=c("sequence_name", "to_remove"), sep="::") %>%
            select(-to_remove)
    }
    colnames(motif.background)[colnames(motif.background) == "strand"] <- "motif.matched.strand"
    motif.background <- motif.background %>%
        mutate(motif.short = strsplit(motif_id, split="_") %>% sapply("[[", 1) %>% as.character)
}


expressed.genes <- rownames(theta.zscore)
## todo: convert expressed genes to symbol if they are not in symbol
db <- ifelse(grepl("mouse|org.Mm.eg.db", opt$organism), "org.Mm.eg.db", "org.Hs.eg.db")
gene.type <- ifelse(length(expressed.genes) == sum(as.numeric(grepl("^ENS", expressed.genes))), "ENSGID", "Gene")
if(gene.type == "ENSGID") expressed.genes = mapIds(get(db), keys=expressed.genes, keytype="ENSEMBL", column="SYMBOL")
motif.background <- motif.background %>%
    subset(sequence_name %in% expressed.genes)





####################################################################################################
## get list of topic defining genes
theta.rank.list <- vector("list", ncol(theta.zscore))## initialize storage list
for(i in 1:ncol(theta.zscore)) {
    topic <- paste0("topic_", colnames(theta.zscore)[i])
    theta.rank.list[[i]] <- theta.zscore[,i] %>%
        as.data.frame %>%
        `colnames<-`("topic.zscore") %>%
        mutate(Gene = rownames(.)) %>%
        arrange(desc(topic.zscore), .before="topic.zscore") %>%
        mutate(zscore.specificity.rank = 1:n()) %>% ## add rank column
        mutate(Topic = topic) ## add topic column
}
theta.rank.df <- do.call(rbind, theta.rank.list) ## combine list to df
topic.defining.gene.df <- theta.rank.df %>%
    subset(zscore.specificity.rank <= num.top.genes) ## select top 300 genes for each topic
gene.type <- ifelse(nrow(topic.defining.gene.df) == sum(as.numeric(grepl("^ENS", topic.defining.gene.df$Gene))), "ENSGID", "Gene")
if(gene.type=="ENSGID") topic.defining.gene.df <- topic.defining.gene.df %>% mutate(ENSGID = Gene) %>% mutate(Gene = mapIds(get(db), keys=.$ENSGID, keytype="ENSEMBL", column="SYMBOL"))

## topic.defining.gene.df <- topic.defining.gene.df %>% mutate(Gene = toupper(Gene))

topic.motif.match.df <- merge(topic.defining.gene.df, motif.background %>% 
                                                      select(motif_id, motif.short, sequence_name, score, p.value, q.value, motif.matched.strand), by.x="Gene", by.y="sequence_name", all.y=T) ## filtered motif.background to genes expressed in this data set, so keep all


motif.id.type <- "motif.short" ## or "motif_id"
topic.motif.match.df.long <- topic.motif.match.df %>%
    group_by(Gene, get(motif.id.type), Topic) %>%
    summarize(count = n()) %>%
    `colnames<-`(c("gene", motif.id.type, "Topic", "count")) %>%
    as.data.frame

wide <- topic.motif.match.df.long %>%
    spread(key=motif.id.type, value="count", fill=0) %>% ## get each motif's count in each top promoter
    mutate(topic.value=1) %>%
    spread(key=Topic, value=topic.value, fill=0) ## %>% select(-topic_NA) ## get presence/absense of promoter in topic ## 5 seconds


print(paste0(ep.type, " motif (", motif.match.thr.str, ") count t-test"))
ttest.df <- ttest.on.motifs(wide)


## significant.motifs.df <- ttest.df %>%
##     group_by(topic) %>%
##     subset(p.adjust < 0.1 &
##            enrichment.log2fc > 0) %>%
##     arrange(p.adjust) %>%
##     summarize(motifs = paste0(motif %>% sort, collapse=",")) %>%
##     as.data.frame


####################################################################################################
## save results
all.ttest.df.path <- paste0(OUTDIRSAMPLE,"/", ep.type, ".topic.top.", num.top.genes, ".zscore.gene_motif.count.ttest.enrichment_motif.thr.", motif.match.thr.str, "_", SUBSCRIPT.SHORT,".txt")
print(paste0("saving to ", all.ttest.df.path))
write.table(ttest.df, all.ttest.df.path, sep="\t", quote=F, row.names=F)
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
suppressPackageStartupMessages(library(conflicted))
conflict_prefer("combine", "dplyr")
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("first", "dplyr")
conflict_prefer("rename", "dplyr")

suppressPackageStartupMessages({
    library(optparse)
    library(dplyr)
    library(tidyr)
    library(reshape2)
    ## library(ggplot2)
    ## library(cowplot)
    ## library(ggpubr) ## ggarrange
    ## library(gplots) ## heatmap.2
    ## library(scales) ## geom_tile gradient rescale
    ## library(ggrepel)
    library(stringr)
    library(stringi)
    library(svglite)
    ## library(Seurat)
    ## library(SeuratObject)
    library(xlsx)
    library(yaml)
})


##########################################################################################
## Constants and Directories
option.list <- list(
    make_option("--sampleName", type="character", default="WeissmanK562gwps", help="Name of the sample"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/", help="Output directory"),
    make_option("--K.val", type="numeric", default=90, help="K value to analyze"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--perturbSeq", type="logical", default=TRUE, help="Whether this is a Perturb-seq experiment"),
    make_option("--level", type="character", default="cell line", help="Sample type (e.g. tissue, cell line, primary cells"),
    make_option("--cell.type", type="character", default="teloHAEC", help="Cell type description (e.g. brain, teloHAEC, K562)")
)
opt <- parse_args(OptionParser(option_list=option.list))


## ## sdev IGVF b01_LeftCortex
## opt$sampleName <- "IGVF_b01_LeftCortex"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes/"
## opt$K.val <- 15
## opt$perturbSeq <- FALSE
## opt$level <- "tissue"
## opt$cell.type <- "brain"


SAMPLE=strsplit(opt$sampleName,",") %>% unlist()
OUTDIR=opt$outdir
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
OUTDIRSAMPLEIGVF = paste0(OUTDIRSAMPLE, "IGVF_format/")

check.dir <- c(OUTDIRSAMPLEIGVF)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))


## subscript for files
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
## SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## cNMF direct output file (GEP)
cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
## cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.k_", k, ".dt_", density.threshold, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
} else {
    print(paste0("file ", cNMF.result.file, " not found"))
}

db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
library(!!db) ## load the appropriate database
topic.gene.names <- rownames(theta.zscore)
topic.gene.name.type <- ifelse(grepl("^ENSG", topic.gene.names) %>% as.numeric %>% sum == length(topic.gene.names), "ENSGID", "SYMBOL")
if(topic.gene.name.type == "ENSGID") {
    ENSGID.gene.names <- topic.gene.names
    SYMBOL.gene.names <- mapIds(get(db), keys=topic.gene.names, keytype = "ENSEMBL", column = "SYMBOL")
    SYMBOL.gene.names[is.na(SYMBOL.gene.names)] <- ENSGID.gene.names[is.na(SYMBOL.gene.names)]
} else {
    SYMBOL.gene.names <- topic.gene.names
    ENSGID.gene.names <- mapIds(get(db), keys=topic.gene.names, keytype = "SYMBOL", column = "ENSEMBL")
    ENSGID.gene.names[is.na(ENSGID.gene.names)] <- SYMBOL.gene.names[is.na(ENSGID.gene.names)]
}

## load variance explained
variance.explained.df <- read.delim(paste0(OUTDIRSAMPLE, "metrics.varianceExplained.df.txt"), stringsAsFactors=F)

## 1. Model YAML file: Capture all the information about the dataset that you used and which method and how you run it along with Topic_ID which points to Topic YAML files and Cell-Topic participation ID for pointing out the Cell-Topic participation h5ad file.
out <- list("Assay" = NULL,
            "Cell-Topic participation ID" = NULL,
            "Experiment ID" = SAMPLE,
            "Name of method" = "cNMF",
            "Number of topics" = k,
            "Technology" = "10x",
            "cNMF spectra threshold" = opt$density.thr,
            "Topic IDs" = paste0(SAMPLE, "_K", k, "_", 1:k),
            "level" = opt$level,
            "cell type" = opt$cell.type)
write_yaml(out, paste0(OUTDIRSAMPLEIGVF, SAMPLE, ".", SUBSCRIPT.SHORT, ".modelYAML.yaml"))

## 2. Topics YAML files: Capture all the information about Topics including Topic_ID, gene_weight, gene_id and gene_name and any other information that suits your data
theta.zscore.long <- theta.zscore %>%
    as.data.frame %>%
    `colnames<-`(paste0(SAMPLE, "_K", k, "_", colnames(.))) %>%
    mutate(Gene = SYMBOL.gene.names,
           ENSGID = ENSGID.gene.names) %>%
    melt(id.vars=c("Gene", "ENSGID"), value.name="Gene weights", variable.name="Topic ID") %>%
    rename("gene_id" = "ENSGID")

for( t in 1:k ) {
    theta.zscore.long.here <- theta.zscore.long %>%
        subset(grepl(paste0("_", t, "$"), `Topic ID`))
    duplicated.index <- duplicated(theta.zscore.long.here$Gene)
    theta.zscore.long.here$Gene[duplicated.index] <- paste0(theta.zscore.long.here$Gene[duplicated.index], "_", theta.zscore.long.here$ENSGID[duplicated.index])

    variance.here <- variance.explained.df %>% subset(ProgramID == paste0("K", k, "_", t)) %>% pull(VarianceExplained)
    ## create output list
    ## out <- list("gene_id" = theta.zscore.long.here %>%
    ##                 `rownames<-`(.$Gene) %>%
    ##                 select(gene_id) %>% t %>% as.data.frame,
    ##             "Gene weights" = theta.zscore.long.here %>%
    ##                 `rownames<-`(.$Gene) %>%
    ##                 select(`Gene weights`) %>% t %>% as.data.frame,
    ##             "Topic ID" = paste0(SAMPLE, "_K", k, "_", t))

    out <- list("Gene information" = list("gene_id" = theta.zscore.long.here %>%
                                              `rownames<-`(.$Gene) %>%
                                              select(gene_id) %>% t %>% as.data.frame),
                "Gene weights" = theta.zscore.long.here %>%
                    `rownames<-`(.$Gene) %>%
                    select(`Gene weights`) %>% t %>% as.data.frame,
                "Topic ID" = paste0(SAMPLE, "_K", k, "_", t),
                "Topic Information" = list("variance" = variance.here))

    write_yaml(out, paste0(OUTDIRSAMPLEIGVF, SAMPLE, "_", SUBSCRIPT.SHORT, "_program", t, "_topicYAML.yaml"))
}
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
suppressPackageStartupMessages({
    library(optparse)
    library(dplyr)
    library(tidyr)
    library(reshape2)
    library(ggplot2)
    library(cowplot)
    library(ggpubr) ## ggarrange
    library(gplots) ## heatmap.2
    library(scales) ## geom_tile gradient rescale
    library(ggrepel)
    library(stringr)
    library(svglite)
    library(ggseqlogo)
    library(universalmotif)
})



##########################################################################################
## Constants and Directories
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211205_sig_topic_TPM/figures/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/"
option.list <- list(
    make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/", help="Figure directory"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/", help="Output directory"),
    make_option("--sampleName", type="character", default="2kG.library", help="Name of Samples to be processed, separated by commas"),
    make_option("--K.val", type="numeric", default=60, help="K value to analyze"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--raw.mtx.RDS.dir",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210623_aggregate_samples/outputs/aggregated.2kG.library.mtx.cell_x_gene.RDS", help="input matrix to cNMF pipeline"),
    make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),

    ## GSEA parameters
    make_option("--ranking.type", type="character", default="zscore", help="{zscore, raw} ranking for the top program genes"),
    make_option("--GSEA.type", type="character", default="GOEnrichment", help="{GOEnrichment, ByWeightGSEA, GSEA}")
)
opt <- parse_args(OptionParser(option_list=option.list))

## ## overdispersed genes (for sdev)
## opt$sampleName <- "2kG.library_overdispersedGenes"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes"
## opt$K.val <- 100

## ## K562 gwps sdev
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/"
## opt$K.val <- 55
## opt$ranking.type <- "zscore"
## opt$GSEA.type <- "GSEA"

## IGVF b01_LeftCortex sdev
opt$sampleName <- "IGVF_b01_LeftCortex"
opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/figures/all_genes/"
opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/all_genes"
opt$K.val <- 10
opt$ranking.type <- "median_spectra"
opt$GSEA.type <- "GSEA"



OUTDIR <- opt$outdir
FIGDIR <- opt$figdir
SAMPLE <- opt$sampleName
k <- opt$K.val
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
FIGDIRSAMPLE=paste0(FIGDIR, "/", SAMPLE, "/K",k,"/")
FIGDIRTOP=paste0(FIGDIRSAMPLE,"/",SAMPLE,"_K",k,"_dt_", DENSITY.THRESHOLD,"_")
OUTDIRSAMPLE=paste0(OUTDIR, "/", SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
# SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)

## GSEA specific parameters
ranking.type.here <- opt$ranking.type
GSEA.type <- opt$GSEA.type


message(FIGDIRTOP)

## adjusted p-value threshold
fdr.thr <- opt$adj.p.value.thr
p.value.thr <- opt$adj.p.value.thr

# create dir if not already
check.dir <- c(OUTDIR, FIGDIR, paste0(FIGDIR,SAMPLE,"/"), paste0(FIGDIR,SAMPLE,"/K",k,"/"), paste0(OUTDIR,SAMPLE,"/"), OUTDIRSAMPLE, FIGDIRSAMPLE)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))

mytheme <- theme_classic() + theme(axis.text = element_text(size = 7),
                                   axis.title = element_text(size = 8),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=10),
                                   axis.line = element_line(color = "black", size = 0.25),
                                   axis.ticks = element_line(color = "black", size = 0.25),
                                   legend.key.size = unit(10, units="pt"),
                                   legend.text = element_text(size=7),
                                   legend.title = element_text(size=8)
                                   )

file.name <- paste0(OUTDIRSAMPLE, "/clusterProfiler_GeneRankingType", ranking.type.here, "_EnrichmentType", GSEA.type,".txt")
gsea.df <- read.delim(file.name, stringsAsFactors=F)
if(nrow(gsea.df) == 0) {
    toplot <- data.frame()
} else {
    toplot <- gsea.df %>%
        subset(p.adjust < fdr.thr) %>%
        group_by(ProgramID) %>%
        arrange(p.adjust) %>%
        unique %>%
        slice(1:10) %>%
        mutate(TruncatedDescription = str_trunc(paste0(ID, "; ", Description), width=50, side="right"),
               t = gsub("K60_", "", ProgramID) %>% as.numeric) %>%
        arrange(t, p.adjust) %>%
        as.data.frame
}

plot.title <- paste0(ifelse(grepl("GO", GSEA.type), "GO Term Enrichment", "MSigDB Pathway Enrichment"),
                     "\non ",
                     ifelse(grepl("zscore", ranking.type.here), "Program Gene Specificity", "Raw Weight"),
                     "\nby ",
                     ifelse(grepl("ByWeight", GSEA.type), "All Gene Weight", "Top 300 Gene Set"))


pdf(file=paste0(FIGDIRTOP,"top10EnrichedPathways_GeneRankingType", ranking.type.here, "_EnrichmentType", GSEA.type, ".pdf"), width=4, height=4)
for(program in (paste0("K", k, "_", c(1:k)))) {
    t <- strsplit(program, split="_") %>% sapply(`[[`,2)
    toplot.here <- toplot %>%
        subset(ProgramID %in% program) %>%
        arrange(p.adjust)
    labels <- toplot.here$TruncatedDescription %>% unique %>% rev
    toplot.here <- toplot.here %>%
        mutate(TruncatedDescription = factor(TruncatedDescription, levels = labels))
    p <- toplot.here %>% ggplot(aes(x=TruncatedDescription, y=-log10(p.adjust))) + geom_col(fill="gray") + coord_flip() + mytheme +
        xlab(ifelse(grepl("GO", GSEA.type), "GO Terms", "Pathways")) + ylab("FDR (-log10)") +
        ggtitle(paste0(plot.title, "\nProgram ", t))    
    print(p)
}
dev.off()
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
library(conflicted)
conflict_prefer("first", "dplyr")
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("collapse", "dplyr")
conflict_prefer("combine", "dplyr")
packages <- c("optparse","dplyr", "data.table", "reshape2", "ggplot2",
              "tidyr", "textshape","readxl", "AnnotationDbi")
xfun::pkg_attach(packages)
conflict_prefer("select", "dplyr")


option.list <- list(
    make_option("--feature.dir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/data/features/pops_features_raw/"),
    make_option("--output", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211101_normalized_features/outputs/", help="output directory")
)
opt <- parse_args(OptionParser(option_list=option.list))


## ## for sdev
## opt$feature.dir <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/211101_20sample_snakemake/pops/features/pops_features_raw"
## opt$output <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/pops/"
## opt$prefix <- "CAD_aug6_cNMF30"


OUTDIR=opt$output
# PREFIX=opt$prefix



## ## load cNMF features
## features <- read.delim(opt$, stringsAsFactors=F)

## get files in feature directory
feature.file.paths <- dir(opt$feature.dir)

## load all features
num_feature_files = length(feature.file.paths)
all.features.list <- vector("list", num_feature_files)
for (feature.raw.index in 1:num_feature_files) {
    all.features.list[[feature.raw.index]] <- read.delim(paste0(opt$feature.dir, "/", feature.file.paths[[feature.raw.index]]), stringsAsFactors=F)
}

all.features <- Reduce(function(x,y) merge(x,y,by="ENSGID"), all.features.list) ## features without cNMF

# all.features <- data.frame(1)
## all.features.cNMF <- merge(all.features, features, by="ENSGID") ## all features with cNMF
print(paste0("Saving all.features to ", OUTDIR, " as an RDS file"))
saveRDS(all.features, file=paste0(OUTDIR, "/full_external_features.RDS"))
print(paste0("Saving all.features to ", OUTDIR, " as a text file"))
write.table(all.features, file=paste0(OUTDIR, "/full_external_features.txt"), row.names=F, quote=F, sep="\t")
print("done!")
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
library(conflicted)
conflict_prefer("first", "dplyr")
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("collapse", "dplyr")
conflict_prefer("combine", "dplyr")
packages <- c("optparse","dplyr", "data.table", "reshape2", "ggplot2",
              "tidyr", "textshape","readxl", "AnnotationDbi")
xfun::pkg_attach(packages)
conflict_prefer("select", "dplyr")


option.list <- list(
    make_option("--feature.RDS", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/data/features/pops_features_raw/"),
    make_option("--cNMF.features", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211101_normalized_features/outputs/", help="cNMF features in ENSGID to add to all features"),
    make_option("--output", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211101_normalized_features/outputs/", help="output directory"),
    make_option("--prefix", type="character", default="CAD_aug6_cNMF60", help="use a format of MAGMA_{with, without}cNMF to specify which features are included")
)
opt <- parse_args(OptionParser(option_list=option.list))


## ## for sdev
## opt$feature.dir <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/211101_20sample_snakemake/pops/features/pops_features_raw"
## opt$output <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/pops/"
## opt$prefix <- "CAD_aug6_cNMF30"


OUTDIR=opt$output
PREFIX=opt$prefix



## load cNMF features
features <- read.delim(opt$cNMF.features, stringsAsFactors=F)

## load external features RDS
all.features <- readRDS(opt$feature.RDS) ## features without cNMF

## combine cNMF features and external features
all.features.cNMF <- merge(all.features, features, by="ENSGID") ## all features with cNMF
saveRDS(all.features.cNMF, file=paste0(OUTDIR, "/full_features_", PREFIX, ".RDS"))
write.table(all.features.cNMF, file=paste0(OUTDIR, "/full_features_", PREFIX, ".txt", row.names=F, quote=F, sep="\t"))
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
library(conflicted)
conflict_prefer("first", "dplyr")
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("collapse", "dplyr")
conflict_prefer("combine", "dplyr")
packages <- c("optparse","dplyr", "data.table", "reshape2", "ggplot2",
              "tidyr", "textshape","readxl", "gplots", "AnnotationDbi",
              "org.Hs.eg.db", "ggrepel", "gplots")
xfun::pkg_attach(packages)
conflict_prefer("select", "dplyr")


option.list <- list(
    make_option("--project", type="character", default = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/", help="project directory"),
    make_option("--output", type="character", default = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/", help="output directory"),
    make_option("--scratch.output", type="character", default="", help="output directory for large files"),
    make_option("--coefs_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.coefs", help=""),
    make_option("--marginals_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.marginals", help=""),
    make_option("--preds_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.preds", help=""),
    make_option("--coefs_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.coefs", help=""),
    make_option("--marginals_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.marginals", help=""),
    make_option("--preds_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.preds", help=""),
    make_option("--prefix", type="character", default="CAD_aug6_cNMF60", help="magma file name (before genes.raw)"),
    make_option("--external.features.metadata", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/metadata/metadata_jul17.txt", help="annotations for each external features"),
    make_option("--cNMF.features", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/data/features/pops_features_raw/topic.zscore.ensembl.scaled_k_60.dt_0_2.txt", help="normalized cNMF weights, unit variance and zero mean"),
    make_option("--all.features", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211101_normalized_features/outputs/full_features_with_cNMF.RDS", help=".RDS file with all features input into PoPS"),
    make_option("--recompute", type="logical", default=F, help="T for rerunning the entire script, F for only outputting the missing data")
)
opt <- parse_args(OptionParser(option_list=option.list))

SAMPLE=opt$prefix
OUTDIR=opt$output
SCRATCH.OUTDIR=opt$scratch.output
PREFIX=opt$prefix

check.dir <- c(OUTDIR, SCRATCH.OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))




## graphing constants
mytheme <- theme_classic() + theme(axis.text = element_text(size = 12), axis.title = element_text(size = 14), plot.title = element_text(hjust = 0.5, face = "bold", size=14))
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)


## load metadata
print("loading metadata")
meta.data.path <- opt$external.features.metadata
print(meta.data.path)
metadata <- read.delim(meta.data.path, stringsAsFactors=F)


## load all features
print(paste0("loading all features from ", opt$all.features))
all.features.cNMF <- readRDS(opt$all.features)


## load data
print("loading PoPS results")
preds <- read.table(file=opt$preds_with_cNMF,header=T, stringsAsFactors=F, sep="\t")
colnames(preds) <- paste0(colnames(preds), "_with.cNMF")
colnames(preds)[1] <- "ENSGID"
preds.before <- read.table(file=paste0(opt$preds_without_cNMF), header=T, stringsAsFactors=F, sep="\t")
colnames(preds.before) <- paste0(colnames(preds.before), "_without.cNMF")
colnames(preds.before)[1] <- "ENSGID"
preds.combined <- merge(preds, preds.before, by="ENSGID")
marginals <- read.table(file=opt$marginals_with_cNMF,header=T, stringsAsFactors=F, sep="\t")
coefs <- read.table(file=opt$coefs_with_cNMF,header=T, stringsAsFactors=F, sep="\t")
coefs.df <- coefs[4:nrow(coefs),] %>% arrange(desc(beta))
coefs.df$beta <- coefs.df$beta %>% as.numeric


## map ids
x <- org.Hs.egENSEMBL 
mapped_genes <- mappedkeys(x)
xx.entrez.to.ensembl <- as.list(x[mapped_genes]) # EntrezID to Ensembl
xx.ensembl.to.entrez <- as.list(org.Hs.egENSEMBL2EG) # Ensembl to EntrezID


y <- org.Hs.egGENENAME
y_mapped_genes <- mappedkeys(y)
entrez.to.genename <- as.list(y[y_mapped_genes])
genename.to.entrez <- as.list(org.Hs.egGENENAME)


z <- org.Hs.egSYMBOL
z_mapped_genes <- mappedkeys(z)
entrez.to.symbol <- as.list(z[z_mapped_genes])
symbol.to.entrez <- as.list(org.Hs.egSYMBOL)



## function for adding name to df
add_gene_name_to_df <- function(df) {
    out <- df %>% mutate(EntrezID = xx.ensembl.to.entrez[df$ENSGID %>% as.character] %>% sapply("[[",1)) %>%
    mutate(Gene.name = entrez.to.genename[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
           Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character)
    return(out)
}

## add Gene name to preds.df
preds.df <- preds %>% mutate(EntrezID = xx.ensembl.to.entrez[preds$ENSGID %>% as.character] %>% sapply("[[",1)) %>%
    mutate(Gene.name = entrez.to.genename[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
           Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character)
file.name <- paste0(OUTDIR, "/", PREFIX, ".combined.preds")
if( !file.exists(file.name) | opt$recompute ) {
preds.combined.df <- preds.combined %>% mutate(EntrezID = xx.ensembl.to.entrez[preds.combined$ENSGID %>% as.character] %>% sapply("[[",1)) %>%
    mutate(Gene.name = entrez.to.genename[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
           Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character)
write.table(preds.combined.df %>% apply(2, as.character), file=file.name, row.names=F, quote=F, sep="\t")
} else {
    preds.combined.df <- read.delim(file.name, stringsAsFactors = F)
}



file.name <- paste0(SCRATCH.OUTDIR, "/", PREFIX, "_coefs.marginals.feature.outer.prod.RDS")
## if( !file.exists(file.name) | opt$recompute ) {

    ## load cNMF features
    ## features <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/data/features/pops_features_raw/topic.zscore.ensembl.scaled_k_60.dt_0_2.txt", stringsAsFactors=F)
    features <- read.delim(opt$cNMF.features, stringsAsFactors=F)

    ## sort features and subset coefs, take a product ((gene x features) x (features x beta scalar))
    coefs.cnmf <- coefs %>% subset(grepl("zscore", parameter))
    coefs.cnmf.mtx <- coefs.cnmf %>% `rownames<-`(.$parameter) %>% select(-parameter)
    coefs.cnmf.mtx$beta <- coefs.cnmf.mtx$beta %>% as.numeric
    coefs.cnmf.mtx <- coefs.cnmf.mtx %>% as.matrix

    ## subset features
    features.names.tokeep <- coefs.cnmf.mtx %>% rownames
    features.tokeep <- features %>% `rownames<-`(.$ENSGID) %>% select(all_of(features.names.tokeep)) %>% as.matrix

    ## also test marginals
    features.names.tokeep.df <- data.frame(X=features.names.tokeep)
    marginals.tokeep <- inner_join(features.names.tokeep.df, marginals, by="X") %>% `rownames<-`(.$X) %>% select(beta) %>% as.matrix


    coefs.mtx <- coefs.df %>% `rownames<-`(.$parameter) %>% select(-parameter)
    coefs.mtx$beta <- coefs.mtx$beta %>% as.numeric
    coefs.mtx <- coefs.mtx %>% as.matrix
    all.features.cNMF.names.tokeep <- coefs.mtx %>% rownames
    all.features.cNMF.tokeep <- all.features.cNMF %>% `rownames<-`(.$ENSGID) %>% select(all_of(all.features.cNMF.names.tokeep)) %>% as.matrix
    all.features.cNMF.tokeep[all.features.cNMF.tokeep=="True"] <- 1
    all.features.cNMF.tokeep[all.features.cNMF.tokeep=="False"] <- 0
    storage.mode(all.features.cNMF.tokeep) <- "numeric"
    all.features.cNMF.names.tokeep.df <- data.frame(X=all.features.cNMF.names.tokeep) ## to subset marginals
    all.marginals.cNMF.tokeep <- inner_join(all.features.cNMF.names.tokeep.df, marginals, by="X") %>% `rownames<-`(.$X) %>% select(beta) %>% as.matrix

    saveRDS(all.features.cNMF.names.tokeep, file=paste0(OUTDIR, "/all.features.cNMF.keep.prioritized.mtx.RDS"))

    ## ## check if PoPS_Score = coefs * features
    ## ## multiply features with beta
    ## PoPS_Score.coefs.manual <- features.tokeep %*% coefs.cnmf.mtx %>% `colnames<-`("PoPS_Score.coefs.manual")
    ## PoPS_Score.marginals.manual <- features.tokeep %*% marginals.tokeep %>% `colnames<-`("PoPS_Score.marginals.manual")
    ## PoPS_Score.coefs.all <- all.features.cNMF.tokeep %*% coefs.mtx %>% `colnames<-`("PoPS_Score_all.coefs")
    ## PoPS_Score.marginals.all <- all.features.cNMF.tokeep %*% marginals.tokeep %>% `colnames<-`("PoPS_Score.all.marginals")
    ## ## PoPS_Score.marginals.manual.all <- features %>% `rownames<-`(.$ENSGID) %>% select(-ENSGID) %*% marginals


    ## get outer products of features and marginals
    PoPS_Score.coefs.manual.outer.ENSG <- sweep(features.tokeep, 2, (coefs.cnmf.mtx %>% t), `*`) ### store this matrix
    PoPS_Score.marginals.manual.outer.ENSG <- sweep(features.tokeep, 2, (marginals.tokeep %>% t), `*`) ### save this matrix
    PoPS_Score.coefs.all.outer.ENSG <- sweep(all.features.cNMF.tokeep, 2, (coefs.mtx %>% t), `*`)
    PoPS_Score.marginals.all.outer.ENSG <- sweep(all.features.cNMF.tokeep, 2, (all.marginals.cNMF.tokeep %>% t), `*`)


    ## function to convert the outer product ENSG name to Gene name
    add_gene_name <- function(df) {
        out <- df %>% as.data.frame %>% mutate(EntrezID = xx.ensembl.to.entrez[df %>% rownames %>% as.character] %>% sapply("[[",1)) %>%
            mutate(Gene.name = entrez.to.genename[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
                   Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character)
        return(out)
    }

    PoPS_Score.coefs.manual.outer <- add_gene_name(PoPS_Score.coefs.manual.outer.ENSG)
    PoPS_Score.marginals.manual.outer <- add_gene_name(PoPS_Score.marginals.manual.outer.ENSG)
    PoPS_Score.coefs.all.outer <- add_gene_name(PoPS_Score.coefs.all.outer.ENSG)
    PoPS_Score.marginals.all.outer <- add_gene_name(PoPS_Score.marginals.all.outer.ENSG)

    ## save the results
    write.table(PoPS_Score.coefs.manual.outer %>% apply(2, as.character), file=paste0(OUTDIR, "/", PREFIX, "_coefs.feature.outer.prod.txt"), quote=F, sep="\t")
    write.table(PoPS_Score.marginals.manual.outer %>% apply(2, as.character), file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_marginals.feature.outer.prod.txt"), quote=F, sep="\t")
    write.table(PoPS_Score.coefs.all.outer %>% apply(2, as.character), file=paste0(OUTDIR, "/", PREFIX, "_coefs.all.feature.outer.prod.txt"), quote=F, sep="\t")
    write.table(PoPS_Score.marginals.all.outer %>% apply(2, as.character), file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_marginals.all.feature.outer.prod.txt"), quote=F, sep="\t")



    ## find top topic that define PoPS for each gene
    ## function for sorting the feature x gene importance value
    sort_feature_x_gene_importance <- function(df) {
        out <- df %>% mutate(ENSGID=rownames(.)) %>% melt(id.vars = c("Gene.name", "EntrezID", "Gene", "ENSGID"), variable.name="topic", value.name="gene.feature_x_beta") %>% group_by(Gene) %>% arrange(desc(gene.feature_x_beta)) %>% as.data.frame
        return(out)
    }

    coefs.defining.top.topic.df <- PoPS_Score.coefs.manual.outer %>% sort_feature_x_gene_importance
    marginals.defining.top.topic.df <- PoPS_Score.marginals.manual.outer %>% sort_feature_x_gene_importance
    all.coefs.defining.top.topic.df <- PoPS_Score.coefs.all.outer %>% sort_feature_x_gene_importance
    all.marginals.defining.top.topic.df <- PoPS_Score.marginals.all.outer %>% sort_feature_x_gene_importance

    ## these are large files, so store them in $GROUP_SCRATCH
    saveRDS(marginals.defining.top.topic.df,
         file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_marginals.defining.top.topic.RDS"))
    saveRDS(coefs.defining.top.topic.df, 
         file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_coefs.defining.top.topic.RDS"))
    saveRDS(all.marginals.defining.top.topic.df, 
         file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_all.marginals.defining.top.topic.RDS"))
    saveRDS(all.coefs.defining.top.topic.df, 
         file=paste0(SCRATCH.OUTDIR, "/", PREFIX, "_all.coefs.defining.top.topic.RDS"))

    save(PoPS_Score.coefs.manual.outer, PoPS_Score.marginals.manual.outer, PoPS_Score.coefs.all.outer, PoPS_Score.marginals.all.outer,
         coefs.defining.top.topic.df, marginals.defining.top.topic.df, all.coefs.defining.top.topic.df, all.marginals.defining.top.topic.df,
         file=file.name)
## } else {
##     load(file.name)
## }


## output a table, one row per gene, with columns for gene symbol, PoPS score without cNMF, PoPS score with cNMF, top features from any source important for that gene, top topic features important for that gene
coefs.defining.top.topic.df.subset <- coefs.defining.top.topic.df %>% subset(grepl("^ENSG",ENSGID)) %>% group_by(ENSGID) %>% arrange(desc(gene.feature_x_beta)) %>% slice(1:10) %>% as.data.frame
all.coefs.defining.top.topic.df.subset <- all.coefs.defining.top.topic.df %>% subset(grepl("^ENSG",ENSGID)) %>% group_by(ENSGID) %>% arrange(desc(gene.feature_x_beta)) %>% slice(1:10) %>% as.data.frame


PoPS_preds.importance.score <- merge(preds.combined.df, all.coefs.defining.top.topic.df.subset %>% select(-Gene.name, -Gene, -EntrezID), by="ENSGID") %>% merge(., metadata, by.x="topic", by.y="X", all.x=T)
colnames(PoPS_preds.importance.score)[which(colnames(PoPS_preds.importance.score)=="topic")] <- "pathway"
write.table(PoPS_preds.importance.score %>% apply(2, as.character), file=paste0(OUTDIR, "/", PREFIX, "_PoPS_preds.importance.score.all.columns.txt"), sep="\t", quote=F, row.names=F)
PoPS_preds.importance.score.key <- PoPS_preds.importance.score %>% select(Gene, Gene.name, PoPS_Score_with.cNMF, PoPS_Score_without.cNMF, pathway, Long_Name, gene.feature_x_beta)
write.table(PoPS_preds.importance.score.key %>% apply(2, as.character), file=paste0(OUTDIR, "/", PREFIX, "_PoPS_preds.importance.score.key.columns.txt"), sep="\t", quote=F, row.names=F)

## cNMF topics only gene.feature_x_beta score
PoPS_preds.importance.score.cNMF <- merge(preds.combined.df, coefs.defining.top.topic.df.subset %>% select(-Gene.name, -Gene, -EntrezID), by="ENSGID") %>% merge(., metadata, by.x="topic", by.y="X", all.x=T)
colnames(PoPS_preds.importance.score.cNMF)[which(colnames(PoPS_preds.importance.score.cNMF)=="topic")] <- "pathway"
write.table(PoPS_preds.importance.score %>% apply(2, as.character), file=paste0(OUTDIR, "/", PREFIX, "_PoPS_preds.importance.score.all.columns.cNMF.Topics.only.txt"), sep="\t", quote=F, row.names=F)
PoPS_preds.importance.score.cNMF.key <- PoPS_preds.importance.score.cNMF %>% select(Gene, Gene.name, PoPS_Score_with.cNMF, PoPS_Score_without.cNMF, pathway, Long_Name, gene.feature_x_beta)
write.table(PoPS_preds.importance.score.cNMF.key %>% apply(2, as.character), file=paste0(OUTDIR, "/PoPS_preds.importance.score.key.columns.cNMF.Topics.only.txt"), sep="\t", quote=F, row.names=F)




## get top features' top genes
top.features.names <- coefs.df %>% arrange(desc(beta)) %>% slice(1:10) %>% pull(parameter)
top.features.definition <- all.features.cNMF %>% select(all_of(c("ENSGID",top.features.names))) %>% add_gene_name_to_df
slice_top_genes_from_features <- function(df) {
    out <- df %>% melt(value.name="relative.importance", variable.name="features", id.vars=c("ENSGID", "Gene", "Gene.name", "EntrezID")) %>% group_by(features) %>% arrange(desc(relative.importance)) %>% slice(1:10)
}
top.genes.in.top.features <- top.features.definition %>% slice_top_genes_from_features %>% as.data.frame %>% merge(coefs.df, by.x="features", by.y="parameter") %>% arrange(desc(beta))
write.table(top.genes.in.top.features, file=paste0(OUTDIR, "/top.genes.in.top.features.coefs.txt"), row.names=F, quote=F, sep="\t")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import pandas as pd
import glob
import argparse

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Converts a directory of feature files into efficient NumPy format, written out to multiple chunks, amenable for use downstream with PoPS.')
    parser.add_argument("--gene_annot_path", help="Path to gene annotation table. For the purposes of this script, only require that there is an ENSGID column.")
    parser.add_argument("--feature_dir", help="Directory where raw feature files live. Each feature file must be a tab-separated file with a header for column names and the first column must be the ENSGID. Will process every file in the directory so make sure every file is a feature file and there are no hidden files. Please also make sure the column names are unique across all feature files. The easiest way to ensure this is to prefix every column with the filename.")
    parser.add_argument("--nan_policy", default="raise", help="What to do if a feature file is missing ENSGIDs that are in gene_annot_path. Takes the values \"raise\" (raise an error), \"ignore\" (ignore and write out with nans), \"mean\" (impute the mean of the feature), and \"zero\" (impute 0). Default is \"raise\".")
    parser.add_argument("--save_prefix", help="Prefix to the output path. For each chunk i, 2 files will be written: {save_prefix}_mat.{i}.npy, {save_prefix}_cols.{i}.txt. Furthermore, row data will be written to {save_prefix}_rows.txt")
    parser.add_argument("--max_cols", default=5000, type=int, help="Maximum number of columns per output chunk. Default is 5000.")

    args = parser.parse_args()
    gene_annot_path = args.gene_annot_path
    feature_dir = args.feature_dir
    nan_policy = args.nan_policy
    save_prefix = args.save_prefix
    MAX_COLS = args.max_cols

    assert nan_policy in ["raise", "ignore", "mean", "zero"], "Invalid argument for flag --nan_policy. Accepts \"raise\", \"ignore\", \"mean\", and \"zero\"."

    gene_annot_df = pd.read_csv(gene_annot_path, sep="\t", index_col="ENSGID").iloc[:,0:0]
    row_data = gene_annot_df.index.values
    np.savetxt(save_prefix + ".rows.txt", row_data, fmt="%s")

    #### Sort for canonical ordering
    all_feature_files = sorted([f for f in glob.glob(feature_dir + "/*")])

    all_mat_data = []
    all_col_data = []
    curr_block_index = 0
    for f in all_feature_files:
        print(f) ## added by Helen for debugging
        f_df = pd.read_csv(f, sep="\t", index_col=0).astype(np.float64)
        # import pdb; pdb.set_trace(); ## added by Helen for debugging
        f_df = gene_annot_df.merge(f_df, how="left", left_index=True, right_index=True)
        if len(f_df.index[f_df.index.duplicated(keep='first')]) > 0:
            print("duplicated ENSGID: " + f_df.index[f_df.index.duplicated()])
            f_df = f_df[~f_df.index.duplicated(keep='first')] ## added by Helen to remove duplicated ENSGID ## need to make sure 
        if nan_policy == "raise":
            assert not f_df.isnull().values.any(), "Missing genes in feature matrix."
        elif nan_policy == "ignore":
            pass
        elif nan_policy == "mean":
            f_df = f_df.fillna(f_df.mean())
        elif nan_policy == "zero":
            f_df = f_df.fillna(0)
        mat = f_df.values
        cols = f_df.columns.values
        all_mat_data.append(mat)
        all_col_data += list(cols)
        while len(all_col_data) >= MAX_COLS:
            ### Flush MAX_COLS columns to disk at a time
            # import pdb; pdb.set_trace(); ## added by Helen for debugging
            mat = np.hstack(all_mat_data)
            save_mat = mat[:,:MAX_COLS]
            keep_mat = mat[:,MAX_COLS:]
            save_cols = all_col_data[:MAX_COLS]
            keep_cols = all_col_data[MAX_COLS:]
            ### Save
            np.save(save_prefix + ".mat.{}.npy".format(curr_block_index), save_mat)
            np.savetxt(save_prefix + ".cols.{}.txt".format(curr_block_index), save_cols, fmt="%s")
            ### Update variables
            all_mat_data = [keep_mat]
            all_col_data = keep_cols
            curr_block_index += 1
    ### Flush last block
    if len(all_col_data) > 0:
        # import pdb; pdb.set_trace(); ## added by Helen for debugging
        mat = np.hstack(all_mat_data)
        np.save(save_prefix + ".mat.{}.npy".format(curr_block_index), mat)
        np.savetxt(save_prefix + ".cols.{}.txt".format(curr_block_index), all_col_data, fmt="%s")
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
library(conflicted)
conflict_prefer("first", "dplyr")
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("collapse", "dplyr")
conflict_prefer("combine", "dplyr")
packages <- c("optparse","dplyr", "data.table", "reshape2", "ggplot2",
              "tidyr", "textshape","readxl", "gplots", "AnnotationDbi",
              "org.Hs.eg.db", "ggrepel", "gplots")
xfun::pkg_attach(packages)
conflict_prefer("select", "dplyr")


option.list <- list(
    make_option("--project", type="character", default = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211130_test_PoPS.plots/", help="project directory"),
    make_option("--sampleName", type="character", default="2kG.library", help="project name"),
    make_option("--output", type="character", default = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211130_test_PoPS.plots/outputs/", help="output directory"),
    make_option("--figure", type="character", default = "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211130_test_PoPS.plots/figures/", help="figure directory"),
    make_option("--scratch.output", type="character", default="", help="output directory for large files"),
    make_option("--k.val", type="numeric", default=60, help="the value of K in this run"),
    make_option("--PoPS_outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/", help="PoPS output directory"),
    make_option("--coefs_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.coefs", help=""),
    make_option("--marginals_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.marginals", help=""),
    make_option("--preds_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.preds", help=""),
    make_option("--coefs_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.coefs", help=""),
    make_option("--marginals_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.marginals", help=""),
    make_option("--preds_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.preds", help=""),
    make_option("--prefix", type="character", default="CAD_aug6_cNMF60", help="magma file name (before genes.raw)"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--cNMF.features", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/data/features/pops_features_raw/topic.zscore.ensembl.scaled_k_60.dt_0_2.txt", help="normalized cNMF weights, unit variance and zero mean"),
    make_option("--all.features", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211101_normalized_features/outputs/full_features_with_cNMF.RDS", help=".RDS file with all features input into PoPS"),
    make_option("--external.features.metadata", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/metadata/metadata_jul17.txt", help="annotations for each external features"),
    make_option("--combined.preds", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.combined.preds", help="preds file with results from no_cNMF run and with_cNMF run"),
    make_option("--coefs.defining.top.topic.RDS", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60_coefs.defining.top.topic.RDS", help=""),
    make_option("--preds.importance.score.key.columns", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60_preds.importance.score.key.columns.txt", help="")
)
opt <- parse_args(OptionParser(option_list=option.list))


## debug PoPS.plots.R using scRNAseq_11AMDox_1 sample
opt$sampleName <- "scRNAseq_2kG_11AMDox_1"
opt$output <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/"
opt$figure <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/figures/all_genes/scRNAseq_2kG_11AMDox_1/K5/"
opt$scratch.output <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/211101_20sample_snakemake/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/"
opt$prefix <- "CAD_aug6_cNMF5"
opt$k.val <- 5
opt$coefs_with_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5.coefs"
opt$preds_with_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5.preds"
opt$marginals_with_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5.marginals"
opt$coefs_without_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/pops/CAD_aug6.coefs"
opt$preds_without_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/pops/CAD_aug6.preds"
opt$marginals_without_cNMF <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5.marginals"
opt$cNMF.features <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/topic.zscore.ensembl.scaled_k_5.dt_0_2.txt"
opt$all.features <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/full_features_CAD_aug6_cNMF5.RDS"
opt$external.features.metadata <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/metadata/metadata_jul17.txt"
opt$combined.preds <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5.combined.preds"
opt$coefs.defining.top.topic.RDS <- "/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/211101_20sample_snakemake/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5_coefs.defining.top.topic.RDS"
opt$preds.importance.score.key.columns <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211101_20sample_snakemake/analysis/all_genes/scRNAseq_2kG_11AMDox_1/K5/threshold_0_2/pops/CAD_aug6_cNMF5_PoPS_preds.importance.score.key.columns.txt"


SAMPLE=opt$sampleName
OUTDIR=opt$output
SCRATCH.OUTDIR=opt$scratch.output
FIGDIR=opt$figure
PREFIX=opt$prefix
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
DENSITY.THRESHOLD.FILENAME =paste0("dt_", DENSITY.THRESHOLD)
k <- opt$k.val

check.dir <- c(OUTDIR, FIGDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))



## graphing constants
mytheme <- theme_classic() + theme(axis.text = element_text(size = 12), axis.title = element_text(size = 14), plot.title = element_text(hjust = 0.5, face = "bold", size=14))
palette = colorRampPalette(c("#38b4f7", "white", "red"))(n = 100)


## load metadata
meta.data.path <- opt$external.features.metadata
metadata <- read.delim(meta.data.path, stringsAsFactors=F)

## load data
preds <- read.table(file=opt$preds_with_cNMF,header=T, stringsAsFactors=F, sep="\t")
colnames(preds) <- paste0(colnames(preds), "_with.cNMF")
colnames(preds)[1] <- "ENSGID"
preds.before <- read.table(file=paste0(opt$preds_without_cNMF), header=T, stringsAsFactors=F, sep="\t")
colnames(preds.before) <- paste0(colnames(preds.before), "_without.cNMF")
colnames(preds.before)[1] <- "ENSGID"
preds.combined <- merge(preds, preds.before, by="ENSGID")
marginals <- read.table(file=opt$marginals_with_cNMF,header=T, stringsAsFactors=F, sep="\t") %>% merge(metadata, by="X")
coefs <- read.table(file=opt$coefs_with_cNMF,header=T, stringsAsFactors=F, sep="\t")
coefs.df <- coefs[4:nrow(coefs),] %>% merge(metadata, by.x="parameter", by.y="X") %>% arrange(desc(beta)) 
coefs.df$beta <- coefs.df$beta %>% as.numeric


## map ids
x <- org.Hs.egENSEMBL 
mapped_genes <- mappedkeys(x)
xx.entrez.to.ensembl <- as.list(x[mapped_genes]) # EntrezID to Ensembl
xx.ensembl.to.entrez <- as.list(org.Hs.egENSEMBL2EG) # Ensembl to EntrezID


y <- org.Hs.egGENENAME
y_mapped_genes <- mappedkeys(y)
entrez.to.genename <- as.list(y[y_mapped_genes])
genename.to.entrez <- as.list(org.Hs.egGENENAME)


z <- org.Hs.egSYMBOL
z_mapped_genes <- mappedkeys(z)
entrez.to.symbol <- as.list(z[z_mapped_genes])
symbol.to.entrez <- as.list(org.Hs.egSYMBOL)


## preds.df
preds.df <- preds %>% mutate(EntrezID = xx.ensembl.to.entrez[preds$ENSGID %>% as.character] %>% sapply("[[",1)) %>%
    mutate(Gene.name = entrez.to.genename[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
           Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character)

## load all preds
preds.combined.df <- read.delim(file=opt$combined.preds, stringsAsFactors=F)

## load cNMF features
features <- read.delim(opt$cNMF.features, stringsAsFactors=F)

## load all features
all.features <- readRDS(opt$all.features)
features <- read.delim(opt$cNMF.features, stringsAsFactors=F)

# ## load gene x feature importance score
# load(file=paste0(OUTDIR, "/coefs.marginals.feature.outer.prod.RDS"))

## load coefs.defining.top.topic.df
coefs.defining.top.topic.df <- readRDS(file=opt$coefs.defining.top.topic.RDS) # paste0(SCRATCH.OUTDIR, "/", PREFIX, "_coefs.defining.top.topic.RDS"))

## load PoPS importance score with key columns table
PoPS_preds.importance.score.key <- read.delim(file=opt$preds.importance.score.key.columns, stringsAsFactors=F) #paste0(OUTDIR, "/PoPS_preds.importance.score.key.columns.txt"), stringsAsFactors = F)


##################################################
## Plots

## plot the list of topics and their PoPS component scores for gene of interest
gene.set <- c("GOSR2", "TLNRD1", "EDN1", "NOS3", "KLF2", "ERG", "CCM2", "KRIT")
pdf(paste0(FIGDIR, "/", PREFIX, "_", DENSITY.THRESHOLD.FILENAME, "_feature_x_gene.component.importance.score.coefs.pdf"), width=4, height=6)
for (gene.here in gene.set) {
    toPlot <- coefs.defining.top.topic.df %>% subset(grepl(paste0("^",gene.here,"$"), Gene)) %>% select(topic, gene.feature_x_beta)
    p <- toPlot %>% ggplot(aes(x=reorder(topic, gene.feature_x_beta), y=gene.feature_x_beta)) + geom_col(fill="#38b4f7") + theme_minimal() +
        coord_flip() + xlab("Feature (Topic)") + ylab("Feature x Gene\nImportance Score") + ggtitle(paste0(gene.here)) + mytheme
    print(p)

    toPlot <- PoPS_preds.importance.score.key %>% subset(grepl(paste0("^",gene.here,"$"), Gene)) %>% select(Long_Name, gene.feature_x_beta) %>% slice(1:15)
    p <- toPlot %>% ggplot(aes(x=reorder(Long_Name, gene.feature_x_beta), y=gene.feature_x_beta)) + geom_col(fill="#38b4f7") + theme_minimal() +
        coord_flip() + xlab("Features") + ylab("Feature x Gene\nImportance Score") + ggtitle(paste0(gene.here)) + mytheme
}
dev.off()


pdf(paste0(FIGDIR, "/", PREFIX, "_", DENSITY.THRESHOLD.FILENAME, "_Topic.coef.beta.pdf"), width=4, height=6) ## double check figures
toPlot <- coefs.df %>% subset(grepl("zscore",parameter)) %>% arrange(desc(beta))
p <- toPlot %>% ggplot(aes(x=reorder(parameter, beta), y=beta)) + geom_col(fill="#38b4f7") + theme_minimal() +
    coord_flip() + xlab("Features (topic)") + ylab("Beta Score") + ggtitle(paste0("K = ", k, " Topics")) + mytheme
print(p)
dev.off()
pdf(paste0(FIGDIR, "/all.coef.beta.pdf"))
toPlot <- coefs.df %>% arrange(desc(beta)) %>% slice(1:15)
p <- toPlot %>% ggplot(aes(x=reorder(Long_Name, beta), y=beta)) + geom_col(fill="#38b4f7") + theme_minimal() +
    coord_flip() + xlab("Features") + ylab("Beta Score") + ggtitle(paste0("K = ", k, " Topics")) + mytheme
print(p)
dev.off()
pdf(paste0(FIGDIR, "/all.marginals.beta.pdf"))
toPlot <- marginals %>% arrange(desc(beta)) %>% slice(1:15)
p <- toPlot %>% ggplot(aes(x=reorder(Long_Name, beta), y=beta)) + geom_col(fill="#38b4f7") + theme_minimal() +
    coord_flip() + xlab("Features") + ylab("Beta Score") + ggtitle(paste0("K = ", k, " Topics, Ranked by marginals beta score")) + mytheme
print(p)
dev.off()



## figure
pdf(paste0(FIGDIR, "/", PREFIX, "_", DENSITY.THRESHOLD.FILENAME, "_PoPS_score_list.pdf"), width=4, height=6)
top.PoPS.genes <- preds.df %>% slice(1:20)
toPlot <- data.frame(Gene= top.PoPS.genes %>% pull(Gene),
                     Score=top.PoPS.genes %>% pull(PoPS_Score_with.cNMF))
p <- toPlot %>% ggplot(aes(x=reorder(Gene, Score), y=Score) ) + geom_col(fill="#38b4f7") + theme_minimal()
p <- p + coord_flip() + xlab("Top 50 Genes") + ylab("PoPS Score") + ggtitle(paste0(SAMPLE, ", ", PREFIX)) + mytheme
print(p)
dev.off()


## PoPS before vs after score
pdf(paste0(FIGDIR, "/", PREFIX, "_", DENSITY.THRESHOLD.FILENAME, "_before.vs.after.cNMF.pdf"))
labels <- preds.combined.df %>% subset((PoPS_Score_with.cNMF > (2 * PoPS_Score_without.cNMF)) & PoPS_Score_with.cNMF > 1)
p <- preds.combined.df %>% ggplot(aes(x=PoPS_Score_without.cNMF, y=PoPS_Score_with.cNMF)) + geom_point(size=0.5) + mytheme + 
xlab("PoPS Score (without cNMF features)") + ylab("PoPS Score(with cNMF features)") + geom_abline(slope=1, color="red") + geom_text_repel(data=labels, box.padding = 0.5,
                                                                                                     max.overlaps=30,
                                                                                                     aes(label=Gene), size=4,
                                                                                                     color="blue")
print(p)
dev.off()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
import pandas as pd
import numpy as np
import re
import scipy.linalg
import random
import logging
import argparse

from sklearn.linear_model import LinearRegression, RidgeCV, LassoCV
from sklearn.metrics import make_scorer
from scipy.sparse import load_npz
from numpy.linalg import LinAlgError

### --------------------------------- PROGRAM INPUTS --------------------------------- ###

def get_pops_args(argv=None):
    parser = argparse.ArgumentParser(description='...')
    parser.add_argument("--gene_annot_path", help="...")
    parser.add_argument("--feature_mat_prefix", help="...")
    parser.add_argument("--num_feature_chunks", type=int, help="...")
    parser.add_argument("--magma_prefix", help="...")
    parser.add_argument('--use_magma_covariates', dest='use_magma_covariates', action='store_true')
    parser.add_argument('--ignore_magma_covariates', dest='use_magma_covariates', action='store_false')
    parser.set_defaults(use_magma_covariates=True)
    parser.add_argument('--use_magma_error_cov', dest='use_magma_error_cov', action='store_true')
    parser.add_argument('--ignore_magma_error_cov', dest='use_magma_error_cov', action='store_false')
    parser.set_defaults(use_magma_error_cov=True)
    parser.add_argument("--y_path", help="...")
    parser.add_argument("--y_covariates_path", help="...")
    parser.add_argument("--y_error_cov_path", help="...")
    parser.add_argument("--project_out_covariates_chromosomes", nargs="*", help="...")
    parser.add_argument('--project_out_covariates_remove_hla', dest='project_out_covariates_remove_hla', action='store_true')
    parser.add_argument('--project_out_covariates_keep_hla', dest='project_out_covariates_remove_hla', action='store_false')
    parser.set_defaults(project_out_covariates_remove_hla=True)
    parser.add_argument("--subset_features_path", help="...")
    parser.add_argument("--control_features_path", help="...")
    parser.add_argument("--feature_selection_chromosomes", nargs="*", help="...")
    parser.add_argument("--feature_selection_p_cutoff", type=float, default=0.05, help="...")
    parser.add_argument("--feature_selection_max_num", type=int, help="...")
    parser.add_argument("--feature_selection_fss_num_features", type=int, help="...")
    parser.add_argument('--feature_selection_remove_hla', dest='feature_selection_remove_hla', action='store_true')
    parser.add_argument('--feature_selection_keep_hla', dest='feature_selection_remove_hla', action='store_false')
    parser.set_defaults(feature_selection_remove_hla=True)
    parser.add_argument("--training_chromosomes", nargs="*", help="...")
    parser.add_argument('--training_remove_hla', dest='training_remove_hla', action='store_true')
    parser.add_argument('--training_keep_hla', dest='training_remove_hla', action='store_false')
    parser.set_defaults(training_remove_hla=True)
    parser.add_argument("--method", default="ridge", help="...")
    parser.add_argument("--out_prefix", help="...")
    parser.add_argument('--save_matrix_files', dest='save_matrix_files', action='store_true')
    parser.add_argument('--no_save_matrix_files', dest='save_matrix_files', action='store_false')
    parser.set_defaults(save_matrix_files=False)
    parser.add_argument("--random_seed", type=int, default=42, help="...")
    parser.add_argument('--verbose', dest='verbose', action='store_true')
    parser.add_argument('--no_verbose', dest='verbose', action='store_false')
    parser.set_defaults(verbose=False)
    return parser.parse_args(argv)


### --------------------------------- GENERAL --------------------------------- ###

def natural_key(string_):
    """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)]


def get_hla_genes(gene_annot_df):
    sub_gene_annot_df = gene_annot_df[gene_annot_df.CHR == "6"]
    sub_gene_annot_df = sub_gene_annot_df[sub_gene_annot_df.TSS >= 20 * (10 ** 6)]
    sub_gene_annot_df = sub_gene_annot_df[sub_gene_annot_df.TSS <= 40 * (10 ** 6)]
    return sub_gene_annot_df.index.values


### Returns as vector of booleans of length len(Y_ids)
def get_gene_indices_to_use(Y_ids, gene_annot_df, use_chrs, remove_hla):
    all_chr_genes_set = set(gene_annot_df[gene_annot_df.CHR.isin(use_chrs)].index.values)
    if remove_hla == True:
        hla_genes_set = set(get_hla_genes(gene_annot_df))
        use_genes = [True if (g in all_chr_genes_set) and (g not in hla_genes_set) else False for g in Y_ids]
    else:
        use_genes = [True if g in all_chr_genes_set else False for g in Y_ids]
    return np.array(use_genes)


def get_indices_in_target_order(ref_list, target_names):
    ref_to_ind_mapper = {}
    for i, e in enumerate(ref_list):
        ref_to_ind_mapper[e] = i
    return np.array([ref_to_ind_mapper[t] for t in target_names])

### --------------------------------- READING DATA --------------------------------- ###

def read_gene_annot_df(gene_annot_path):
    gene_annot_df = pd.read_csv(gene_annot_path, delim_whitespace=True).set_index("ENSGID")
    gene_annot_df["CHR"] = gene_annot_df["CHR"].astype(str)
    return gene_annot_df


def read_magma(magma_prefix, use_magma_covariates, use_magma_error_cov):
    ### Get Y and Y_ids
    magma_df = pd.read_csv(magma_prefix + ".genes.out", delim_whitespace=True)
    Y = magma_df.ZSTAT.values
    Y_ids = magma_df.GENE.values
    if use_magma_covariates is not None or use_magma_error_cov is not None:
        ### Get covariates and error_cov
        sigmas, gene_metadata = munge_magma_covariance_metadata(magma_prefix + ".genes.raw")
        cov_df = build_control_covariates(gene_metadata)
        ### Process
        assert (cov_df.index.values == Y_ids).all(), "Covariate ids and Y ids don't match."
        covariates = cov_df.values
        error_cov = scipy.linalg.block_diag(*sigmas)
    if use_magma_covariates == False:
        covariates = None
    if use_magma_error_cov == False:
        error_cov = None
    return Y, covariates, error_cov, Y_ids


def munge_magma_covariance_metadata(magma_raw_path):
    sigmas = []
    gene_metadata = []
    with open(magma_raw_path) as f:
        ### Get all lines
        lines = list(f)[2:]
        lines = [np.asarray(line.strip('\n').split(' ')) for line in lines]
        ### Check that chromosomes are sequentially ordered
        all_chroms = np.array([l[1] for l in lines])
        all_seq_breaks = np.where(all_chroms[:-1] != all_chroms[1:])[0]
        assert len(all_seq_breaks) == len(set(all_chroms)) - 1, "Chromosomes are not sequentially ordered."
        ### Get starting chromosome and set up temporary variables
        curr_chrom = lines[0][1]
        curr_ind = 0
        num_genes_in_chr = sum([1 for line in lines if line[1] == curr_chrom])
        curr_sigma = np.zeros((num_genes_in_chr, num_genes_in_chr))
        curr_gene_metadata = []
        for line in lines:
            ### If we move to a new chromosome, we reset everything
            if line[1] != curr_chrom:
                ### Symmetrize and save
                sigmas.append(curr_sigma + curr_sigma.T + np.eye(curr_sigma.shape[0]))
                gene_metadata.append(curr_gene_metadata)
                ### Reset
                curr_chrom = line[1]
                curr_ind = 0
                num_genes_in_chr = sum([1 for line in lines if line[1] == curr_chrom])
                curr_sigma = np.zeros((num_genes_in_chr, num_genes_in_chr))
                curr_gene_metadata = []
            ### Add metadata; GENE, NSNPS, NPARAM, MAC
            curr_gene_metadata.append([line[0], float(line[4]), float(line[5]), float(line[7])])
            if len(line) > 9:
                ### Add covariance
                gene_corrs = np.array([float(c) for c in line[9:]])
                curr_sigma[curr_ind, curr_ind - gene_corrs.shape[0]:curr_ind] = gene_corrs
            curr_ind += 1
        ### Save last piece
        sigmas.append(curr_sigma + curr_sigma.T + np.eye(curr_sigma.shape[0]))
        gene_metadata.append(curr_gene_metadata)
    gene_metadata = pd.DataFrame(np.vstack(gene_metadata), columns=["GENE", "NSNPS", "NPARAM", "MAC"])
    gene_metadata.NSNPS = gene_metadata.NSNPS.astype(np.float64)
    gene_metadata.NPARAM = gene_metadata.NPARAM.astype(np.float64)
    gene_metadata.MAC = gene_metadata.MAC.astype(np.float64)
    return sigmas, gene_metadata


def build_control_covariates(metadata):
    genesize = metadata.NPARAM.values
    genedensity = metadata.NPARAM.values/metadata.NSNPS.values
    inverse_mac = 1.0/metadata.MAC.values
    cov = np.stack((genesize, np.log(genesize), genedensity, np.log(genedensity), inverse_mac, np.log(inverse_mac)), axis=1)
    cov_df = pd.DataFrame(cov, columns=["gene_size", "log_gene_size", "gene_density", "log_gene_density", "inverse_mac", "log_inverse_mac"])
    cov_df["GENE"] = metadata.GENE.values
    cov_df = cov_df.loc[:,["GENE", "gene_size", "log_gene_size", "gene_density", "log_gene_density", "inverse_mac", "log_inverse_mac"]]
    cov_df = cov_df.set_index("GENE")
    return cov_df


def read_error_cov_from_y(y_error_cov_path, Y_ids):
    ### Will try to read in as a: scipy sparse .npz, numpy .npy
    error_cov = None
    try:
        error_cov = load_npz(y_error_cov_path)
        error_cov = np.array(error_cov.todense())
    except AttributeError as ev:
        error_cov = np.load(y_error_cov_path)
    if error_cov is None:
        raise IOError("Error reading from {}. Make sure data is in scipy .npz or numpy .npy format.".format(y_error_cov_path))
    assert error_cov.shape[0] == error_cov.shape[1], "Error covariance is not square."
    assert error_cov.shape[0] == len(Y_ids), "Error covariance does not match dimensions of Y."
    return error_cov


def read_from_y(y_path, y_covariates_path, y_error_cov_path):
    ### Get Y and Y_ids
    y_df = pd.read_csv(y_path, sep="\t")
    Y = y_df.Score.values
    Y_ids = y_df.ENSGID.values
    ### Read in covariates and error_cov
    covariates = None
    error_cov = None
    if y_covariates_path is not None:
        covariates = pd.read_csv(y_covariates_path, sep="\t", index_col="ENSGID").astype(np.float64)
        covariates = covariates.loc[Y_ids].values
    if y_error_cov_path is not None:
        error_cov = read_error_cov_from_y(y_error_cov_path, Y_ids)
    return Y, covariates, error_cov, Y_ids


### --------------------------------- PROCESSING DATA --------------------------------- ###

def block_Linv(A, block_labels):
    block_labels = np.array(block_labels)
    Linv = np.zeros(A.shape)
    for l in set(block_labels):
        subset_ind = (block_labels == l)
        sub_A = A[np.ix_(subset_ind, subset_ind)]
        Linv[np.ix_(subset_ind, subset_ind)] = np.linalg.inv(np.linalg.cholesky(sub_A))
    return Linv


def block_AB(A, block_labels, B):
    block_labels = np.array(block_labels)
    new_B = np.zeros(B.shape)
    for l in set(block_labels):
        subset_ind = (block_labels == l)
        new_B[subset_ind] = A[np.ix_(subset_ind, subset_ind)].dot(B[subset_ind])
    return new_B


def block_BA(A, block_labels, B):
    block_labels = np.array(block_labels)
    new_B = np.zeros(B.shape)
    for l in set(block_labels):
        subset_ind = (block_labels == l)
        new_B[:,subset_ind] = B[:,subset_ind].dot(A[np.ix_(subset_ind, subset_ind)])
    return new_B


def regularize_error_cov(error_cov, Y, Y_ids, gene_annot_df):
    Y_chr = gene_annot_df.loc[Y_ids].CHR.values
    min_lambda = 0
    for c in set(Y_chr):
        subset_ind = Y_chr == c
        W = np.linalg.eigvalsh(error_cov[np.ix_(subset_ind, subset_ind)])
        min_lambda = min(min_lambda, min(W))
    ridge = abs(min(min_lambda, 0))+.05+.9*max(0, np.var(Y)-1)
    return error_cov + np.eye(error_cov.shape[0]) * ridge


def project_out_covariates(Y, covariates, error_cov, Y_ids, gene_annot_df, project_out_covariates_Y_gene_inds):
    ### If covariates doesn't contain intercept, add intercept
    if not np.isclose(covariates.var(axis=0), 0).any():
        covariates = np.hstack((covariates, np.ones((covariates.shape[0], 1))))
    X_train, y_train = covariates[project_out_covariates_Y_gene_inds], Y[project_out_covariates_Y_gene_inds]
    if error_cov is not None:
        sub_error_cov = error_cov[np.ix_(project_out_covariates_Y_gene_inds, project_out_covariates_Y_gene_inds)]
        sub_error_cov_labels = gene_annot_df.loc[Y_ids[project_out_covariates_Y_gene_inds]].CHR.values
        Linv = block_Linv(sub_error_cov, sub_error_cov_labels)
        X_train, y_train = block_AB(Linv, sub_error_cov_labels, X_train), block_AB(Linv, sub_error_cov_labels, y_train)
    reg = LinearRegression(fit_intercept=False).fit(X_train, y_train)
    Y_proj = Y - reg.predict(covariates)
    return Y_proj


def project_out_V(M, V):
    gram_inv = np.linalg.inv(V.T.dot(V))
    moment = V.T.dot(M)
    betas = gram_inv.dot(moment)
    M_res = M - V.dot(betas)
    return M_res

### --------------------------------- FEATURE SELECTION --------------------------------- ###

def batch_marginal_ols(Y, X):
    ### Save current error settings and set divide to ignore
    old_settings = np.seterr(divide='ignore')
    ### Does not include intercept; we assume that's been projected out already
    sum_sq_X = np.sum(np.square(X), axis=0)
    ### If near-constant to 0 then set to nan. Make a safe copy so we don't get divide by 0 errors.
    near_const_0 = np.isclose(sum_sq_X, 0)
    sum_sq_X_safe = sum_sq_X.copy()
    sum_sq_X_safe[near_const_0] = 1
    betas = Y.dot(X) / sum_sq_X_safe
    mse = np.mean(np.square(Y.reshape(-1,1) - X * betas), axis=0)
    se = np.sqrt(mse / sum_sq_X_safe)
    z = betas / se
    chi2 = np.square(z)
    pvals = scipy.stats.chi2.sf(chi2, 1)
    r2 = 1 - (mse / np.var(Y))
    ### Set everything that's near-constant to 0 to be nan
    betas[near_const_0] = np.nan
    se[near_const_0] = np.nan
    pvals[near_const_0] = np.nan
    r2[near_const_0] = np.nan
    ### Reset error settings to old
    np.seterr(**old_settings)
    return betas, se, pvals, r2


### Accepts covariates, error_cov = None
def compute_marginal_assoc(feature_mat_prefix, num_feature_chunks, Y, Y_ids, covariates, error_cov, gene_annot_df, feature_selection_Y_gene_inds):
    ### Get Y data
    feature_selection_genes = Y_ids[feature_selection_Y_gene_inds]
    sub_Y = Y[feature_selection_Y_gene_inds]
    ### Add intercept if no near-constant feature
    if covariates is not None and not np.isclose(covariates.var(axis=0), 0).any():
        covariates = np.hstack((covariates, np.ones((covariates.shape[0], 1))))
    elif covariates is None:
        ### If no covariates then make intercept as only covariate
        covariates = np.ones((Y.shape[0], 1)) 
    sub_covariates = covariates[feature_selection_Y_gene_inds]
    if error_cov is not None:
        sub_error_cov = error_cov[np.ix_(feature_selection_Y_gene_inds, feature_selection_Y_gene_inds)]
        sub_error_cov_labels = gene_annot_df.loc[feature_selection_genes].CHR.values
        Linv = block_Linv(sub_error_cov, sub_error_cov_labels)
        sub_Y = block_AB(Linv, sub_error_cov_labels, sub_Y)
        sub_covariates = block_AB(Linv, sub_error_cov_labels, sub_covariates)
    ### Project covariates out of sub_Y
    sub_Y = project_out_V(sub_Y.reshape(-1,1), sub_covariates).flatten()
    ### Get X training indices
    rows = np.loadtxt(feature_mat_prefix + ".rows.txt", dtype=str).flatten()
    X_train_inds = get_indices_in_target_order(rows, feature_selection_genes)
    ### Loop through and get marginal association data
    marginal_assoc_data = []
    all_cols = []
    for i in range(num_feature_chunks):
        mat = np.load(feature_mat_prefix + ".mat.{}.npy".format(i))
        mat = mat[X_train_inds]
        cols = np.loadtxt(feature_mat_prefix + ".cols.{}.txt".format(i), dtype=str).flatten()
        ### Apply error covariance transformation if available
        if error_cov is not None:
            mat = block_AB(Linv, sub_error_cov_labels, mat)
        ### Project out covariates
        mat = project_out_V(mat, sub_covariates)
        ### Compute marginal associations
        marginal_assoc_data.append(np.vstack(batch_marginal_ols(sub_Y, mat)).T)
        all_cols.append(cols)
    marginal_assoc_data = np.vstack(marginal_assoc_data)
    all_cols = np.hstack(all_cols)
    marginal_assoc_df = pd.DataFrame(marginal_assoc_data, columns=["beta", "se", "pval", "r2"], index=all_cols)
    return marginal_assoc_df


### Note that subset_features overrides control_features.
### That is: we do not include control features that are not contained in subset features
### Also, control features do not count toward feature_selection_max_num
def select_features_from_marginal_assoc_df(marginal_assoc_df,
                                           subset_features_path,
                                           control_features_path,
                                           feature_selection_p_cutoff,
                                           feature_selection_max_num):
    ### Subset to subset_features
    if subset_features_path is not None:
        subset_features = np.loadtxt(subset_features_path, dtype=str).flatten()
        marginal_assoc_df = marginal_assoc_df.loc[subset_features]
    ### Get control_features contained in currently subsetted features, and set those aside
    if control_features_path is not None:
        control_features = np.loadtxt(control_features_path, dtype=str).flatten()
        control_df = marginal_assoc_df[marginal_assoc_df.index.isin(control_features)]
        marginal_assoc_df = marginal_assoc_df[~marginal_assoc_df.index.isin(control_features)]
    ### Subset to features that pass p-value cutoff
    if feature_selection_p_cutoff is not None:
        marginal_assoc_df = marginal_assoc_df[marginal_assoc_df.pval < feature_selection_p_cutoff]
    ### Enforce maximum number of features
    if feature_selection_max_num is not None:
        marginal_assoc_df = marginal_assoc_df.sort_values("pval").iloc[:feature_selection_max_num]
    ### Get selected features
    selected_features = list(marginal_assoc_df.index.values)
    ### Combine with control features
    if control_features_path is not None:
        selected_features = selected_features + list(control_df.index.values)
    return selected_features


def load_feature_matrix(feature_mat_prefix, num_feature_chunks, selected_features):
    if selected_features is not None:
        selected_features_set = set(selected_features)
    rows = np.loadtxt(feature_mat_prefix + ".rows.txt", dtype=str).flatten()
    all_mats = []
    all_cols = []
    for i in range(num_feature_chunks):
        mat = np.load(feature_mat_prefix + ".mat.{}.npy".format(i))
        cols = np.loadtxt(feature_mat_prefix + ".cols.{}.txt".format(i), dtype=str).flatten()
        if selected_features is not None:
            keep_inds = [True if c in selected_features_set else False for c in cols]
            mat = mat[:,keep_inds]
            cols = cols[keep_inds]
        all_mats.append(mat)
        all_cols.append(cols)
    mat = np.hstack(all_mats)
    cols = np.hstack(all_cols)
    return mat, cols, rows


def add_feature_to_covariates(covariates, covariates_ids, feature_mat_prefix, num_feature_chunks, feature_name):
    ### Get X indices
    rows = np.loadtxt(feature_mat_prefix + ".rows.txt", dtype=str).flatten()
    X_inds = get_indices_in_target_order(rows, covariates_ids)
    for i in range(num_feature_chunks):
        cols = np.loadtxt(feature_mat_prefix + ".cols.{}.txt".format(i), dtype=str).flatten()
        if feature_name in cols:
            mat = np.load(feature_mat_prefix + ".mat.{}.npy".format(i))[X_inds]
            f = mat[:,np.where(cols == feature_name)[0]]
            break
    covariates = np.hstack((covariates, f))
    return covariates


def forward_stepwise_selection(feature_mat_prefix, num_feature_chunks, Y, Y_ids, covariates, error_cov, gene_annot_df, feature_selection_Y_gene_inds, num_features_to_select):
    if covariates is None:
        covariates = np.ones((Y.shape[0], 1))
    selected_features = []
    for i in range(num_features_to_select):
        logging.info("FORWARD STEPWISE SELECTION: {} features selected".format(len(selected_features)))
        marginal_assoc_df = compute_marginal_assoc(feature_mat_prefix, num_feature_chunks, Y, Y_ids, covariates, error_cov, gene_annot_df, feature_selection_Y_gene_inds)
        top_feature = marginal_assoc_df[~marginal_assoc_df.index.isin(selected_features)].sort_values("pval").index.values[0]
        selected_features.append(top_feature)
        covariates = add_feature_to_covariates(covariates, Y_ids, feature_mat_prefix, num_feature_chunks, top_feature)
    return selected_features


### --------------------------------- MODEL FITTING --------------------------------- ###

def build_training(mat, cols, rows, Y, Y_ids, error_cov, gene_annot_df, training_Y_gene_inds, project_out_intercept=True):
    ### Get training Y
    training_genes = Y_ids[training_Y_gene_inds]
    sub_Y = Y[training_Y_gene_inds]
    intercept = np.ones((sub_Y.shape[0], 1)) ### Make intercept
    ### Get training X
    X_train_inds = get_indices_in_target_order(rows, training_genes)
    X = mat[X_train_inds]
    assert (rows[X_train_inds] == training_genes).all(), "Something went wrong. This shouldn't happen."
    ### Apply error covariance
    if error_cov is not None:
        sub_error_cov = error_cov[np.ix_(training_Y_gene_inds, training_Y_gene_inds)]
        sub_error_cov_labels = gene_annot_df.loc[training_genes].CHR.values
        Linv = block_Linv(sub_error_cov, sub_error_cov_labels)
        sub_Y = block_AB(Linv, sub_error_cov_labels, sub_Y)
        X = block_AB(Linv, sub_error_cov_labels, X)
        intercept = block_AB(Linv, sub_error_cov_labels, intercept)
    if project_out_intercept == True:
        ### Project out intercept
        sub_Y = project_out_V(sub_Y.reshape(-1,1), intercept).flatten()
        X = project_out_V(X, intercept)
    return X, sub_Y


# def corr_score(Y, Y_pred):
#     score = scipy.stats.pearsonr(Y, Y_pred)[0]
#     return score


def initialize_regressor(method, random_state):
    # scorer = make_scorer(corr_score)
    if method == "ridge":
        alphas = np.logspace(-2, 10, num=25)
        # reg = RidgeCV(fit_intercept=False, alphas=alphas, scoring=scorer)
        # logging.info("Model = RidgeCV with 25 alphas, generalized leave-one-out cross-validation, held-out Pearson correlation as scoring metric.")
        reg = RidgeCV(fit_intercept=False, alphas=alphas)
        logging.info("Model = RidgeCV with 25 alphas, generalized leave-one-out cross-validation, NMSE as scoring metric.")
    elif method == 'lasso':
        alphas = np.logspace(-2, 10, num=25)
        reg = LassoCV(fit_intercept=False, alphas=alphas, random_state=random_state, selection="random")
        logging.info("Model = LassoCV with 25 alphas, 5-fold cross-validation, mean-squared error as scoring metric.")
    elif method == 'linreg':
        ### Note that this solves using pseudo-inverse if # features > # samples, corresponding to minimum norm OLS
        reg = LinearRegression(fit_intercept=False)
        logging.info("Model = LinearRegression. Note that this solves using the pseudo-inverse if # features > # samples, corresponding to minimum norm OLS.")
    return reg


### A custom function to replace sklearn RidgeCV solver if needed. Solves using gesvd instead of gesdd
def _svd_decompose_design_matrix_custom(self, X, y, sqrt_sw):
    # X already centered
    X_mean = np.zeros(X.shape[1], dtype=X.dtype)
    if self.fit_intercept:
        # to emulate fit_intercept=True situation, add a column
        # containing the square roots of the sample weights
        # by centering, the other columns are orthogonal to that one
        intercept_column = sqrt_sw[:, None]
        X = np.hstack((X, intercept_column))
    U, singvals, _ = scipy.linalg.svd(X, full_matrices=0, lapack_driver="gesvd")
    singvals_sq = singvals ** 2
    UT_y = np.dot(U.T, y)
    return X_mean, singvals_sq, U, UT_y


### Original function in _RidgeGCV
def _svd_decompose_design_matrix_original(self, X, y, sqrt_sw):
    # X already centered
    X_mean = np.zeros(X.shape[1], dtype=X.dtype)
    if self.fit_intercept:
        # to emulate fit_intercept=True situation, add a column
        # containing the square roots of the sample weights
        # by centering, the other columns are orthogonal to that one
        intercept_column = sqrt_sw[:, None]
        X = np.hstack((X, intercept_column))
    U, singvals, _ = scipy.linalg.svd(X, full_matrices=0)
    singvals_sq = singvals ** 2
    UT_y = np.dot(U.T, y)
    return X_mean, singvals_sq, U, UT_y


### A custom function to replace sklearn LinearRegression fit if needed. Solves using gelss
def _linear_regression_fit_custom(self, X, y, sample_weight=None):
    ### Importing all the base functions needed to run the monkey-patched solver
    from sklearn.linear_model._base import _check_sample_weight, _rescale_data, Parallel, delayed, optimize, sp, sparse, sparse_lsqr, linalg
    n_jobs_ = self.n_jobs
    accept_sparse = False if self.positive else ['csr', 'csc', 'coo']
    X, y = self._validate_data(X, y, accept_sparse=accept_sparse,
                               y_numeric=True, multi_output=True)
    if sample_weight is not None:
        sample_weight = _check_sample_weight(sample_weight, X,
                                             dtype=X.dtype)
    X, y, X_offset, y_offset, X_scale = self._preprocess_data(
        X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
        copy=self.copy_X, sample_weight=sample_weight,
        return_mean=True)
    if sample_weight is not None:
        # Sample weight can be implemented via a simple rescaling.
        X, y = _rescale_data(X, y, sample_weight)
    if self.positive:
        if y.ndim < 2:
            self.coef_, self._residues = optimize.nnls(X, y)
        else:
            # scipy.optimize.nnls cannot handle y with shape (M, K)
            outs = Parallel(n_jobs=n_jobs_)(
                    delayed(optimize.nnls)(X, y[:, j])
                    for j in range(y.shape[1]))
            self.coef_, self._residues = map(np.vstack, zip(*outs))
    elif sp.issparse(X):
        X_offset_scale = X_offset / X_scale
        def matvec(b):
            return X.dot(b) - b.dot(X_offset_scale)
        def rmatvec(b):
            return X.T.dot(b) - X_offset_scale * np.sum(b)
        X_centered = sparse.linalg.LinearOperator(shape=X.shape,
                                                  matvec=matvec,
                                                  rmatvec=rmatvec)
        if y.ndim < 2:
            out = sparse_lsqr(X_centered, y)
            self.coef_ = out[0]
            self._residues = out[3]
        else:
            # sparse_lstsq cannot handle y with shape (M, K)
            outs = Parallel(n_jobs=n_jobs_)(
                delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
                for j in range(y.shape[1]))
            self.coef_ = np.vstack([out[0] for out in outs])
            self._residues = np.vstack([out[3] for out in outs])
    else:
        self.coef_, self._residues, self.rank_, self.singular_ = \
            linalg.lstsq(X, y, lapack_driver="gelss")
        self.coef_ = self.coef_.T
    if y.ndim == 1:
        self.coef_ = np.ravel(self.coef_)
    self._set_intercept(X_offset, y_offset, X_scale)
    return self


### Original function in LinearRegression
def _linear_regression_fit_original(self, X, y, sample_weight=None):
    ### Importing all the base functions needed to run the monkey-patched solver
    from sklearn.linear_model._base import _check_sample_weight, _rescale_data, Parallel, delayed, optimize, sp, sparse, sparse_lsqr, linalg
    n_jobs_ = self.n_jobs
    accept_sparse = False if self.positive else ['csr', 'csc', 'coo']
    X, y = self._validate_data(X, y, accept_sparse=accept_sparse,
                               y_numeric=True, multi_output=True)
    if sample_weight is not None:
        sample_weight = _check_sample_weight(sample_weight, X,
                                             dtype=X.dtype)
    X, y, X_offset, y_offset, X_scale = self._preprocess_data(
        X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
        copy=self.copy_X, sample_weight=sample_weight,
        return_mean=True)
    if sample_weight is not None:
        # Sample weight can be implemented via a simple rescaling.
        X, y = _rescale_data(X, y, sample_weight)
    if self.positive:
        if y.ndim < 2:
            self.coef_, self._residues = optimize.nnls(X, y)
        else:
            # scipy.optimize.nnls cannot handle y with shape (M, K)
            outs = Parallel(n_jobs=n_jobs_)(
                    delayed(optimize.nnls)(X, y[:, j])
                    for j in range(y.shape[1]))
            self.coef_, self._residues = map(np.vstack, zip(*outs))
    elif sp.issparse(X):
        X_offset_scale = X_offset / X_scale
        def matvec(b):
            return X.dot(b) - b.dot(X_offset_scale)
        def rmatvec(b):
            return X.T.dot(b) - X_offset_scale * np.sum(b)
        X_centered = sparse.linalg.LinearOperator(shape=X.shape,
                                                  matvec=matvec,
                                                  rmatvec=rmatvec)
        if y.ndim < 2:
            out = sparse_lsqr(X_centered, y)
            self.coef_ = out[0]
            self._residues = out[3]
        else:
            # sparse_lstsq cannot handle y with shape (M, K)
            outs = Parallel(n_jobs=n_jobs_)(
                delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
                for j in range(y.shape[1]))
            self.coef_ = np.vstack([out[0] for out in outs])
            self._residues = np.vstack([out[3] for out in outs])
    else:
        self.coef_, self._residues, self.rank_, self.singular_ = \
            linalg.lstsq(X, y)
        self.coef_ = self.coef_.T
    if y.ndim == 1:
        self.coef_ = np.ravel(self.coef_)
    self._set_intercept(X_offset, y_offset, X_scale)
    return self


def compute_coefficients(X_train, Y_train, cols, method, random_state):
    if method not in ["ridge", "lasso", "linreg"]:
        raise ValueError("Invalid argument for \"method\". Must be one of \"ridge\", \"lasso\", or \"linreg\".")
    reg = initialize_regressor(method, random_state)
    logging.info("Fitting model.")
    try:
        reg.fit(X_train, Y_train)
    except LinAlgError as err:
        if method == "ridge":
            logging.warning(("First ridge regression failed with LinAlgError. Will re-run once more. "
                             "This is due to a rare but documented issue with LAPACK. "
                             "To attempt to circumvent this issue, we monkey-patch sklearn's _RidgeGCV to call scipy.linalg.svd with lapack_driver=\"gesvd\" instead of \"gesdd\". "
                             "This seems to solve the problem but behavior is not guaranteed. "
                             "For more details, see "
                             "https://mathematica.stackexchange.com/questions/143894/sporadic-numerical-convergence-failure-of-singularvaluedecomposition-message-s"))
            logging.info("Re-running ridge regression with monkey-patched solver.")
            ### Import module and monkey patch
            import sklearn.linear_model._ridge as sklm
            sklm._RidgeGCV._svd_decompose_design_matrix = _svd_decompose_design_matrix_custom
            ### Re-initialize regressor
            reg = initialize_regressor(method, random_state)
            ### Re-fit
            reg.fit(X_train, Y_train)
            logging.info("Restoring original solver to _RidgeGCV class.")
            sklm._RidgeGCV._svd_decompose_design_matrix = _svd_decompose_design_matrix_original
        elif method == "linreg":
            logging.warning(("First linear regression failed with LinAlgError. Will re-run once more. "
                             "This is due to a rare but documented issue with LAPACK. "
                             "To attempt to circumvent this issue, we monkey-patch sklearn's LinearRegression class to call scipy.linalg.lstsq with lapack_driver=\"gelss\". "
                             "This seems to solve the problem but behavior is not guaranteed. "
                             "For more details, see "
                             "https://mathematica.stackexchange.com/questions/143894/sporadic-numerical-convergence-failure-of-singularvaluedecomposition-message-s"))
            logging.info("Re-running linear regression with monkey-patched solver.")
            ### Import module and monkey patch
            import sklearn.linear_model._base as sklm
            sklm.LinearRegression.fit = _linear_regression_fit_custom
            ### Re-initialize regressor
            reg = initialize_regressor(method, random_state)
            ### Re-fit
            reg.fit(X_train, Y_train)
            logging.info("Restoring original solver to LinearRegression class.")
            sklm.LinearRegression.fit = _linear_regression_fit_original
        else:
            raise err
    if method == "ridge":
        coefs_df = pd.DataFrame([["METHOD", "RidgeCV"],
                                 ["SELECTED_CV_ALPHA", reg.alpha_],
                                 ["BEST_CV_SCORE", reg.best_score_]])
        coefs_df = pd.concat([coefs_df, pd.DataFrame([cols, reg.coef_]).T])
        coefs_df.columns = ["parameter", "beta"]
        coefs_df = coefs_df.set_index("parameter")
    elif method == "lasso":
        best_score = reg.mse_path_[np.where(reg.alphas_ == reg.alpha_)[0][0]].mean()
        coefs_df = pd.DataFrame([["METHOD", "LassoCV"],
                                 ["SELECTED_CV_ALPHA", reg.alpha_],
                                 ["BEST_CV_SCORE", best_score]])
        coefs_df = pd.concat([coefs_df, pd.DataFrame([cols, reg.coef_]).T])
        coefs_df.columns = ["parameter", "beta"]
        coefs_df = coefs_df.set_index("parameter")
    elif method == "linreg":
        coefs_df = pd.DataFrame([["METHOD", "LinearRegression"]])
        coefs_df = pd.concat([coefs_df, pd.DataFrame([cols, reg.coef_]).T])
        coefs_df.columns = ["parameter", "beta"]
        coefs_df = coefs_df.set_index("parameter")
    return coefs_df


def pops_predict(mat, rows, cols, coefs_df):
    pred = mat.dot(coefs_df.loc[cols].beta.values)
    preds_df = pd.DataFrame([rows, pred]).T
    preds_df.columns = ["ENSGID", "PoPS_Score"]
    return preds_df

### --------------------------------- MAIN --------------------------------- ###

def main(config_dict):
    ### --------------------------------- Basic settings --------------------------------- ###
    ### Set logging settings
    if config_dict["verbose"]:
        logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
        logging.info("Verbose output enabled.")
    else:
        logging.basicConfig(format="%(levelname)s: %(message)s")
    ### Set random seeds
    np.random.seed(config_dict["random_seed"])
    random.seed(config_dict["random_seed"])

    ### Display configs
    logging.info("Config dict = {}".format(str(config_dict)))

    ### --------------------------------- Reading/processing data --------------------------------- ###
    gene_annot_df = read_gene_annot_df(config_dict["gene_annot_path"])
    ### If chromosome arguments are None, replace their values in config_dict with all chromosomes
    all_chromosomes = sorted(gene_annot_df.CHR.unique(), key=natural_key)
    if config_dict["project_out_covariates_chromosomes"] is None:
        config_dict["project_out_covariates_chromosomes"] = all_chromosomes
        logging.info("--project_out_covariates_chromosomes is None, defaulting to all chromosomes")
    if config_dict["feature_selection_chromosomes"] is None:
        config_dict["feature_selection_chromosomes"] = all_chromosomes
        logging.info("--feature_selection_chromosomes is None, defaulting to all chromosomes")
    if config_dict["training_chromosomes"] is None:
        config_dict["training_chromosomes"] = all_chromosomes
        logging.info("--training_chromosomes is None, defaulting to all chromosomes")
    ### Make sure all chromosome arguments are fully contained in gene_annot_df's chromosome list
    assert set(config_dict["project_out_covariates_chromosomes"]).issubset(all_chromosomes), "Invalid --project_out_covariates_chromosomes argument."
    assert set(config_dict["feature_selection_chromosomes"]).issubset(all_chromosomes), "Invalid --feature_selection_chromosomes argument."
    assert set(config_dict["training_chromosomes"]).issubset(all_chromosomes), "Invalid --training_chromosomes argument."
    ### Read in scores
    if config_dict["magma_prefix"] is not None:
        logging.info("MAGMA scores provided, loading MAGMA.")
        Y, covariates, error_cov, Y_ids = read_magma(config_dict["magma_prefix"],
                                                     config_dict["use_magma_covariates"],
                                                     config_dict["use_magma_error_cov"])
        if config_dict["use_magma_covariates"] == True:
            logging.info("Using MAGMA covariates.")
        else:
            logging.info("Ignoring MAGMA covariates.")
        if config_dict["use_magma_error_cov"] == True:
            logging.info("Using MAGMA error covariance.")
        else:
            logging.info("Ignoring MAGMA error covariance.")
        ### Regularize MAGMA error covariance if using
        if error_cov is not None:
            logging.info("Regularizing MAGMA error covariance.")
            error_cov = regularize_error_cov(error_cov, Y, Y_ids, gene_annot_df)
    elif config_dict["y_path"] is not None:
        logging.info("Reading scores from {}.".format(config_dict["y_path"]))
        if config_dict["y_covariates_path"] is not None:
            logging.info("Reading covariates from {}.".format(config_dict["y_covariates_path"]))
        if config_dict["y_error_cov_path"] is not None:
            logging.info("Reading error covariance from {}.".format(config_dict["y_error_cov_path"]))
        ### Note that we do not regularize covariance matrix provided in y_error_cov_path. It will be used as is.
        Y, covariates, error_cov, Y_ids = read_from_y(config_dict["y_path"],
                                                      config_dict["y_covariates_path"],
                                                      config_dict["y_error_cov_path"])
    else:
        raise ValueError("At least one of --magma_prefix or --y_path must be provided (--magma_prefix overrides --y_path).")
    ### Get projection, feature selection, and training genes
    project_out_covariates_Y_gene_inds = get_gene_indices_to_use(Y_ids,
                                                                 gene_annot_df,
                                                                 config_dict["project_out_covariates_chromosomes"],
                                                                 config_dict["project_out_covariates_remove_hla"])
    feature_selection_Y_gene_inds = get_gene_indices_to_use(Y_ids,
                                                            gene_annot_df,
                                                            config_dict["feature_selection_chromosomes"],
                                                            config_dict["feature_selection_remove_hla"])
    training_Y_gene_inds = get_gene_indices_to_use(Y_ids,
                                                   gene_annot_df,
                                                   config_dict["training_chromosomes"],
                                                   config_dict["training_remove_hla"])
    ### Project out covariates if using
    if covariates is not None:
        logging.info("Projecting {} covariates out of target scores using genes on chromosome {}. HLA region {}."
                     .format(covariates.shape[1],
                             ", ".join(sorted(gene_annot_df.loc[Y_ids[project_out_covariates_Y_gene_inds]].CHR.unique(), key=natural_key)),
                             "removed" if config_dict["project_out_covariates_remove_hla"] else "included"))
        Y_proj = project_out_covariates(Y,
                                        covariates,
                                        error_cov,
                                        Y_ids,
                                        gene_annot_df,
                                        project_out_covariates_Y_gene_inds)
    else:
        Y_proj = Y


    ### --------------------------------- Feature selection --------------------------------- ###
    ### Compute marginal association data frame
    logging.info("Computing marginal association table using genes on chromosome {}. HLA region {}."
                 .format(", ".join(sorted(gene_annot_df.loc[Y_ids[feature_selection_Y_gene_inds]].CHR.unique(), key=natural_key)),
                         "removed" if config_dict["feature_selection_remove_hla"] else "included"))
    marginal_assoc_df = compute_marginal_assoc(config_dict["feature_mat_prefix"],
                                               config_dict["num_feature_chunks"],
                                               Y_proj,
                                               Y_ids,
                                               None,
                                               error_cov,
                                               gene_annot_df,
                                               feature_selection_Y_gene_inds)
    ### Either do FSS or filter marginal_assoc_df
    if config_dict["feature_selection_fss_num_features"] is not None:
        logging.info("--feature_selection_fss_num_features set to {}, so performing forward stepwise selection (overriding all other feature selection settings).".format(config_dict["feature_selection_fss_num_features"]))
        selected_features = forward_stepwise_selection(config_dict["feature_mat_prefix"],
                                                       config_dict["num_feature_chunks"],
                                                       Y_proj,
                                                       Y_ids,
                                                       None,
                                                       error_cov,
                                                       gene_annot_df,
                                                       feature_selection_Y_gene_inds,
                                                       config_dict["feature_selection_fss_num_features"])
        marginal_assoc_df["selected"] = marginal_assoc_df.index.isin(selected_features)
        ### Annotate with selection rank
        marginal_assoc_df["selection_rank"] = np.nan
        for i in range(len(selected_features)):
            marginal_assoc_df.loc[selected_features[i], "selection_rank"] = i + 1
        logging.info("Forward stepwise selection complete, {} features in model.".format(len(selected_features)))
    else:
        ### Filter features based on settings
        selected_features = select_features_from_marginal_assoc_df(marginal_assoc_df,
                                                                   config_dict["subset_features_path"],
                                                                   config_dict["control_features_path"],
                                                                   config_dict["feature_selection_p_cutoff"],
                                                                   config_dict["feature_selection_max_num"])
        ### Annotate marginal_assoc_df with selected True/False
        marginal_assoc_df["selected"] = marginal_assoc_df.index.isin(selected_features)
        ### Explicitly set features with nan p-values to not-selected
        marginal_assoc_df["selected"] = marginal_assoc_df["selected"] & ~pd.isnull(marginal_assoc_df.pval)
        ### Redefine selected_features
        selected_features = marginal_assoc_df[marginal_assoc_df.selected].index.values
        ### Complex logging statement
        select_feat_logtxt_pieces = []
        if config_dict["subset_features_path"] is not None:
            select_feat_logtxt_pieces.append("subsetting to features at {}".format(config_dict["subset_features_path"]))
        if config_dict["feature_selection_p_cutoff"] is not None:
            if config_dict["feature_selection_max_num"] is not None:
                select_feat_logtxt_pieces.append("filtering to top {} features with p-value < {}"
                                                 .format(config_dict["feature_selection_max_num"],
                                                         config_dict["feature_selection_p_cutoff"]))
            else:
                select_feat_logtxt_pieces.append("filtering to features with p-value < {}"
                                                 .format(config_dict["feature_selection_p_cutoff"]))
        elif config_dict["feature_selection_max_num"] is not None:
            select_feat_logtxt_pieces.append("filtering to top {} features by p-value"
                                             .format(config_dict["feature_selection_max_num"]))
        if config_dict["control_features_path"] is not None:
            select_feat_logtxt_pieces.append("unioning with non-constant control features")
        ### Combine complex logging statement
        if len(select_feat_logtxt_pieces) == 0:
            select_feat_logtxt = ("{} features reamin in model.".format(len(selected_features)))
        if len(select_feat_logtxt_pieces) == 1:
            select_feat_logtxt = ("After {}, {} features remain in model."
                                  .format(select_feat_logtxt_pieces[0], len(selected_features)))
        elif len(select_feat_logtxt_pieces) == 2:
            select_feat_logtxt = ("After {} and {}, {} features remain in model."
                                  .format(select_feat_logtxt_pieces[0], select_feat_logtxt_pieces[1], len(selected_features)))
        elif len(select_feat_logtxt_pieces) == 3:
            select_feat_logtxt = ("After {}, {}, and {}, {} features remain in model."
                                  .format(select_feat_logtxt_pieces[0], select_feat_logtxt_pieces[1], select_feat_logtxt_pieces[2], len(selected_features)))
        logging.info(select_feat_logtxt)


    ### --------------------------------- Training --------------------------------- ###
    ### Load data
    ### Won't necessarily load in order of selected_features. Loads in order of matrix columns.
    ### Note: doesn't raise error if trying to select feature that isn't in columns
    mat, cols, rows = load_feature_matrix(config_dict["feature_mat_prefix"], config_dict["num_feature_chunks"], selected_features)
    logging.info("Building training X and Y using genes on chromosome {}. HLA region {}."
                 .format(", ".join(sorted(gene_annot_df.loc[Y_ids[training_Y_gene_inds]].CHR.unique(), key=natural_key)),
                         "removed" if config_dict["training_remove_hla"] else "included"))
    ### Build training X and Y
    ### Should be properly subsetted and have error_cov applied. We also explicitly project out intercept
    X_train, Y_train = build_training(mat, cols, rows,
                                      Y_proj, Y_ids, error_cov,
                                      gene_annot_df, training_Y_gene_inds,
                                      project_out_intercept=True)
    logging.info("X dimensions = {}. Y dimensions = {}".format(X_train.shape, Y_train.shape))
    ### Compute coefficients
    ### Output should contain at least one row for every column and additional rows for any metadata like method, regularization chosen by CV, etc.
    coefs_df = compute_coefficients(X_train, Y_train, cols, config_dict["method"], config_dict["random_seed"])
    ### Prediction
    logging.info("Computing PoPS scores.")
    preds_df = pops_predict(mat, rows, cols, coefs_df)
    ### Annotate Y, Y_proj, and gene used in feature selection + gene used in training
    preds_df = preds_df.merge(pd.DataFrame(np.array([Y_ids, Y]).T, columns=["ENSGID", "Y"]),
                              how="left",
                              on="ENSGID")
    if covariates is not None:
        preds_df = preds_df.merge(pd.DataFrame(np.array([Y_ids, Y_proj]).T, columns=["ENSGID", "Y_proj"]),
                                  how="left",
                                  on="ENSGID")
        preds_df["project_out_covariates_gene"] = preds_df.ENSGID.isin(Y_ids[project_out_covariates_Y_gene_inds])
    preds_df["feature_selection_gene"] = preds_df.ENSGID.isin(Y_ids[feature_selection_Y_gene_inds])
    preds_df["training_gene"] = preds_df.ENSGID.isin(Y_ids[training_Y_gene_inds])


    ### --------------------------------- Save --------------------------------- ###
    logging.info("Writing output files.")
    preds_df.to_csv(config_dict["out_prefix"] + ".preds", sep="\t", index=False)
    coefs_df.to_csv(config_dict["out_prefix"] + ".coefs", sep="\t")
    marginal_assoc_df.to_csv(config_dict["out_prefix"] + ".marginals", sep="\t")
    if config_dict["save_matrix_files"] == True:
        logging.info("Saving matrix files as well.")
        pd.DataFrame(np.hstack((Y_train.reshape(-1,1), X_train)),
                     index=Y_ids[training_Y_gene_inds],
                     columns=["Y_train"] + list(cols)).to_csv(config_dict["out_prefix"] + ".traindata", sep="\t")
        pd.DataFrame(mat,
                     index=rows,
                     columns=cols).to_csv(config_dict["out_prefix"] + ".matdata", sep="\t")


### Main
if __name__ == '__main__':
    args = get_pops_args()
    config_dict = vars(args)
    main(config_dict)
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
library(conflicted)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")
packages <- c("optparse","dplyr","data.table","reshape2","ggplot2","ggpubr","conflicted",
              "cluster","textshape","readxl","writexl","tidyr","org.Hs.eg.db","stats",
              "gplots", "stringi" # heatmap.2
              ) 
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("replace_na", "dplyr")
conflict_prefer("filter", "dplyr")

## optparse
option.list <- list(
    make_option("--input.GWAS.table", type="character", default="/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/overlap/MAP/MAP_GWAS_gene_incl_ubq_genes.txt", help="Input CAD GWAS Table"),
    make_option("--cNMF.table", type="character", default="", help="Table with cNMF program result"),
    make_option("--outdir", type="character", default="./outputs/MAP/", help="Output directory"),
    make_option("--figdir", type="character", default="./figures/", help="Output directory"),
    make_option("--sampleName", type="character", default="2kG.library", help="Name of the sample"),
    make_option("--celltype", type="character", default="EC", help="Cell type in GWAS table"),
    make_option("--K.val", type="numeric", default=60, help="The value of K"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--outdirsample", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/", help="path to cNMF analysis results"), ## or for 2n1.99x: "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211116_snakemake_dup4_cells/analysis/all_genes/Perturb_2kG_dup4/K60/threshold_0_2/"
    make_option("--num.tests", type="numeric", default=2, help="number of statistical test to do from the top of statistical.test.df.txt"),
    make_option("--trait.name", type="character", default="MAP", help="name of the trait"),
    make_option("--coding.variant.df", type="character", default="/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/SBP/SBPvariant.list.1.coordinate.txt", help="Data frame with coding variant information for GWAS trait"),
    make_option("--regulator.analysis.type", type="character", default="GWASWide", help="path to statistical test recipe"),
    make_option("--perturbSeq", type="logical", default=FALSE, help="Whether this is a Perturb-seq experiment"),
    make_option("--TPM.table", type="character", default="", help="Path to TPM table for the correct cell type")
)
opt <- parse_args(OptionParser(option_list=option.list))

## ## sdev for K562 gwps 2k overdispersed genes K=80
## opt$input.GWAS.table <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/overlap/RBC/RBC_GWAS_gene_incl_ubq_genes.txt"
## opt$cNMF.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K80/threshold_0_2/prepare_compute_enrichment.txt"
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K80/program_prioritization_GenomeWide/RBC/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K80/threshold_0_2/program_prioritization_GenomeWide/RBC/"
## opt$trait.name <- "RBC"
## opt$celltype <- "K562"
## opt$K.val <- 80
## opt$coding.variant.df <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/RBC/RBCvariant.list.1.coordinate.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K80/threshold_0_2/"
## opt$perturbSeq <- TRUE
## opt$regulator.analysis.type <- "GenomeWide"



## ## sdev for K562 gwps 2k overdispersed genes K=90 RBC
## opt$input.GWAS.table <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/overlap/RBC/RBC_GWAS_gene_incl_ubq_genes.txt"
## opt$cNMF.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/prepare_compute_enrichment.txt"
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K90/program_prioritization_GenomeWide/RBC/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/program_prioritization_GenomeWide/RBC/"
## opt$trait.name <- "RBC"
## opt$celltype <- "K562"
## opt$K.val <- 90
## opt$coding.variant.df <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/RBC/RBCvariant.list.1.coordinate.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/"
## opt$perturbSeq <- TRUE
## opt$regulator.analysis.type <- "GenomeWide"
## opt$TPM.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/230512_TPM/outputs/K562.TPM.df.txt"


## ## sdev for K562 gwps 2k overdispersed genes K=90 Plt
## opt$input.GWAS.table <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/overlap/Plt/Plt_GWAS_gene_incl_ubq_genes.txt"
## opt$cNMF.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/prepare_compute_enrichment.txt"
## opt$sampleName <- "WeissmanK562gwps"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K90/program_prioritization_GenomeWide/Plt/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/program_prioritization_GenomeWide/Plt/"
## opt$trait.name <- "Plt"
## opt$celltype <- "K562"
## opt$K.val <- 90
## opt$coding.variant.df <- "/oak/stanford/groups/engreitz/Users/rosaxma/2111_pipeline_output/UKB/Plt/Pltvariant.list.1.coordinate.txt"
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/"
## opt$perturbSeq <- TRUE
## opt$regulator.analysis.type <- "GenomeWide"
## opt$TPM.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/230512_TPM/outputs/K562.TPM.df.txt"


## ## sdev for EC Perturb-seq K=60
## opt$input.GWAS.table <- "/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/GWAS_tables/CAD_GWAS_gene_incl_ubq_genes_aragam_all_harst.txt"
## opt$cNMF.table <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/prepare_compute_enrichment.txt"
## opt$sampleName <- "2kG.library"
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/2kG.library/K60/program_prioritization_GWASWide/CAD/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/program_prioritization_GWASWide/CAD/"
## opt$trait.name <- "CAD"
## opt$celltype <- "EC"
## opt$K.val <- 60
## opt$coding.variant.df <- ""
## opt$outdirsample <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/"
## opt$perturbSeq <- TRUE
## opt$regulator.analysis.type <- "GWASWide"



## Directories
FIGDIR <- opt$figdir
OUTDIR <- opt$outdir
DATADIR <- "/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/" ## update this path
check.dir <- c(FIGDIR, OUTDIR)
invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x, recursive=T) }))


##########################################################################################
## Load Data
## load cNMF results to get the list of input genes to cNMF (need theta)
k <- opt$K.val
SAMPLE <- opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
SUBSCRIPT.SHORT=paste0("k_", k, ".dt_", DENSITY.THRESHOLD)
OUTDIRSAMPLE <- opt$outdirsample
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/"
celltype <- opt$celltype

cNMF.result.file <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT, ".RData")
print(cNMF.result.file)
if(file.exists(cNMF.result.file)) {
    print("loading cNMF result file")
    load(cNMF.result.file)
}


## load 10X reference
gtf.10X.df <- readRDS(paste0("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/refdata-cellranger-arc-GRCh38-2020-A_genes.gtf_df.RDS")) ## load 10X gtf file
Gene.ENSEMBL.10X.df <- gtf.10X.df %>% 
    mutate(ENSGID = gene_id,
           Gene10X = gene_name) %>%
    select(Gene10X, ENSGID) %>%
    unique
## gtf <- importGTF("/home/groups/engreitz/Software/cellranger-arc-1.0.1/refdata-cellranger-arc-GRCh38-2020-A/genes/genes.gtf")

db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
library(!!db) ## load the appropriate database
## helper function to map between ENSGID and SYMBOL
map.ENSGID.SYMBOL <- function(df) {
    ## need column `Gene` to be present in df
    ## detect gene data type (e.g. ENSGID, Entrez Symbol)
    gene.type <- ifelse(nrow(df) == sum(as.numeric(grepl("^ENS", df$Gene))),
                        "ENSGID",
                        "Gene")
    if(gene.type == "ENSGID") {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "ENSEMBL", column = "SYMBOL")
        df <- df %>% mutate(ENSGID = Gene, Gene = mapped.genes)
    } else {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "SYMBOL", column = "ENSEMBL")
        df <- df %>% mutate(ENSGID = mapped.genes)
    }
    df <- df %>%
        mutate(Gene = ifelse(is.na(Gene), "NA", Gene),
               ENSGID = ifelse(is.na(ENSGID), "NA", ENSGID)) 
    df <- df %>% merge(Gene.ENSEMBL.10X.df, by="ENSGID", all.x=T)
    df <- df %>% mutate(Gene = ifelse(Gene == "NA", Gene10X, Gene))
    notMatched.df <- df %>% subset(is.na(Gene10X))
    if(nrow(notMatched.df) > 0) {
        notMatched.index <- notMatched.df %>% rownames
        match.df <- merge(notMatched.df %>% select(-ENSGID, -Gene10X), Gene.ENSEMBL.10X.df, by.x="Gene", by.y="Gene10X") %>% mutate(Gene10X = Gene) %>% select(all_of(df %>% colnames))
        matched.index <- notMatched.df %>% subset(Gene %in% c(match.df$Gene %>% unique)) %>% rownames %>% as.numeric
        df <- rbind(df[-matched.index,], match.df)
    }
    df <- df %>% mutate(OriginalGene = Gene,
                        Gene = Gene10X)
    ## toconvert <- df %>% subset(Gene != Gene10X) %>% select(ENSGID, Gene, Gene10X)
    ## toconvert.index <- toconvert %>% 
    ## df %>% subset(Gene %in% Gene.ENSEMBL.10X.df$Gene10X & is.na(ENSGID))    
    return(df)
}


if(grepl("2kG.library", SAMPLE)) {
    ## Perturb-seq vs 10X names (from Gavin)
    ## ptb10xNames.df <- read_xlsx("../data/Perturbation 10X names.xlsx") %>% as.data.frame
    ptb10xNames.df <- read.delim(paste0(DATADIR, "/data/220627_add_Perturbation 10X names.txt"), stringsAsFactors=F, check.names=F)
    ## write.table(ptb10xNames.df,"../data/220627_add_Perturbation 10X names.220628.txt", quote=F, row.names=F, sep="\t")
    perturbseq.gene.names.to10X <- function(Gene) stri_replace_all_regex(Gene, pattern = ptb10xNames.df$Symbol, replace = ptb10xNames.df$`Name used by CellRanger`, vectorize=F)
    tenX.gene.names.toperturbseq <- function(Gene) stri_replace_all_regex(Gene, pattern = ptb10xNames.df$`Name used by CellRanger`, replace = ptb10xNames.df$Symbol, vectorize=F) 
}


## gtf.10X.df <- readRDS(paste0(DATADIR, "/data/refdata-cellranger-arc-GRCh38-2020-A_genes.gtf_df.RDS")) ## load 10X gtf file
message(paste0("Loading input GWAS table file from ", opt$input.GWAS.table))
GWAS.df <- GWAS.df.original <- read.delim(opt$input.GWAS.table, stringsAsFactors=F) %>% mutate(original.gene = gene) %>%
    select(-ProgramsInWhichGeneIsInTop100ZScoreSpecificGenes, -ProgramsInWhichGeneIsInTop300ZScoreSpecificGenes, -ProgramsInWhichGeneIsInTop500ZScoreSpecificGenes, -PoPS_Score, -PoPS.Rank, -Top5ProgramsThatContributeToPoPSScore) %>%
    mutate(GeneInGWASTable = TRUE,
           Gene = gene) %>%
    map.ENSGID.SYMBOL

if(grepl("2kG.library", SAMPLE)){
    GWAS.df <- GWAS.df %>%
        ## mutate(original.gene = gene) %>%
        mutate(gene = gene %>% perturbseq.gene.names.to10X)

    ## add columns on_2kG_lib and TeloHAEC_ctrl_TPM
    genes.on.2kG.lib <- read_xlsx(paste0(DATADIR, "/data/Table.S1.1 revised 220722.xlsx")) %>%
        mutate(original.gene = `Symbol (for library design)`,
               gene = `Symbol (in CellRanger)`)
    RNAseq_CITV <- read.delim(paste0(DATADIR, "/CAD_SNP_INFO/RNASeq_Telo_Eahy_pm_IL1b_TNF_VEGF.txt"), stringsAsFactors=F)
    TPM.df <- RNAseq_CITV %>% select(Gene_symbol, TeloHAEC.Ctrl_avg) %>% `colnames<-`(c("gene", "TeloHAEC_ctrl_TPM"))


    ## GWAS.df <- GWAS.df %>%
    ##     mutate(on_2kG_lib = gene %in% genes.on.2kG.lib$gene) %>%
    ##     merge(TPM.df, by="gene", all.x=T)
}

perturbed_gene_ary <- barcode.names$Gene %>% unique


## Load data for all genes relation to the topics (not limited to CAD GWAS genes)
message(paste0("Loading input G2P table file from ", opt$cNMF.table))
narrow.df <- read.delim(opt$cNMF.table, header=T, stringsAsFactors=F)## %>% select(-Top5FeaturesThatContributeToPoPSScore)
## narrow.df <- narrow.df[narrow.df$Gene!="NULL", ]

if(grepl("2kG.library", SAMPLE)) {
    narrow.df.TPM <- merge(TPM.df, narrow.df %>% select(-ENSGID) %>% unique, by.x=c("gene"), by.y=c("Gene"), all.y=T)
    narrow.df.TPM$gene[narrow.df.TPM$gene == "MESDC1"] <- "TLNRD1" ## one time fix MESDC1 -> TLNRD1
} else {
    narrow.df.TPM <- narrow.df %>% mutate(gene = Gene)
}

if(!grepl("2kG.library", SAMPLE)) {
    coding_variant.df <- read.delim(opt$coding.variant.df, stringsAsFactors=F)

    GeneWithCodingVariant <- coding_variant.df$CodingVariantGene %>% unique
    GWAS.df <- GWAS.df %>%
        mutate(GeneContainsCodingVariant = (gene %in% GeneWithCodingVariant) | (original.gene %in% GeneWithCodingVariant),
               CodingVariantGene = ifelse(GeneContainsCodingVariant, gene, NA))
} else {
    all.cs.aragam <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/CAD_SNP_INFO/Aragam2021/all.cs.txt", stringsAsFactors=F)
    all.cs.harst <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/CAD_SNP_INFO/Harst2017/all.cs.txt", stringsAsFactors=F)

    addCodingGenes <- function(df, all.cs) {
        codingGenes <- all.cs %>% filter(AnyCoding) %>% select(CredibleSet, CodingVariantGene)
        df <- df %>% merge(codingGenes, all.x=TRUE) %>%
            rowwise() %>%
            mutate(GeneContainsCodingVariant=ifelse(is.na(CodingVariantGene), FALSE, (gene %in% strsplit(CodingVariantGene,";")[[1]]))) %>%
            ungroup() %>%
            as.data.frame()
        return(df)
    }

    GWAS.df <- GWAS.df %>% 
        ## mutate(Resource=ifelse(Resource=="Aragam_2021","Aragam2021","Harst2017"),
        ##        CredibleSet=paste0(Resource,"-",Lead_SNP_rsID)) %>%
        addCodingGenes(rbind(all.cs.aragam,all.cs.harst)) %>%
        mutate(CredibleSet = ifelse(CredibleSet %in% c("Aragam2021-rs768453105", "Harst2017-rs138120077"), "Aragam2021-rs768453105_Harst2017-rs138120077", CredibleSet))
}

## include TPM table
if(!("TPM" %in% colnames(GWAS.df))) {
    TPM.df <- read.delim(opt$TPM.table, stringsAsFactors=F)
    GWAS.tmp.df <- merge(GWAS.df, TPM.df, by.x="gene", by.y="Gene", all.x=T) 
    GWAS.df <- GWAS.tmp.df
}



########################################################################################
## Define key variables for analysis
GWAS.df <- GWAS.df %>%
    mutate(
        Expressed = (TPM >= 1),
        InPlausibleCellTypeSpecificLocus = ( !LipidLevelsAssociated & (!is.na(get(paste0("RankOfDistanceTo", celltype, "PeakWithVariant"))) | !is.na(get(paste0("MaxABC.Rank.", celltype, "Only"))) | !is.na(CodingVariantGene)) ), ## Also require that locus is not associated with lipid levels
        InPlausibleCellTypeSpecificLocus_includeLipid = (!is.na(get(paste0("RankOfDistanceTo", celltype, "PeakWithVariant"))) | !is.na(get(paste0("MaxABC.Rank.", celltype, "Only"))) | !is.na(CodingVariantGene)),
        InPlausibleLocus = ( !LipidLevelsAssociated & (!is.na(rank_SNP_to_TSS) | !is.na(MaxABC.Rank) | !is.na(CodingVariantGene)) ),
        InPlausibleLocus_includeLipid = !is.na(rank_SNP_to_TSS) | !is.na(MaxABC.Rank) | !is.na(CodingVariantGene),
        TopCandidate = (get(paste0("RankOfDistanceTo", celltype, "PeakWithVariant")) <= 2 | MaxABC.Rank.ECOnly <= 2 | GeneContainsCodingVariant) %>% replace_na(FALSE),
        TopCandidateInCellTypeSpecificLocus = TopCandidate & InPlausibleCellTypeSpecificLocus
        ## !!(paste0("ExpressedTopCandidateIn", celltype, "Locus")) := ((get(paste0("RankOfDistanceTo", celltype, "PeakWithVariant")) <= 2 | get(paste0("MaxABC.Rank.", celltype, "Only")) <= 2 | GeneContainsCodingVariant) %>% replace_na(FALSE)) & get(paste0("InPlausible", celltype, "Locus")) & Expressed,
        ## TopCandidateInCellTypeSpecificLocus_includeLipid = ((get(paste0("RankOfDistanceTo", celltype, "PeakWithVariant")) <= 2 | get(paste0("MaxABC.Rank.", celltype, "Only")) <= 2 | GeneContainsCodingVariant) %>% replace_na(FALSE)) & get(paste0("InPlausible", celltype, "Locus_includeLipid"))
        ## TopCandidate = ((rank_SNP_to_TSS <= 2 | MaxABC.Rank <= 2 | GeneContainsCodingVariant) %>% replace_na(FALSE)) & InPlausibleLocus,
        ## TopCandidate_includeLipid = ((rank_SNP_to_TSS <= 2 | MaxABC.Rank <= 2 | GeneContainsCodingVariant) %>% replace_na(FALSE)) & InPlausibleLocus_includeLipid
    ) %>%
    mutate(perturbed_gene = (gene %in% perturbed_gene_ary)) %>%
    as.data.frame()



##########################################################################################
## quick statistics on GWAS.df
## Count of V2G linked genes
cat("Number of genes with V2G links: ",
    GWAS.df %>% filter(TopCandidateInCellTypeSpecificLocus) %>% pull(gene) %>% unique %>% length,
    "\n")

## Count of credible sets that count as "plausibleCellTypeSpecificLocus":
cat("Number of credible sets that count as 'in a plausible CellTypeSpecific locus': ", 
    GWAS.df %>% filter(InPlausibleCellTypeSpecificLocus) %>% pull(CredibleSet) %>% unique() %>% length(), 
    " out of ",
    GWAS.df %>% pull(CredibleSet) %>% unique() %>% length(),
    "\n")


## Count of credible sets that count as "plausibleLocus":
cat("Number of credible sets that count as 'in a plausible locus': ", 
    GWAS.df %>% filter(InPlausibleLocus) %>% pull(CredibleSet) %>% unique() %>% length(), 
    " out of ",
    GWAS.df %>% pull(CredibleSet) %>% unique() %>% length(),
    "\n")



## merge CAD.GWAS.df into narrow.df.TPM to get all expressed gene's information
remove_na <- function(x) x[is.na(x)] <- NULL
overlapping.colnames <- intersect(narrow.df.TPM %>% colnames, GWAS.df %>% colnames)
narrow.df.TPM <- merge(narrow.df.TPM, # %>%
                       GWAS.df %>%
                       select(-one_of((overlapping.colnames[!(overlapping.colnames %in% c("gene"))]))),
                       by=c("gene"), all=T
                       ) %>%
    rowwise() %>%
    mutate(ProgramsLinkedToGene= c(strsplit(ProgramsRegulatedByThisGene,",")[[1]], strsplit(ProgramsInWhichGeneIsInTop300ZScoreSpecificGenes, ",")[[1]]) %>% unique() %>% paste(collapse=',')) %>%
    ## mutate(perturbed_gene = (gene %in% perturbed_gene_ary) | (OriginalGene %in% perturbed_gene_ary)) %>%
    ## ## mutate(ProgramsLinkedToGene= c(strsplit(ProgramsRegulatedByThisGene,",")[[1]], strsplit(ProgramsInWhichGeneIsInTop300ZScoreSpecificGenes, ",")[[1]]) %>% unique() %>% paste(collapse=',')) %>%
    as.data.frame ## because CAD.GWAS.df has dupicated genes, TopicsInWhichGeneIsInTop100ZScoreSpecificGenes sum up to > 100
narrow.df.TPM[is.na(narrow.df.TPM)] <- FALSE

## head(narrow.df.TPM)
## print(colnames(narrow.df.TPM))
## GenesIncNMF <- c(theta.zscore.rank.df$Gene %>% unique)
narrow.df.TPM <- narrow.df.TPM %>%
    ## rowwise %>%
    mutate(AllGenesIncNMFInput = IncNMFAnalysis) %>%
    as.data.frame

## write GWAS table to file
message("write GWAS.df to file")
write.table(GWAS.df, paste0(OUTDIR, opt$trait.name, ".GWAS.df.txt"), sep="\t", quote=F, row.names=F)
write.table(narrow.df.TPM, paste0(OUTDIR, opt$trait.name, ".narrow.df.TPM.txt"), sep="\t", quote=F, row.names=F)

# narrow.df.TPM <- read.delim(paste0(OUTDIR, opt$trait.name, ".narrow.df.TPM.txt"), stringsAsFactors=F)



########################################################################################
## Run statistical Tests
## Create subset dfs based on background conditions

## read table that specifies which gene set is for alternative hypothesis and which set of topics to use
message("reading statistical test list")
statistical.test.list.df <- read.delim(paste0("workflow/scripts/program_prioritization/All_GWAS_traits.statistical.test.list_", opt$regulator.analysis.type, ".txt"), stringsAsFactors=F)
## statistical.test.list.df <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230119_apply_V2G2P_to_more_GWAS_traits/All_GWAS_traits.statistical.test.list.txt"), stringsAsFactors=F)

message("double check statistical test criteria")
statistical.test.list.df[1:2,] ## double check test criteria

num.tests <- opt$num.tests ## actual number of tests requested by the user
statistical.test.list.df <- statistical.test.list.df[1:num.tests,]

## num.tests <- nrow(statistical.test.list.df) ## specify number of statistical tests to conduct
df.list.to.test <- vector("list", num.tests) ## initiate background data storage list
names(df.list.to.test) <- statistical.test.list.df$test.name

## remove batch topics from test
if(grepl("2kG.library", SAMPLE)) {
    ## load topic summary
    TopicSummary.df <- read.delim(paste0(DATADIR, "/data/210730_cNMF_topic_model_analysis.xlsx - TopicCatalogNEW.tsv"), stringsAsFactors=F)

    ## batch effect topic
    batch.topics <- TopicSummary.df %>%
        filter(ProgramCategoryLabel == "Batch") %>%
        pull(ProgramID) %>%
        unique
} else {
    message("reading batch topics")
    filename = paste0(opt$outdirsample, "/batch.topics.txt")
    empty = file.size(filename) == 0L
    if(empty) {
        message(paste0("There is no batch topics in K = ", k))
        batch.topics = character(0)
    } else {
        batch.topics <- read.delim(filename, stringsAsFactors=F, header=F) %>%
            mutate(ProgramID = paste0("K", k, "_", gsub("topic_", "", V1))) %>%
            pull(ProgramID) %>%
            unique
    }
}

## topics_to_test <- setdiff(c(1:k), batch.topics %>% gsub(paste0("K", k, "_"), "", .))

## do not remove batch topic from the test
topics_to_test = c(1:k)




all.test.results.background.list <- vector("list", num.tests)
for(i in 1:num.tests) {
    background.name <- names(df.list.to.test)[i] ## get the name of the background filter
    subset.df <- eval(parse(text = paste0("subset.df <- ", statistical.test.list.df$background.subset.command[i]))) 

    all.test.results.list <- vector("list", k) ## initialize variable
    ## loop over every topic t
    for(t in topics_to_test) {
        topic <- paste0("K", k, "_", t)

        ## new scratch for test that only looks at 'expressed genes in a topic'
        genes.to.test <- statistical.test.list.df$genes.to.test[i]
        topics.to.test <- statistical.test.list.df$topics.to.test[i]

        list.to.test <- subset.df %>%
            rowwise %>%
            mutate(hypothesis.gene = ifelse(eval(parse(text = genes.to.test)), 1, 0),
                   LinkedToTopic.gene = ifelse(topic %in% (get(topics.to.test) %>% strsplit("\\|") %>% unlist %>% as.character), 1, 0)) %>% ## debug
            ## mutate(gene = paste0(gene, "_", Gene, "_", OriginalGene)) %>%
            select(gene, hypothesis.gene, LinkedToTopic.gene) %>%
            unique %>%
            group_by(gene) %>%
            summarize(hypothesis.gene = ifelse(hypothesis.gene %>% sum > 0, "CandidateGene", "NotCandidateGene"),
                      LinkedToTopic.gene = ifelse(LinkedToTopic.gene %>% sum > 0, "LinkedToTopic", "NotLinkedToTopic")) %>%
            as.data.frame


        table.to.test <- table(list.to.test$hypothesis.gene, list.to.test$LinkedToTopic.gene) ## count frequencies
        if(ncol(table.to.test) == 1) {
            columns.needed <- c("LinkedToTopic", "NotLinkedToTopic")
            column.to.add <- columns.needed[!(columns.needed %in% colnames(table.to.test))]
            table.to.add <- matrix(c(0,0), nrow=2, dimnames=list(rownames(table.to.test),c(column.to.add)))
            if(column.to.add == "NotLinkedToTopic") table.to.test <- cbind(table.to.test, table.to.add) else table.to.test <- cbind(table.to.add, table.to.test)
        } else if (nrow(table.to.test) == 1) {
            rows.needed <- c("CandidateGene", "NotCandidateGene")
            row.to.add <- row.needed[!(row.needed %in% rownames(table.to.test))]
            table.to.add <- matrix(c(0,0), ncol=2, dimnames=list(c(row.to.add), colnames(table.to.test)))
            if(row.to.add == "NotCandidateGene") table.to.test <- rbind(table.to.test, table.to.add) else table.to.test <- rbind(table.to.add, table.to.test)
        }


        flattened.table.to.store <- table.to.test %>%
            matrix(nrow=1, ncol=4) %>% ## flatten the table
            as.data.frame %>%
            `colnames<-`(c("LinkedToTopic_CandidateGene",
                           "LinkedToTopic_NotCandidateGene",
                           "NotLinkedToTopic_CandidateGene",
                           "NotLinkedToTopic_NotCandidateGene")) %>% ## assign the table values to the corresponding column names
            mutate(enrichment = ( LinkedToTopic_CandidateGene / (LinkedToTopic_CandidateGene + NotLinkedToTopic_CandidateGene) ) / ( LinkedToTopic_NotCandidateGene / ( LinkedToTopic_NotCandidateGene + NotLinkedToTopic_NotCandidateGene) ), ## calculate enrichment
                   enrichment.log2fc = log2(enrichment), ## log2fc of the enrichment
                   Gene_LinkedToTopic_CandidateGene = list.to.test %>% ## list the genes that fall into each of the four categories in the contingency table
                       subset(hypothesis.gene == "CandidateGene" & LinkedToTopic.gene == "LinkedToTopic") %>%
                       pull(gene) %>%
                       paste0(collapse=","),
                   Gene_LinkedToTopic_NotCandidateGene = list.to.test %>%
                       subset(hypothesis.gene == "NotCandidateGene" & LinkedToTopic.gene == "LinkedToTopic") %>%
                       pull(gene) %>%
                       paste0(collapse=","),
                   Gene_NotLinkedToTopic_CandidateGene = list.to.test %>%
                       subset(hypothesis.gene == "CandidateGene" & LinkedToTopic.gene == "NotLinkedToTopic") %>%
                       pull(gene) %>%
                       paste0(collapse=",")#,
                   ## Gene_NotLinkedToTopic_NotCandidateGene = list.to.test %>%
                   ##     subset(hypothesis.gene == "NotCandidateGene" & LinkedToTopic.gene == "NotLinkedToTopic") %>%
                   ##     pull(gene) %>%
                   ##     paste0(collapse=",")
                   )

        ## null distribution probability
        p.binomial.null <- (list.to.test %>% subset(LinkedToTopic.gene == "LinkedToTopic") %>% nrow) / (list.to.test %>% nrow)
        p.binomial.test <- binom.test(list.to.test %>% subset(hypothesis.gene == "CandidateGene" & LinkedToTopic.gene == "LinkedToTopic") %>% nrow, ## number of candidate genes that link to the topic
                                      list.to.test %>% subset(hypothesis.gene == "CandidateGene") %>% nrow, ## number of candidate genes
                                      p = p.binomial.null)$p.value ## binomial test

        if (dim(table.to.test)[1] == 2 & dim(table.to.test) [2] == 2) {
            fisher.p.val <- fisher.test(table.to.test, alternative="greater")$p.value
        } else {
            fisher.p.val = NA ## add a row of 0?
        }
        all.test.results.list[[t]] <- data.frame(Topic = topic,
                                                 fisher.p.value = fisher.p.val,
                                                 binomial.p.value = p.binomial.test) %>%
            cbind(flattened.table.to.store) ## put all results for one topic in one line to store
    }
    all.test.results.background.list[[i]] <- do.call(rbind, all.test.results.list) %>% mutate(fisher.p.adjust = p.adjust(fisher.p.value), binomial.p.adjust = p.adjust(binomial.p.value), background = background.name, .after="binomial.p.value")
}
all.test.results <- do.call(rbind, all.test.results.background.list) %>%
    as.data.frame %>%
    mutate(Gene_NotLinkedToTopic_CandidateGene = ifelse(NotLinkedToTopic_CandidateGene > 300, "", Gene_NotLinkedToTopic_CandidateGene),
           Gene_LinkedToTopic_CandidateGene = ifelse(LinkedToTopic_CandidateGene > 150, "", Gene_LinkedToTopic_CandidateGene)) #Gene_LinkedToTopic_NotCandidateGene))

all.test.results.nonbatch <- all.test.results %>%
    subset(Topic %in% paste0("K", k, "_", setdiff(c(1:k), batch.topics %>% gsub(paste0("K", k, "_"), "", .)))) %>%
    group_by(background) %>%
    mutate(fisher.p.adjust = p.adjust(fisher.p.value, method="fdr"),
           binomial.p.adjust = p.adjust(binomial.p.value, method="fdr"),
           .after="binomial.p.value") %>%
    as.data.frame

## output statistical test results
write.table(all.test.results, file=paste0(OUTDIR, opt$trait.name, ".fisher.exact.test.to.prioritize.topics.txt"), sep="\t", quote=F, row.names=F)
write.table(all.test.results.nonbatch, file=paste0(OUTDIR, opt$trait.name, ".fisher.exact.test.to.prioritize.topics.nonbatch.txt"), sep="\t", quote=F, row.names=F)

head(all.test.results)

## ## reload data
## all.test.results.nonbatch <- read.delim(paste0(OUTDIR, opt$trait.name, ".fisher.exact.test.to.prioritize.topics.nonbatch.txt"), stringsAsFactors=F)



regulator.programGene.test.pairs <- data.frame(regulator = statistical.test.list.df$test.name[seq(2, num.tests, by=2)],
                                               programGene = statistical.test.list.df$test.name[seq(1, num.tests-1, by=2)],
                                               name = statistical.test.list.df[1:num.tests,] %>%
                                                   separate(col="test.name", sep="_geneSet.", into = c("regulator.or.program.gene", "toProcess")) %>%
                                                   separate(col="toProcess", sep="_", into=c("test.handle", "background")) %>% pull(test.handle) %>% unique,
                                               command = statistical.test.list.df$genes.to.test[seq(1, num.tests-1, by=2)],
                                               regulator.selection.command = statistical.test.list.df$background.subset.command[seq(2, num.tests, by=2)])
regulator.programGene.combined.pval.nonbatch.list <- vector("list", nrow(regulator.programGene.test.pairs))

for(i in 1:nrow(regulator.programGene.test.pairs)) {
    regulator.results.nonbatch <- all.test.results.nonbatch %>%
        subset(background == regulator.programGene.test.pairs$regulator[i]) %>%
        `colnames<-`(paste0("Regulator_", colnames(.))) %>%
        dplyr::rename("ProgramID" = "Regulator_Topic")
    programGene.results.nonbatch <- all.test.results.nonbatch %>% subset(background == regulator.programGene.test.pairs$programGene[i]) %>%
        `colnames<-`(paste0("ProgramGene_", colnames(.))) %>%
        dplyr::rename("ProgramID" = "ProgramGene_Topic")

    ## calculate expected number of genes ## adapted from 221115_compute_enrichment.R
    mean.Regulators = regulator.results.nonbatch %>%
        mutate(nRegulators=Regulator_LinkedToTopic_CandidateGene + Regulator_LinkedToTopic_NotCandidateGene) %>%
        pull(nRegulators) %>%
        mean
    ## nPerturbedGeneInCADGWASLoci = GWAS.df %>% filter(perturbed_gene == 1) %>% pull(gene) %>% unique %>% length
    ## nPerturbedGeneCandidate = GWAS.df %>% filter(perturbed_gene == 1 & eval(parse(text = regulator.programGene.test.pairs$command[i] %>% as.character))) %>% pull(gene) %>% unique %>% length
    nPerturbedGene = eval(parse(text = regulator.programGene.test.pairs$regulator.selection.command[i] %>% as.character)) %>% pull(gene) %>% unique %>% length ## old: narrow.df.TPM %>% filter(perturbed_gene == 1) %>% pull(gene) %>% unique %>% length
    nPerturbedGeneCandidate = eval(parse(text = regulator.programGene.test.pairs$regulator.selection.command[i] %>% as.character)) %>% filter(eval(parse(text = regulator.programGene.test.pairs$command[i] %>% as.character))) %>% pull(gene) %>% unique %>% length ## old: narrow.df.TPM %>% filter(perturbed_gene == 1 & eval(parse(text = regulator.programGene.test.pairs$command[i] %>% as.character))) %>% pull(gene) %>% unique %>% length
    expected.Regulators = mean.Regulators * nPerturbedGeneCandidate / nPerturbedGene
    nProgramGeneBackground = narrow.df.TPM %>% filter(IncNMFAnalysis == 1) %>% pull(gene) %>% unique %>% length
    nProgramGeneCandidate = narrow.df.TPM %>% subset(IncNMFAnalysis == 1 & eval(parse(text = regulator.programGene.test.pairs$command[i] %>% as.character))) %>% pull(gene) %>% unique %>% length
    expected.ProgramGenes = 300 * nProgramGeneCandidate / nProgramGeneBackground
    expected.nLinkedGenes = expected.Regulators + expected.ProgramGenes

    regulator.programGene.combined.pval.nonbatch.list[[i]] <- merge(regulator.results.nonbatch,
                                                                    programGene.results.nonbatch,
                                                                    by="ProgramID") %>%
        mutate(LinkedGenes_ChiSquareTestStatistic = -2 * (log(Regulator_fisher.p.value) + log(ProgramGene_fisher.p.value)),
               LinkedGenes_ChiSquare.p.value = pchisq(LinkedGenes_ChiSquareTestStatistic, df = 4, lower.tail=F),
               LinkedGenes_fisher.p.value = Regulator_fisher.p.value * ProgramGene_fisher.p.value,
               LinkedGenes_binomial.p.value = Regulator_binomial.p.value * ProgramGene_binomial.p.value,
               LinkedGenes_fisher.p.adjust = p.adjust(LinkedGenes_fisher.p.value, method="fdr"),
               LinkedGenes_binomial.p.adjust = p.adjust(LinkedGenes_binomial.p.value, method="fdr"),
               LinkedGenes_ChiSquare.p.adjust = p.adjust(LinkedGenes_ChiSquare.p.value, method="fdr"),
               LinkedGenes_fisherNegLog10FDR = -log10(LinkedGenes_fisher.p.adjust),
               LinkedGenes_binomialNegLog10FDR = -log10(LinkedGenes_binomial.p.adjust),
               LinkedGenes_ChiSquareNegLog10FDR = -log10(LinkedGenes_ChiSquare.p.adjust),
               ## LinkedGenes_LinkedToTopic_CandidateGene = Regulator_LinkedToTopic_CandidateGene + ProgramGene_LinkedToTopic_CandidateGene, 
               .after = "ProgramID") %>%
        ## adapted from 211115_compute_enrichment.R
                rowwise %>% 
    mutate(LinkedGenes_LinkedToTopic_CandidateGene = c(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% length,
           LinkedGenes_Gene_LinkedToTopic_CandidateGene = c(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% paste0(collapse=","),
        InGeneSet_Regulator_and_ProgramGene = intersect(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% length,
           Gene_InGeneSet_Regulator_and_ProgramGene = intersect(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% paste0(collapse=","),
           Unique_Regulator_InGeneSet_CandidateGene = setdiff(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% length,
                   Unique_ProgramGene_InGeneSet_CandidateGene = setdiff(ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% length,
                   Unique_Regulator_Gene_InGeneSet_CandidateGene = setdiff(Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% paste0(collapse=","),
                   Unique_ProgramGene_Gene_InGeneSet_CandidateGene = setdiff(ProgramGene_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character, Regulator_Gene_LinkedToTopic_CandidateGene %>% strsplit(split=",") %>% unlist %>% as.matrix %>% as.character) %>% unique %>% paste0(collapse=","),

           .after = "LinkedGenes_binomialNegLog10FDR") %>%
        mutate(.after = "LinkedGenes_LinkedToTopic_CandidateGene",
               LinkedGenes_Expected = expected.nLinkedGenes,
               LinkedGenes_Enrichment = LinkedGenes_LinkedToTopic_CandidateGene / LinkedGenes_Expected) %>%
    mutate(.before = "Regulator_enrichment",
           Regulator_Expected = expected.Regulators) %>%
    mutate(.before = "ProgramGene_enrichment",
           ProgramGene_Expected = expected.ProgramGenes) %>%
        as.data.frame %>%
## end of adapted code ##
        arrange(LinkedGenes_binomial.p.adjust) %>% mutate(test.name = regulator.programGene.test.pairs$name[i], .after = "ProgramID")
}

regulator.programGene.combined.pval.nonbatch <- do.call(rbind, regulator.programGene.combined.pval.nonbatch.list) %>%
    arrange(LinkedGenes_ChiSquare.p.adjust, desc(LinkedGenes_LinkedToTopic_CandidateGene))

write.table(regulator.programGene.combined.pval.nonbatch, file=paste0(OUTDIR, "/", opt$trait.name, ".program_prioritization.txt"), sep="\t", quote=F, row.names=F)

## create a table with V2G2P prioritized genes for Program Genes
fdr.thr <- 0.05
prioritized.program.genes.df <- regulator.programGene.combined.pval.nonbatch %>%
    subset(ProgramGene_fisher.p.adjust < fdr.thr) %>%
    select(ProgramID, ProgramGene_Gene_LinkedToTopic_CandidateGene) %>%
    separate_rows(ProgramGene_Gene_LinkedToTopic_CandidateGene, sep=",", convert=F) %>%
    `colnames<-`(c("ProgramID", "V2G2P Program Gene")) %>%
    as.data.frame




####################################################################################################
## plots ## adapted from TeloHAEC_Perturb-seq_2kG/221115_stimulation_condition_V2G/221115_compute_enrichment.R
source('/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/figures/helper_scripts/plot_helper_functions.R')
## reload data
regulator.programGene.combined.pval.nonbatch <- read.delim(paste0(OUTDIR, "/", opt$trait.name, ".program_prioritization.txt"), stringsAsFactors=F)
## regulator.programGene.combined.pval.nonbatch <- read.delim(paste0(OUTDIR, opt$trait.name, ".combinedRegulatorProgramGeneEnrichmentTest.to.prioritize.topics.nonbatch.txt"), stringsAsFactors=F)
theme.here <- theme(legend.key.size=unit(0.1, unit="in"),
                    legend.text = element_text(size=5),
                    legend.position = "bottom",
                    legend.direction="vertical",
                    legend.margin = margin(1,1,1,1,unit="pt"),
                    panel.spacing = unit(1, units="pt"))
mytheme <- theme_classic() + theme(axis.text = element_text(size = 5),
                                   axis.title = element_text(size = 6),
                                   plot.title = element_text(hjust = 0.5, face = "bold", size=6),
                                   axis.line = element_line(color = "black", size = 0.25),
                                   axis.ticks = element_line(color = "black", size = 0.25))


if(opt$perturbSeq) {
    category_ary <- c("LinkedGenes", "ProgramGene", "Regulator")
} else {
    category_ary <- c("ProgramGene")
}

fdr.thr <- 0.05
for(category in category_ary) {
    expectedNumEnrichedGenes <- regulator.programGene.combined.pval.nonbatch %>% pull(get(paste0(category, "_Expected"))) %>% unique
    toplot <- regulator.programGene.combined.pval.nonbatch %>%
        mutate(significant = ifelse(get(paste0(category, "_", ifelse(category == "LinkedGenes", "fisher", "fisher"), ".p.adjust")) < fdr.thr,
                             ifelse(get(paste0(category, "_", ifelse(category == "LinkedGenes", "fisher", "fisher"), ".p.adjust")) < (fdr.thr / 10),
                             ifelse(get(paste0(category, "_", ifelse(category == "LinkedGenes", "fisher", "fisher"), ".p.adjust")) < (fdr.thr / 100), "***", "**"), ## change to "ChiSquare" in the true slot if you want to combine regulator and program genes by Fisher's method
                             "*"),
                             "")) %>%
        ## add.df.Program.name %>% ## todo (generalize)
        arrange(desc(get(paste0(category, "_LinkedToTopic_CandidateGene"))))
    ## numSignificantPrograms <- nrow(toplot)

    if(grepl("2kG.library", SAMPLE)) {
        toplot <- toplot %>% add.df.Program.name
    } else {
        toplot <- toplot %>% mutate(truncatedLabel = ProgramID)
    }

    if(category == "LinkedGenes") {
        toplot <- toplot %>%
            arrange(desc(LinkedGenes_LinkedToTopic_CandidateGene),
                    desc(InGeneSet_Regulator_and_ProgramGene),
                    desc(Unique_ProgramGene_InGeneSet_CandidateGene),
                    desc(Unique_Regulator_Gene_InGeneSet_CandidateGene))
    }

    label.order <- toplot$truncatedLabel %>% rev
    toplot <- toplot %>%
        mutate(truncatedLabel = factor(truncatedLabel, levels=label.order))

    if(category == "LinkedGenes") {
        maxNumGenes <- toplot$LinkedGenes_LinkedToTopic_CandidateGene %>% max
        toplot.here <- toplot %>%
            select(truncatedLabel, InGeneSet_Regulator_and_ProgramGene, Unique_Regulator_InGeneSet_CandidateGene, Unique_ProgramGene_InGeneSet_CandidateGene, significant) %>%
            melt(id.vars=c("truncatedLabel", "significant"), value.name="numGenes", variable.name="LinkType") %>%
            mutate(LinkTypeText = ifelse(grepl("_and_", LinkType), "Regulator and Co-regulated Genes",
                                  ifelse(grepl("Regulator", LinkType), "Regulators", "Co-regulated Genes")),
                   LinkTypeText = factor(LinkTypeText, levels=c("Co-regulated Genes", "Regulators", "Regulator and Co-regulated Genes") %>% rev))

        p <- toplot.here %>% ggplot(aes(x=truncatedLabel, y=numGenes, fill=LinkTypeText)) + geom_col() + coord_flip() + mytheme +
            geom_text(aes(label=significant, y=maxNumGenes*1.1), nudge_x=-0.5, size=3, color='gray') +
            scale_fill_manual(name = "", values = c("gray30", "#38b4f7", "#0141a8") %>% rev) +
            xlab("Programs") + ylab("# Genes Linked to Program") +
            ggtitle(paste0(opt$trait.name, " GWAS trait")) +
            theme.here +
            geom_hline(yintercept=expectedNumEnrichedGenes, linetype="dashed", color="gray")

    } else {
        toplot.here <- toplot %>%
            select(truncatedLabel, paste0(category, "_LinkedToTopic_CandidateGene"), significant)

        p <- toplot.here %>% ggplot(aes(x=truncatedLabel, y=get(paste0(category, "_LinkedToTopic_CandidateGene")))) + geom_col(fill="gray30") + coord_flip() + mytheme +
            geom_text(aes(label=significant, y=max(get(paste0(category, "_LinkedToTopic_CandidateGene")))*1.1), nudge_x=-0.5, size=3, color='gray') +
            xlab("Programs") + ylab("# Genes Linked to Program") +
            ggtitle(paste0(opt$trait.name, " GWAS trait")) +
            theme.here +
            geom_hline(yintercept=expectedNumEnrichedGenes, linetype="dashed", color="gray")
    }

    ## if(opt$perturbSeq) p <- p + scale_fill_manual(name = "", values = c("gray30", "#38b4f7", "#0141a8") %>% rev) 
    filename <- paste0(FIGDIR, "/", SAMPLE, "_K", k, "_dt_", DENSITY.THRESHOLD, "_", opt$trait.name, "_", category, "_GeneCountBarPlot")
    pdf(paste0(filename, ".pdf"), width=2.5, height=3/60*k+1)
    print(p)
    dev.off()

    ## Enrichment Plot
    column.here <- ifelse(category == "LinkedGenes", "LinkedGenes_Enrichment", paste0(category, "_enrichment"))
    toplot <- toplot %>% arrange(desc(get(column.here)))
    label.order <- toplot$truncatedLabel %>% rev
    toplot <- toplot %>%
        mutate(truncatedLabel = factor(truncatedLabel, levels=label.order))

    p <- toplot %>% ggplot(aes(x=truncatedLabel, y=get(column.here))) + geom_col(fill="gray30") + coord_flip() + mytheme +
        geom_text(aes(label=significant, y=max(get(column.here))*1.1), nudge_x=-0.5, size=3, color='gray') +
        xlab("Program") + ylab(paste0(category, "Enrichment")) +
        ggtitle(paste0(opt$trait.name, " GWAS trait")) +
        theme.here

    filename <- paste0(FIGDIR, "/", SAMPLE, "_K", k, "_dt_", DENSITY.THRESHOLD, "_", opt$trait.name, "_", category, "_EnrichmentBarPlot")
    pdf(paste0(filename, ".pdf"), width=2.5, height=3/60*k+1)
    print(p)
    dev.off()

}


## output V2G2P genes
fdr.thr <- 0.05
for(key in category_ary) {
    significantGenes.df <- regulator.programGene.combined.pval.nonbatch %>%
        subset(get(paste0(key, "_fisher.p.adjust")) < fdr.thr) %>%
        select(ProgramID, paste0(key, "_Gene_LinkedToTopic_CandidateGene")) %>%
        as.data.frame
    write.table(significantGenes.df, paste0(OUTDIR, "/significant", key, ".df.txt"), sep="\t", quote=F, row.names=F)

    significantGenes <- significantGenes.df %>%
        pull(get(paste0(key, "_Gene_LinkedToTopic_CandidateGene"))) %>%
        paste0(collapse=",") %>%
        strsplit(split=",") %>%
        unlist %>%
        unique
    write.table(significantGenes, paste0(OUTDIR, "/significant", key, ".txt"), sep="\n", quote=F, row.names=F, col.names=F)

    significantGenes.formatted.df <- significantGenes.df %>%
        separate_rows(paste0(key, "_Gene_LinkedToTopic_CandidateGene")) %>%
        mutate(t = gsub(paste0("K", k, "_"), "", ProgramID)) %>%
        arrange(t) %>%
        group_by(get(paste0(key, "_Gene_LinkedToTopic_CandidateGene"))) %>%
        summarize(ProgramID = paste0(ProgramID, collapse=",")) %>%
        `colnames<-`(c(paste0(key,"_Gene_LinkedToTopic_CandidateGene"), "ProgramID")) %>%
        as.data.frame
    write.table(significantGenes.formatted.df, paste0(OUTDIR, "/significant", key, ".formatted.df.txt"), sep="\t", quote=F, row.names=F)
}

## add program gene and regulator membership to linked genes table
if(opt$perturbSeq) {
    linkedSignificantGenes.formatted.df <- read.delim(paste0(OUTDIR, "/significantLinkedGenes.formatted.df.txt"), stringsAsFactors=F) %>% `colnames<-`(c("Gene", "LinkedPrograms"))
    for(key in c("ProgramGene", "Regulator")) {
        assign(paste0(key, "SignificantGenes.formatted.df"), regulator.programGene.combined.pval.nonbatch %>%
                                                             subset(get(paste0("LinkedGenes_fisher.p.adjust")) < fdr.thr) %>%
                                                             select(ProgramID, paste0(key, "_Gene_LinkedToTopic_CandidateGene")) %>%
                                                             separate_rows(paste0(key, "_Gene_LinkedToTopic_CandidateGene")) %>%
                                                             mutate(t = gsub(paste0("K", k, "_"), "", ProgramID)) %>%
                                                             arrange(t) %>%
                                                             group_by(get(paste0(key, "_Gene_LinkedToTopic_CandidateGene"))) %>%
                                                             summarize(ProgramID = paste0(ProgramID, collapse=",")) %>%
                                                             `colnames<-`(c("Gene", key)) %>%
                                                             as.data.frame)
    }
    significantGenes.formatted.df <- merge(linkedSignificantGenes.formatted.df, ProgramGeneSignificantGenes.formatted.df, by="Gene", all.x=T) %>% merge(RegulatorSignificantGenes.formatted.df, by="Gene", all.x=T) %>% `colnames<-`(c("Gene", "LinkedProgram", "PartOfProgram", "RegulatorOfProgram"))
    write.table(significantGenes.formatted.df, paste0(OUTDIR, "/significantLinkedGenes.formatted.df.txt"), sep="\t", quote=F, row.names=F)
}
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
library(conflicted)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")


packages <- c("optparse","dplyr", "cowplot", "ggplot2", "gplots", "data.table", "reshape2",
              "tidyr", "grid", "gtable", "gridExtra","ggrepel","ramify",
              "ggpubr","gridExtra",
              "org.Hs.eg.db","limma","fgsea", "conflicted",
              "cluster","textshape","readxl", 
              "ggdist", "gghalves", "Seurat", "writexl",
              "stringi") 
xfun::pkg_attach(packages)
conflict_prefer("select","dplyr") # multiple packages have select(), prioritize dplyr
conflict_prefer("melt", "reshape2") 
conflict_prefer("slice", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("filter", "dplyr")
conflict_prefer("combine", "dplyr")
conflict_prefer("list", "base")
conflict_prefer("desc", "dplyr")
conflict_prefer("Position", "ggplot2")
conflict_prefer("first", "dplyr")


# setwd("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/figures/helper_scripts/")
# source("../../figures/helper_scripts/plot_helper_functions.R")
# source("../../figures/helper_scripts/load_data.R")
# source("../../figures/helper_scripts/load_edgeR.R")
# setwd("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/sample_cNMF_PoPS_input_table/")

source("./workflow/scripts/helper_functions.R")

option.list <- list(
    make_option("--figdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/figures/all_genes/2kG.library.ctrl.only/K25/threshold_0_2/", help="Figure directory"),
    make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211206_ctrl_only_snakemake/analysis/all_genes/2kG.library.ctrl.only/K25/threshold_0_2/", help="Output directory"),
    ## make_option("--olddatadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/data/", help="Input 10x data directory"),
    make_option("--datadir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/", help="Input 10x data directory"),
    make_option("--topic.model.result.dir", type="character", default="/scratch/groups/engreitz/Users/kangh/Perturb-seq_CAD/210625_snakemake_output/top3000VariableGenes_acrossK/2kG.library/", help="Topic model results directory"),
    make_option("--sampleName", type="character", default="2kG.library.ctrl.only", help="Name of Samples to be processed, separated by commas"),
    make_option("--K.val", type="numeric", default=25, help="K value to analyze"),
    ## make_option("--cell.count.thr", type="numeric", default=2, help="filter threshold for number of cells per guide (greater than the input number)"),
    ## make_option("--guide.count.thr", type="numeric", default=1, help="filter threshold for number of guide per perturbation (greater than the input number)"),
    make_option("--density.thr", type="character", default="0.2", help="concensus cluster threshold, 2 for no filtering"),
    make_option("--perturbSeq", type="logical", default=F, help="T for Perturb-seq experiment, F for no perturbation"),

    ## PoPS results
    make_option("--preds_with_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6_cNMF60.preds", help="PoPS Score with cNMF input"),
    make_option("--preds_without_cNMF", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/CAD_aug6.preds", help="PoPS Score with cNMF input"),


    ## summary plot parameters
    make_option("--test.type", type="character", default="per.guide.wilcoxon", help="Significance test to threshold perturbation results"),
    make_option("--adj.p.value.thr", type="numeric", default=0.1, help="adjusted p-value threshold"),
    make_option("--recompute", type="logical", default=F, help="T for recomputing statistical tests and F for not recompute")

)
opt <- parse_args(OptionParser(option_list=option.list))


## ## all genes directories (for sdev)
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/figures/2kG.library/all_genes/2kG.library/K60/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/"
## opt$K.val <- 60
## opt$sampleName <- "2kG.library"
## opt$perturbSeq <- TRUE

## ## K562 gwps 2k overdispersed genes
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K80/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K80/threshold_0_2/"
## opt$K.val <- 80
## opt$sampleName <- "WeissmanK562gwps"
## opt$perturbSeq <- TRUE


## ## K562 gwps 2k overdispersed genes
## opt$figdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/figures/top2000VariableGenes/WeissmanK562gwps/K90/"
## opt$outdir <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K90/threshold_0_2/"
## opt$K.val <- 90
## opt$sampleName <- "WeissmanK562gwps"
## opt$perturbSeq <- TRUE


OUTDIRSAMPLE <- OUTDIR <- opt$outdir
DATADIR <- opt$datadir
SAMPLE=opt$sampleName
DENSITY.THRESHOLD <- gsub("\\.","_", opt$density.thr)
k <- opt$K.val
## OUTDIRSAMPLE=paste0(OUTDIR, SAMPLE, "/K",k,"/threshold_", DENSITY.THRESHOLD, "/")
SUBSCRIPT=paste0("k_", k,".dt_",DENSITY.THRESHOLD,".minGuidePerPtb_",opt$guide.count.thr,".minCellPerGuide_", opt$cell.count.thr)
SUBSCRIPT.SHORT=paste0("k_", k,".dt_",DENSITY.THRESHOLD)
if(!dir.exists(OUTDIR)) dir.create(OUTDIR)
fdr.thr <- 0.05
db <- ifelse(grepl("mouse", SAMPLE), "org.Mm.eg.db", "org.Hs.eg.db")
library(!!db) ## load the appropriate database

if(grepl("2kG.library", SAMPLE)) {

    ## map ids ## slow version
    x <- org.Hs.egENSEMBL 
    mapped_genes <- mappedkeys(x)
    entrez.to.ensembl <- as.list(x[mapped_genes]) # EntrezID to Ensembl
    ensembl.to.entrez <- as.list(org.Hs.egENSEMBL2EG) # Ensembl to EntrezID

    y <- org.Hs.egGENENAME
    y_mapped_genes <- mappedkeys(y)
    entrez.to.genename <- as.list(y[y_mapped_genes])
    genename.to.entrez <- as.list(org.Hs.egGENENAME)

    ## map between EntrezID and Gene Symbol
    z <- org.Hs.egSYMBOL
    z_mapped_genes <- mappedkeys(z)
    entrez.to.symbol <- as.list(z[z_mapped_genes])

    entrez.to.symbol <- as.list(org.Hs.egSYMBOL)
    symbol.to.entrez <- as.list(org.Hs.egSYMBOL2EG)
}

## load 10X reference
gtf.10X.df <- readRDS(paste0("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/refdata-cellranger-arc-GRCh38-2020-A_genes.gtf_df.RDS")) ## load 10X gtf file
Gene.ENSEMBL.10X.df <- gtf.10X.df %>% 
    mutate(ENSGID = gene_id,
           Gene10X = gene_name) %>%
    select(Gene10X, ENSGID) %>%
    unique
## gtf <- importGTF("/home/groups/engreitz/Software/cellranger-arc-1.0.1/refdata-cellranger-arc-GRCh38-2020-A/genes/genes.gtf")

## helper function to map between ENSGID and SYMBOL
map.ENSGID.SYMBOL <- function(df) {
    ## need column `Gene` to be present in df
    ## detect gene data type (e.g. ENSGID, Entrez Symbol)
    gene.type <- ifelse(nrow(df) == sum(as.numeric(grepl("^ENS", df$Gene))),
                        "ENSGID",
                        "Gene")
    if(gene.type == "ENSGID") {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "ENSEMBL", column = "SYMBOL")
        df <- df %>% mutate(ENSGID = Gene, Gene = mapped.genes)
    } else {
        mapped.genes <- mapIds(get(db), keys=df$Gene, keytype = "SYMBOL", column = "ENSEMBL")
        df <- df %>% mutate(ENSGID = mapped.genes)
    }
    df <- df %>%
        mutate(Gene = ifelse(is.na(Gene), "NA", Gene),
               ENSGID = ifelse(is.na(ENSGID), "NA", ENSGID)) 
    df <- df %>% merge(Gene.ENSEMBL.10X.df, by="ENSGID", all.x=T)
    df <- df %>% mutate(Gene = ifelse(Gene == "NA", Gene10X, Gene))
    notMatched.df <- df %>% subset(is.na(Gene10X))
    if(nrow(notMatched.df) > 0) {
        notMatched.index <- notMatched.df %>% rownames
        match.df <- merge(notMatched.df %>% select(-ENSGID, -Gene10X), Gene.ENSEMBL.10X.df, by.x="Gene", by.y="Gene10X") %>% mutate(Gene10X = Gene) %>% select(all_of(df %>% colnames))
        matched.index <- notMatched.df %>% subset(Gene %in% c(match.df$Gene %>% unique)) %>% rownames %>% as.numeric
        df <- rbind(df[-matched.index,], match.df)
    }
    df <- df %>% mutate(OriginalGene = Gene,
                        Gene = Gene10X)
    ## toconvert <- df %>% subset(Gene != Gene10X) %>% select(ENSGID, Gene, Gene10X)
    ## toconvert.index <- toconvert %>% 
    ## df %>% subset(Gene %in% Gene.ENSEMBL.10X.df$Gene10X & is.na(ENSGID))    
    return(df)
}

## ## load Perturb-seq analysis results
## OUTDIRSAMPLE <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210707_snakemake_maxParallel/analysis/2kG.library/all_genes/2kG.library/K60/threshold_0_2/"
## SUBSCRIPT <- "k_60.dt_0_2.minGuidePerPtb_1.minCellPerGuide_2"
file.name <- paste0(OUTDIRSAMPLE,"/cNMF_results.",SUBSCRIPT.SHORT,".RData")
print(file.name) 
if(file.exists((file.name))) { 
    print(paste0("loading ",file.name))
    load(file.name)
}

if(opt$perturbSeq) {
    MAST.file.name <- paste0(OUTDIRSAMPLE, "/", SAMPLE, "_MAST_DEtopics.txt")
    print(paste0("loading ", MAST.file.name))
    MAST.df <- read.delim(MAST.file.name, stringsAsFactors=F, check.names=F)
    if(grepl("2kG.library", SAMPLE)){
        MAST.df <- MAST.df %>%
            subset(zlm.model.name == "batch.correction")
        multiTarget.genes <- MAST.df %>%
            subset(grepl("multiTarget", perturbation)) %>% 
            pull(perturbation) %>%
            unique %>%
            gsub("_multiTarget", "", .)
        MAST.df <- MAST.df %>%
            mutate(ProgramID = gsub("topic_", "K60_", primerid)) %>%
            subset(!(perturbation %in% multiTarget.genes) & ## for multiTarget genes, swap in the set where cells have GeneA and GeneA-and-GeneB guides.
                   !grepl("-and-", perturbation)) %>% ## remove the set where cells only have GeneA-and-GeneB guides.
            mutate(perturbation = gsub("_multiTarget", "", perturbation)) %>%
            group_by(zlm.model.name) %>%
            mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method="fdr")) %>%
            subset(zlm.model.name %in% c("batch.correction")) %>%
            select(-zlm.model.name) %>%
            group_by(ProgramID) %>%
            arrange(desc(coef)) %>%
            mutate(coef_rank = 1:n()) %>%
            as.data.frame
    }
}


## make theta.zscore.rank.df
## if(!("theta.zscore.rank.df" %in% ls())) {
theta.zscore.rank.df <- theta.zscore %>%
    as.data.frame %>%
    mutate(Gene = rownames(.)) %>%
    melt(id.vars="Gene", variable.name="ProgramID", value.name="zscore.specificity") %>%
    mutate(ProgramID = paste0("K", k, "_", ProgramID)) %>%
    group_by(ProgramID) %>%
    arrange(desc(zscore.specificity)) %>%
    mutate(zscore.specificity.rank = 1:n()) %>%
    ungroup %>%
    arrange(ProgramID, zscore.specificity.rank) %>%
    as.data.frame
## }
theta.zscore.rank.df <- theta.zscore.rank.df %>% map.ENSGID.SYMBOL

if(grepl("2kG.library", SAMPLE)) {

    ## load PoPS outputs
    preds <- read.table(file=opt$preds_without_cNMF,header=T, stringsAsFactors=F, sep="\t")

    ## load PoPS processed gene x feature score
    ## load("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/coefs.marginals.feature.outer.prod.RDS") ## takes a while
    OUTDIRPOPS <- "/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/"
    ## PoPS_Score.coefs.all.outer <- read.delim(paste0(OUTDIRPOPS, "/coefs.all.feature.outer.prod.txt"), stringsAsFactors=F) ## also takes a long time to load

    ## top.genes.in.top.features <- read.delim(file=paste0(OUTDIRPOPS, "/top.genes.in.top.features.coefs.txt"), stringsAsFactors=F)
    ## load top features defining each gene's PoPS score
    PoPS_preds.importance.score.key <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/210831_PoPS/211108_withoutBBJ/outputs/PoPS_preds.importance.score.key.columns.txt"), stringsAsFactors=F)

    ## load top topic defining each gene's PoPS score
    PREFIX <- paste0("CAD_aug6_cNMF", k)
    load(paste0(OUTDIRPOPS, "/", PREFIX, "_coefs.defining.top.topic.RDS"))

    coefs <- read.delim(paste0(OUTDIRPOPS, "/", PREFIX, ".coefs"), header=T, stringsAsFactors=F)
    coefs.df <- coefs[4:nrow(coefs),]
    coefs.cNMF.df <- coefs.df %>% subset(grepl("zscore",parameter))
    coefs.cNMF.names <- coefs.cNMF.df %>% pull(parameter)

    ## load EdgeR log2fcs and p-values
    log2fc.edgeR <- read.table(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/EdgeR/ALL_log2fcs_dup4_s4n3.99x.txt"), header=T, stringsAsFactors=F)
    p.value.edgeR <- read.table(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/data/EdgeR/ALL_Pvalues_dup4_s4n3.99x.txt"), header=T, stringsAsFactors=F)

    ## load known CAD gene set
    params.known.CAD.Gene.set <- read.delim("/oak/stanford/groups/engreitz/Users/kangh/ECPerturbSeq2021-Analysis/data/known_CAD_gene_set.txt", stringsAsFactors=F, header=F) %>% as.matrix %>% as.character
}


## Helper function to add EntrezID and ENSGID
add.EntrezID.ENSGID <- function(df) {
    return ( df %>%
             mutate(EntrezID = symbol.to.entrez[.$Gene %>% as.character] %>% sapply("[[",1) %>% as.character) %>%
            mutate(ENSGID = entrez.to.ensembl[.$EntrezID %>% as.character] %>% sapply("[[", 1) %>% as.character) 
            )
}

## ## MAST results
## MAST.df.4n.original <- read.delim(paste0("/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220217_MAST/2kG.library4n3.99x_MAST.txt"), stringsAsFactors=F, check.names=F)
## MAST.df <- MAST.df.4n.original %>%
##     subset(!grepl("multiTarget", perturbation)) %>%
##     group_by(zlm.model.name) %>%
##     mutate(fdr.across.ptb = p.adjust(`Pr(>Chisq)`, method="fdr")) %>%
##     subset(zlm.model.name %in% c("batch.correction")) %>%
##     select(-zlm.model.name) %>%
##     mutate(ProgramID = gsub("topic_", "K60_", primerid)) %>%
##     group_by(ProgramID) %>%
##     arrange(desc(coef)) %>%
##     mutate(coef_rank = 1:n()) %>%
##     as.data.frame
## sigPerturbationsProgram <- MAST.df %>%
##     subset(fdr.across.ptb < fdr.thr) %>%
##     as.data.frame

if(grepl("2kG.library", SAMPLE)) {
    ## curated gene name conversion list
    geneNameConversion <- read.delim(paste0(DATADIR, "/heterogenous_geneName_conversion_toMatchENSGID.txt"), stringsAsFactors=F) %>%
        separate(col=Conversion, into=c("fromGene", "toGene"), sep=" -> ", remove=F)
    geneNameConversionBackToPerturbseq <- geneNameConversion %>% filter(!KeepForGWASTable)

    ## helper function to conver gene names
    convert.gene.names.toMatchENSGID <- function(Gene) stri_replace_all_regex(Gene, pattern = geneNameConversion$fromGene, replace = geneNameConversion$toGene, vectorize=F)

    convert.gene.names.toMatchPerturbseq <- function(Gene) stri_replace_all_regex(Gene, pattern = geneNameConversionBackToPerturbseq$toGene, replace = geneNameConversionBackToPerturbseq$fromGene, vectorize=F)
}

## new gene name conversion method ## 220628
ptb10xNames.df <- read.delim(paste0(DATADIR, "/220627_add_Perturbation 10X names.txt"), stringsAsFactors=F, check.names=F)
perturbseq.gene.names.to10X <- function(Gene) {
    if(Gene %in% ptb10xNames.df$Symbol) {
        out <- stri_replace_all_regex(Gene, pattern = ptb10xNames.df$Symbol, replace = ptb10xNames.df$`Name used by CellRanger`, vectorize=F)
    } else {
        out <- Gene
    }
    return(out)
}
tenX.gene.names.toperturbseq <- function(Gene) stri_replace_all_regex(Gene, pattern = ptb10xNames.df$`Name used by CellRanger`, replace = ptb10xNames.df$Symbol, vectorize=F) 

print("loaded all prerequisite data")


##########################################################################################
ptb.topic.thr <- 0.05 ## FDR threshold
## Column: ProgramsRegulatedByThisGene
if(opt$perturbSeq) {
    ProgramsRegulatedByThisGene.df <- MAST.df %>%
        mutate(OriginalGene = perturbation,
               Gene = perturbseq.gene.names.to10X(perturbation),
               t = gsub("topic_", "", primerid) %>% as.numeric,
               ProgramID = paste0("K", k, "_",  t)) %>%
        mutate(significant = fdr.across.ptb < ptb.topic.thr &
                   (coef > log(1.1) | coef < log(0.9)),
               ProgramID = ifelse(significant, ProgramID, "NA")) %>%
        ## subset(fdr.across.ptb < ptb.topic.thr) %>%
        ## mutate(Gene = convert.gene.names.toMatchENSGID(perturbation),         
        map.ENSGID.SYMBOL
    ## mutate(EntrezID = symbol.to.entrez[.$Gene %>% as.character] %>% sapply("[[",1) %>% as.character) %>%
    ## mutate(ENSGID = entrez.to.ensembl[.$EntrezID %>% as.character] %>% sapply("[[", 1) %>% as.character)
    ProgramsRegulatedByThisGene.tokeep <- ProgramsRegulatedByThisGene.df %>%
        arrange(t) %>%
        ## select(Gene, OriginalGene, ENSGID, EntrezID, ProgramID) %>%
        select(Gene, OriginalGene, Gene10X, ENSGID, ProgramID) %>%
        unique %>% 
        ## group_by(ENSGID, Gene, OriginalGene, EntrezID) %>%
        group_by(ENSGID, Gene, Gene10X, OriginalGene) %>%        
        summarize(ProgramsRegulatedByThisGene=paste0(ProgramID, collapse="|")) %>%
        mutate(ProgramsRegulatedByThisGene = gsub("NA[|]|[|]NA", "", ProgramsRegulatedByThisGene),
               ProgramsRegulatedByThisGene = gsub("[|]NA[|]", "|", ProgramsRegulatedByThisGene),
               perturbed_gene = TRUE) %>% ## clean up "NA"s
        ## mutate(Gene = convert.gene.names.toMatchENSGID(rownames(.))) %>%
        ## mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
        as.data.frame
    ## if(grepl("2kG.library", SAMPLE)) {
    ##     ProgramsRegulatedByThisGene.df <- ProgramsRegulatedByThisGene.df %>% mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(perturbation))
    ##     ProgramsRegulatedByThisGene.tokeep <- ProgramsRegulatedByThisGene.df %>%
    ##         mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
    ##         as.data.frame
    ## } 
}

## ## ## ## Column: ProgramsInWhichGeneIsInTop100ZScoreSpecificGenes
## ## ## params.top.n.genes.in.topic <- 100
## if(grepl("2kG.library", SAMPLE)) {
##     ProgramsInWhichGeneIsExpressed.df <- theta.zscore.rank.df %>%
##         mutate(OriginalGene = Gene,
##                PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene),
##                Gene = perturbseq.gene.names.to10X(Gene)) %>%
##         ## add.EntrezID.ENSGID %>%
##         as.data.frame
## } else {
##     ProgramsInWhichGeneIsExpressed.df <- theta.zscore.rank.df %>%
##         ## mutate(OriginalGene = Gene,
##         ##        PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene),
##         ##        Gene = perturbseq.gene.names.to10X(Gene)) %>%
##         ## add.EntrezID.ENSGID %>%
##         as.data.frame
## }

IncNMFAnalysis.tokeep <- theta.zscore.rank.df %>%
    select(Gene, Gene10X, ENSGID) %>%
    unique %>%
    mutate(IncNMFAnalysis = TRUE) %>%
    as.data.frame

params.top.n.genes.in.topic.list <- c(100, 300, 500)
ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep <- Reduce(
    function(x, y, ...) full_join(x, y, by = c("Gene", "Gene10X", "ENSGID"), ...), ## want to join all the data frames that has columns c("Gene", "ProgramsInWhichGeneIsInTopNZScoreSpecificGenes")
    out <- lapply(params.top.n.genes.in.topic.list, function (params.top.n.genes.in.topic) { # for each selected number of top genes we want to include
        column.name <- paste0("ProgramsInWhichGeneIsInTop", params.top.n.genes.in.topic, "ZScoreSpecificGenes") # store the column name in a variable
        ## out <- ProgramsInWhichGeneIsExpressed.df %>%
        ##     group_by(ProgramID) %>% 
        ## arrange(desc(Topic.zscore)) %>% # sort the z-score within each topic
        ## slice(1:params.top.n.genes.in.topic) %>% # select the top N genes
        out <- theta.zscore.rank.df %>%
            mutate(IsProgramGene = zscore.specificity.rank <= params.top.n.genes.in.topic,
                   ProgramID = ifelse(IsProgramGene, ProgramID, "NA")) %>%
            ## subset(zscore.specificity.rank <= params.top.n.genes.in.topic) %>%
            ## ungroup %>%
            group_by(Gene, Gene10X, ENSGID) %>% # for each gene
            summarize(!!column.name := paste0(ProgramID, collapse="|")) %>% # paste the topics each gene links to by ","
            mutate(!!column.name := gsub("NA[|]|[|]NA", "", get(column.name)),
                   !!column.name := gsub("[|]NA[|]", "|", get(column.name))) %>% ## clean up "NA"s
            as.data.frame
    }))
## if(grepl("2kG.library", SAMPLE)) {
##     ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep <- ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep %>%     
##         ## mutate(Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##         mutate(OriginalGene = Gene) %>%
##         mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##         mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##         add.EntrezID.ENSGID %>%
##         as.data.frame
## }

## ## Column: PoPSEnrichedProgramsRegulatedByThisGene ## FDR < 0.05 and in PoPS prioritized features
## if(opt$perturbSeq & grepl("2kG.library", SAMPLE)){
## cNMF.PoPS.enriched.topics <- coefs.cNMF.names %>% gsub("zscore_K60_topic", "K60_", .) ## get the list of PoPS prioritized cNMF topics and reformat
## PoPSEnrichedProgramsRegulatedByThisGene.tokeep <- ProgramsRegulatedByThisGene.df %>%
##     select(Gene, OriginalGene, PreviousConversionGene, ENSGID, ProgramID) %>%
##     subset(ProgramID %in% cNMF.PoPS.enriched.topics) %>%
##     group_by(ENSGID, Gene, OriginalGene, PreviousConversionGene) %>%
##     summarize(PoPSEnrichedProgramsRegulatedByThisGene=paste0(ProgramID, collapse="|")) %>%
##     ## mutate(Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     ## mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     ## mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##     as.data.frame
## }

## ## Column: PoPSEnrichedProgramsInWhichGeneIsInTop100ZScoreSpecificGenes
## cNMF.PoPS.enriched.topics <- coefs.cNMF.names %>% gsub("zscore_K60_topic", "K60_", .) ## get the list of PoPS prioritized cNMF topics and reformat
## params.top.n.genes.in.topic <- 300
## PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes.tokeep <- ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep %>% ## ProgramsInWhichGeneIsExpressed.df %>%
##     subset(ProgramID %in% cNMF.PoPS.enriched.topics) %>%
##     ## group_by(ProgramID) %>%
##     ## arrange(desc(Topic.zscore)) %>%
##     ## slice(1:params.top.n.genes.in.topic) %>%
##     ## ungroup %>%
##     subset(zscore.specificity.rank <= params.top.n.genes.in.topic) %>%
##     group_by(Gene, PreviousConversionGene, OriginalGene, ENSGID) %>%
##     summarize(PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes = paste0(ProgramID, collapse="|")) %>%
##     ## mutate(Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     ## mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     ## mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##     ## add.EntrezID.ENSGID %>%
##     as.data.frame


## ## PoPS.Score
## PoPS.Score.tokeep <- preds %>% select(ENSGID, PoPS_Score) %>%
##     mutate(EntrezID = ensembl.to.entrez[.$ENSGID %>% as.character] %>% sapply("[[",1) %>% as.character) %>%
##     mutate(Gene = entrez.to.symbol[.$EntrezID %>% as.character] %>% sapply("[[",1) %>% as.character,
##            ## Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##            PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     mutate(OriginalGene = Gene) %>%
##     mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##     subset(!(EntrezID == "NULL" & Gene == "NULL" & PreviousConversionGene == "NULL" & OriginalGene == "NULL")) %>%
##     as.data.frame


## ## ## PoPS.Rank [rank of the score among genes near this CredibleSet/GWAS signal]
## ## create later when we merge in CredibleSet/GWAS signal

## ## ## Top5ProgramsThatContributeToPoPSScore
## ## Top5ProgramsThatContributeToPoPSScore.tokeep <- coefs.defining.top.topic.df %>%
## ##     mutate(ProgramID = gsub("zscore_K60_topic", "K60_", topic),
## ##            Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
## ##     group_by(Gene) %>%
## ##     arrange(desc(gene.feature_x_beta)) %>%
## ##     slice(1:5) %>%
## ##     summarize(Top5ProgramsThatContributeToPoPSScore = paste0(ProgramID, collapse="|")) %>%
## ##     add.EntrezID.ENSGID %>%
## ##     as.data.frame

## ## Top5FeaturesThatContributeToPoPSScore
## Top5FeaturesThatContributeToPoPSScore.tokeep <- PoPS_preds.importance.score.key %>%
##     group_by(Gene) %>%
##     arrange(desc(gene.feature_x_beta)) %>%
##     slice(1:5) %>%
##     mutate(to.display=paste0(pathway, ":", Long_Name)) %>%
##     summarize(Top5FeaturesThatContributeToPoPSScore = paste0(to.display, collapse="|")) %>%
##     ## mutate(Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     mutate(OriginalGene = Gene) %>% 
##     mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##     mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##     add.EntrezID.ENSGID %>%
##     as.data.frame
## }

## ## DoesThisGeneWhenPerturbedRegulateAKnownCADGene 
## if(opt$perturbSeq & grepl("2kG.library", SAMPLE)){
##     CAD.gene.regulator.log2fc <- log2fc.edgeR %>%
##         subset(grepl(paste0(params.known.CAD.Gene.set, ":ENSG") %>%
##                      paste0(collapse="|"), .$gene))
##     CAD.gene.regulator.p.value <- p.value.edgeR %>%
##         subset(grepl(paste0(params.known.CAD.Gene.set, ":ENSG") %>%
##                      paste0(collapse="|"), .$gene))
##     params.EdgeR.p.value.thr <- "0.05"
##     CAD.gene.regulator.df <- CAD.gene.regulator.p.value %>%
##         melt(id.vars="genes", value.name="p.value", variable.name="perturbation") %>%
##         subset(p.value < params.EdgeR.p.value.thr)
##     DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep <- CAD.gene.regulator.df %>%
##         separate(genes, into=c("Gene", "ENSGID"), sep=":") %>%
##         group_by(perturbation) %>%
##         summarize(DoesThisGeneWhenPerturbedRegulateAKnownCADGene = paste0(Gene, collapse="|"))
##     colnames(DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep)[colnames(DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep) == "perturbation"] <- "Gene"
##     DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep <- DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep %>%
##         ## mutate(Gene = convert.gene.names.toMatchENSGID(Gene)) %>%
##         mutate(OriginalGene = Gene) %>%
##         mutate(PreviousConversionGene = convert.gene.names.toMatchENSGID(Gene)) %>%
##         mutate(Gene = perturbseq.gene.names.to10X(Gene)) %>%
##         add.EntrezID.ENSGID %>%
##         as.data.frame
## }

## put together all
if(opt$perturbSeq){
    ## if(grepl("2kG.library", SAMPLE)) {
    ##     list.all <- list(ProgramsRegulatedByThisGene = ProgramsRegulatedByThisGene.tokeep,
    ##                      ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes = ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep,
    ##                      PoPSEnrichedProgramsRegulatedByThisGene = PoPSEnrichedProgramsRegulatedByThisGene.tokeep,
    ##                      PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes = PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes.tokeep,
    ##                      PoPS.Score = PoPS.Score.tokeep,
    ##                      DoesThisGeneWhenPerturbedRegulateAKnownCADGene = DoesThisGeneWhenPerturbedRegulateAKnownCADGene.tokeep,
    ##                      IncNMFAnalysis = IncNMFAnalysis.tokeep) ## caveat: Not all perturbations are in common Gene Symbol format (e.g. FAM212A, Icam2, Slc9a3r2)
    ## } else {
        list.all <- list(ProgramsRegulatedByThisGene = ProgramsRegulatedByThisGene.tokeep,
                         ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes = ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep,
                         IncNMFAnalysis = IncNMFAnalysis.tokeep) ## caveat: Not all perturbations are in common Gene Symbol format (e.g. FAM212A, Icam2, Slc9a3r2)
    ## }
} else {

    list.all <- list(ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes = ProgramsInWhichGeneIsInTopNListZScoreSpecificGenes.tokeep,
                     ## PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes = PoPSEnrichedProgramsInWhichGeneIsInTop300ZScoreSpecificGenes.tokeep,
                     ## PoPS.Score = PoPS.Score.tokeep,
                     IncNMFAnalysis = IncNMFAnalysis.tokeep) 
}
## if(grepl("2kG.library", SAMPLE)) {
##     df.all <- Reduce(function(x, y, ...) full_join(x, y %>% select(-OriginalGene, -PreviousConversionGene), by = c("Gene", "ENSGID", "EntrezID"), ...), list.all)
## } else {
    df.all <- Reduce(function(x, y, ...) full_join(x, y, by = c("Gene", "Gene10X", "ENSGID"), ...), list.all)
## }
df.all[df.all == "NULL"] <- NA
na.row.index <- which(df.all$Gene %>% is.na & df.all$ENSGID %>% is.na & df.all$PoPS_Score %>% is.na)

## Perturb-seq specific trimming of duplicated Gene names
if(opt$perturbSeq) {
    ## if(grepl("2kG.library", SAMPLE)) {
    ##     df.all <- df.all %>% subset(!(is.na(Gene) & is.na(ProgramsRegulatedByThisGene) & !(Gene %in% c(theta.zscore %>% rownames, MAST.df$perturbation %>% unique)) & is.na(ENSGID))) %>% ## combine duplicated Gene names' rows
    ##         group_by(Gene, ENSGID) %>% ## consider implement match with 10X
    ##         summarize_all(., function(x) {
    ##             out <- paste0(x %>% unique, collapse="|")
    ##             if(grepl("TRUE", out %>% as.character)) return("TRUE") else return(out)
    ##         }) %>%
    ##         as.data.frame
    ## } else {
        df.all <- df.all %>% subset(!(is.na(Gene) & is.na(ProgramsRegulatedByThisGene) & !(Gene %in% c(theta.zscore %>% rownames, MAST.df$perturbation %>% unique)) & is.na(ENSGID))) %>% ## combine duplicated Gene names' rows
            group_by(Gene, Gene10X, ENSGID) %>%
            summarize_all(., function(x) {
                out <- paste0(x %>% unique, collapse="|")
                if(grepl("TRUE", out %>% as.character)) return("TRUE") else return(out)
            }) %>%
            mutate(IncNMFAnalysis = ifelse(IncNMFAnalysis == "TRUE", TRUE, FALSE),
                   perturbed_gene = ifelse(perturbed_gene == "TRUE", TRUE, FALSE)) %>%
            as.data.frame
    ## }        
}


## write table
print("writing prepared table to file")
write.table(df.all, file=paste0(OUTDIRSAMPLE, "/prepare_compute_enrichment.txt"), row.names=F, quote=F, sep="\t")
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
suppressPackageStartupMessages(library(optparse))

option.list <- list(
    # make_option("--outdir", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2106_FT007_Analysis/outputs/", help="Output directory"),
    # make_option("--datadir",type="character", default="/oak/stanford/groups/engreitz/Users/kangh/process_sequencing_data/210611_FT007_CM_CMO/gex_FT007_50k/outs/filtered_feature_bc_matrix/", help="Data directory"),
    # make_option("--project",type="character",default="/oak/stanford/groups/engreitz/Users/kangh/heart_atlas/2106_FT007_Analysis/",help="Project Directory"),
    make_option("--sampleName",type="character",default="gex_FT007_50k", help="Sample name"),
    make_option("--inputSeuratObject", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/data/FT010_fresh_3min.SeuratObject.RDS", help="Path to the Seurat Object"),
    make_option("--output_h5ad", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/data/FT010_fresh_3min.h5ad"),
    make_option("--output_gene_name_txt", type="character", default="/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/211011_Perturb-seq_Analysis_Pipeline_scratch/analysis/data/FT010_fresh_3min.h5ad.all.genes.txt"),
    make_option("--minUMIsPerCell", type="numeric", default=200),
    make_option("--minUniqueGenesPerCell", type="numeric", default=200)
    # make_option("--recompute", type="logical", default=F, help="T for recomputing UMAP from 10x count matrix")
)
opt <- parse_args(OptionParser(option_list=option.list))

suppressPackageStartupMessages(library(SeuratObject))
suppressPackageStartupMessages(library(Seurat))
suppressPackageStartupMessages(library(reticulate))
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(data.table))
suppressPackageStartupMessages(library(tidyr))
suppressPackageStartupMessages(library(Matrix))
# suppressPackageStartupMessages(library(readxl))
# suppressPackageStartupMessages(library(ggrepel))

# mytheme <- theme_classic() + theme(axis.text = element_text(size = 13), axis.title = element_text(size = 15), plot.title = element_text(hjust = 0.5))

# source("/oak/stanford/groups/engreitz/Users/kangh/2009_endothelial_perturbseq_analysis/topicModelAnalysis.functions.R")


## ## sdev for mouse ENCODE adrenal data
## opt$inputSeuratObject <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230117_snakemake_mouse_ENCODE_adrenal/analysis/data/mouse_ENCODE_adrenal.SeuratObject.RDS"


## ## sdev for IGVF_b01_LeftCortex
## opt$inputSeuratObject <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/data/IGVF_b01_LeftCortex.SeuratObject.RDS"
## opt$output_h5ad <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/data/IGVF_b01_LeftCortex.h5ad"
## opt$output_gene_name_txt <- "/oak/stanford/groups/engreitz/Users/kangh/IGVF/Cellular_Programs_Networks/230706_snakemake_igvf_b01_LeftCortex/analysis/data/IGVF_b01_LeftCortex.h5ad.all.genes.txt"
## opt$sampleName <- "IGVF_b01_LeftCortex"


#######################################################################
## Constants
# PROJECT=opt$project
# DATADIR=opt$datadir
# OUTDIR=opt$outdir
# FIGDIR= paste0(PROJECT, "/figures/")
SAMPLE=opt$sampleName
# OUTDIRSAMPLE=paste0(OUTDIR,"/",SAMPLE,"/")
# FIGDIRSAMPLE=paste0(FIGDIR,SAMPLE,"/")
# palette = colorRampPalette(c("#38b4f7", "white", "red"))(n=100)
# # create dir if not already
# check.dir <- c(OUTDIR, OUTDIRSAMPLE)
# invisible(lapply(check.dir, function(x) { if(!dir.exists(x)) dir.create(x) }))


## convert Seurat Object to h5ad
anndata <- import("anndata", convert = FALSE) ## load AnnData module

## load Seruat Object
s <- readRDS(opt$inputSeuratObject)

print("finished loading Seurat Object")

## set assay to RNA first to avoid error such as "SCT is not an assay present in the given object. Available assays are: RNA" (could happen when getting gene names).
DefaultAssay(object = s) <- "RNA"
## only keep RNA counts and metadata. Remove any other items (e.g. PCA, SCTransform, UMAP) to avoid errors.
s.meta <- s[[]]
s.count <- s@assays$RNA@counts
s <- CreateSeuratObject(counts = s.count,
                        project = SAMPLE,
                        meta.data = s.meta)

## filter genes and cells again, for the cases when input file is Seurat Object and create_seurat_object step is skipped
## remove non-protein coding genes and genes detected in fewer than 10 cells
tokeep <- which(!(grepl("^LINC|^[A-Za-z][A-Za-z][0-9][0-9][0-9][0-9][0-9][0-9]\\.|^Gm[0-9]|[0-9]Rik$|-ps", s %>% rownames)))
s.subset <- s[tokeep,]
print('finished subsetting to remove non-coding genes')
s.subset <- subset(s.subset, subset= nCount_RNA > opt$minUMIsPerCell & nFeature_RNA > opt$minUniqueGenesPerCell) # remove cells with less than predefined number of UMIs (e.g. 200 UMIs) and less than a number of genes (e.g. 200 genes)
print('removed cells with less than 200 UMIs and less than 200 genes')
## tokeep <- which(s.subset@assays$RNA@counts %>% apply(1, sum) > 10) # keep genes detected in more than 10 UMIs

## tokeep <- tryCatch(
##     {
##         which(s.subset@assays$RNA@counts %>% apply(1, function(x) ((x > 0) %>% as.numeric %>% sum > 10)))
##     },
##     error = function(cond) {
##         message("cannot genes expressed in less than 10 cells")
##         return(seq(1, s.subset %>% nrow(), by=1))
##     },
##     finally = {
##         message('removed genes expressed in less than 10 cells')
##     }
## )
## s <- s.subset[tokeep,]
s <- s.subset

adata <- anndata$AnnData(
    X = t(GetAssayData(object = s)),
    obs = data.frame(s@meta.data),
    var = s %>% rownames %>% as.data.frame %>% `rownames<-`(s %>% rownames) %>% `colnames<-`("Gene")
)

## adata <- anndata$AnnData(
##     X = t(GetAssayData(object = s) %>% as.matrix),
##     obs = data.frame([email protected]),
##     var = s %>% rownames %>% as.data.frame %>% `rownames<-`(s %>% rownames) %>% `colnames<-`("Gene")
## )

anndata$AnnData$write(adata, opt$output_h5ad)

## write gene names
gene.names <- s %>% rownames
write.table(gene.names, file=opt$output_gene_name_txt, col.names=F, row.names=F, quote=F, sep="\t")
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import scipy
import pandas as pd
# import scipy.sparse as sp
import scanpy as sc
import argparse
import os
import re
from cnmf import cNMF
from sklearn.decomposition import PCA


## argparse
parser = argparse.ArgumentParser()

## add arguments
parser.add_argument('--path_to_topics', type=str, help='path to the topic (cNMF directory) to project data on', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes_acrossK/')
parser.add_argument('--topic_sampleName', type=str, help='sample name for topics to project on, use the same sample name as used for the cNMF directory', default='2kG.library_overdispersedGenes')
parser.add_argument('--X_normalized', type=str,  help='path to normalized input cell x gene matrix from cNMF pipeline', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes_acrossK/2kG.library_overdispersedGenes/cnmf_tmp/2kG.library_overdispersedGenes.norm_counts.h5ad')
parser.add_argument('--outdir', dest = 'outdir', type=str, help = 'path to output directory', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/220716_snakemake_overdispersedGenes/analysis/top2000VariableGenes/2kG.library_overdispersedGenes/K60/threshold_0_2/')
parser.add_argument('--k', dest = 'k', type=int, help = 'number of components', default='60')
parser.add_argument('--density_threshold', dest = 'density_threshold', type=float, help = 'component spectra clustering threshold, 2 for no filtering, recommend 0_2 (means 0.2)', default="0.2")

# ## sdev for K562 gwps
# parser.add_argument('--path_to_topics', type=str, help='path to the topic (cNMF directory) to project data on', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes_acrossK/')
# parser.add_argument('--topic_sampleName', type=str, help='sample name for topics to project on, use the same sample name as used for the cNMF directory', default='WeissmanK562gwps')
# parser.add_argument('--X_normalized', type=str,  help='path to normalized input cell x gene matrix from cNMF pipeline', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes_acrossK/WeissmanK562gwps/cnmf_tmp/WeissmanK562gwps.norm_counts.h5ad')
# parser.add_argument('--outdir', dest = 'outdir', type=str, help = 'path to output directory', default='/oak/stanford/groups/engreitz/Users/kangh/TeloHAEC_Perturb-seq_2kG/230104_snakemake_WeissmanLabData/analysis/top2000VariableGenes/WeissmanK562gwps/K35/threshold_0_2/')
# parser.add_argument('--k', dest = 'k', type=int, help = 'number of components', default='35')
# parser.add_argument('--density_threshold', dest = 'density_threshold', type=float, help = 'component spectra clustering threshold, 2 for no filtering, recommend 0_2 (means 0.2)', default="0.2")


args = parser.parse_args()

# ## sdev debug for Disha's error
# args.X_normalized = "/oak/stanford/groups/engreitz/Users/kangh/scratch_space/230612_debug_cNMF_pipeline_variance_explained/Merge_Pauletal_subset_SMC.norm_counts.h5ad"
# args.outdir = "/oak/stanford/groups/engreitz/Users/kangh/scratch_space/230612_debug_cNMF_pipeline_variance_explained/"



sample = args.topic_sampleName
# output_sample = args.output_sampleName
# tpm_counts_path = args.tpm_counts_path
OUTDIR = args.outdir
selected_K = args.k
density_threshold = args.density_threshold
output_directory = args.path_to_topics
run_name = args.topic_sampleName

if not os.path.exists(OUTDIR):
    raise Exception("Output directory does not exist")

cnmf_obj = cNMF(output_dir=output_directory, name=run_name)
usage_norm, gep_scores, gep_tpm, topgenes = cnmf_obj.load_results(K=selected_K, density_threshold=density_threshold)


# X_original = sc.read_h5ad(args.tpm_counts_path)
X_norm = sc.read_h5ad(args.X_normalized)
# sc.pp.normalize_per_cell(X_original, counts_per_cell_after=1e6) ## normalize X to TPM
# X_original.X[0:10,:].todense().sum(axis=1) ## check normalization results


## functions
def compute_Var(X):
    if scipy.sparse.issparse(X):
        return np.sum(np.var(X.todense(), axis=0, ddof=1))
    else:
        return np.sum(np.var(X, axis=0, ddof=1))

# ## first turn X into TPM
# X_tpm_dense = X_original.X.todense()
# X_tpm_dense[0:10,].sum(axis=1) ## check TPM normalization

# X = X_norm.X.todense() ## 221203
X = X_norm.X
H_path = cnmf_obj.paths['consensus_spectra__txt'] % (selected_K, '0_2') ## median_spectra_file
H_df = pd.read_csv(H_path, sep='\t', index_col=0).T
H = H_df.to_numpy()
H = (H/H.sum(0))
W_path = cnmf_obj.paths['consensus_usages__txt'] % (selected_K, '0_2') ## median_spectra_file
W_df = pd.read_csv(W_path, sep='\t', index_col=0)
W = W_df.to_numpy()
WH = W @ H.T
# diff = X - WH

# X_col_sd = np.std(X_tpm_dense, axis=0, ddof=1) ## column normalization, shape: (1,17472)
# type(gep_tpm) ## data frame
# gep_tpm.shape ## (17472, 60)
# H_tmp = gep_tpm.to_numpy().T ## shape: (60, 17472)
# H = gep_tpm.to_numpy().T / X_col_sd ## normalize TPM spectra matrix (gene x component)
# X = X_tpm_dense / X_col_sd
# X[:,0:10].std(axis=0) ## check if standard deviation of gene expression is 1
# W = usage_norm.to_numpy()
# W[0:10,].sum(axis=1) ## to check if sum of each cell's usage is 1
# WH = W @ H
diff = X - WH
diff_sumOfSquaresError = (np.asarray(diff)**2).sum()
# X_sumOfSquares = (np.asarray(X)**2).sum()
# WH_sumOfSquares = (np.asarray(WH)**2).sum()
Var_diff = compute_Var(diff)
Var_X = compute_Var(X)
TotalVarianceExplained = 1 - Var_diff / Var_X

def computeVarianceExplained(X, H, Var_X, i):
    if not isinstance(H, (pd.DataFrame)):
        B_k = X @ H[i,:].T / np.sqrt((np.asarray(H[i,:])**2).sum())
        numerator = compute_Var(X - np.outer(B_k, H[i,:]))
    else:
        B_k = X @ H.iloc[i,:] / np.sqrt((H.iloc[i,:]**2).sum())
        numerator = compute_Var(X - np.outer(B_k, H.iloc[i,:]))
    return (1 - numerator / Var_X)

## initialize storage variable
V_k = np.empty([selected_K])

for i in range(selected_K):
    print(i)
    V_k[i] = computeVarianceExplained(X, H.T, Var_X, i)

ProgramID = ['K' + str(selected_K) + '_' + str(i+1) for i in range(selected_K)]

metrics_df = pd.DataFrame({'VarianceExplained': V_k,
                           'ProgramID': ProgramID })
metrics_summary = pd.DataFrame({'Sum' : metrics_df['VarianceExplained'].sum(),
                                'Median' : metrics_df['VarianceExplained'].median(),
                                'Max' : metrics_df['VarianceExplained'].max(),
                                'Total' : TotalVarianceExplained},
                               index = [0])
metrics_df.to_csv(os.path.join(OUTDIR, "metrics.varianceExplained.df.txt"), index = None, sep="\t")
metrics_summary.to_csv(os.path.join(OUTDIR, "summary.varianceExplained.df.txt"), index = None, sep="\t")
ShowHide 70 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...